Commit 5eddda8b authored by Myle Ott's avatar Myle Ott
Browse files

Save dictionary in model base classes

parent 08a74a32
......@@ -13,8 +13,9 @@ import torch.nn.functional as F
class FairseqDecoder(nn.Module):
"""Base class for decoders."""
def __init__(self):
def __init__(self, dictionary):
super().__init__()
self.dictionary = dictionary
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
......
......@@ -12,8 +12,9 @@ import torch.nn as nn
class FairseqEncoder(nn.Module):
"""Base class for encoders."""
def __init__(self):
def __init__(self, dictionary):
super().__init__()
self.dictionary = dictionary
def max_positions(self):
"""Maximum input length supported by the encoder."""
......
......@@ -12,8 +12,8 @@ from . import FairseqDecoder
class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders."""
def __init__(self):
super().__init__()
def __init__(self, dictionary):
super().__init__(dictionary)
self._is_incremental_eval = False
self._incremental_state = {}
......
......@@ -28,8 +28,7 @@ class FConvEncoder(FairseqEncoder):
"""Convolutional encoder"""
def __init__(self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1):
super().__init__()
self.dictionary = dictionary
super().__init__(dictionary)
self.dropout = dropout
self.num_attention_layers = None
......@@ -137,9 +136,8 @@ class FConvDecoder(FairseqIncrementalDecoder):
def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1):
super().__init__()
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2]))
self.dictionary = dictionary
self.dropout = dropout
in_channels = convolutions[0][0]
......
......@@ -23,8 +23,7 @@ class LSTMEncoder(FairseqEncoder):
"""LSTM encoder."""
def __init__(self, dictionary, embed_dim=512, num_layers=1, dropout_in=0.1,
dropout_out=0.1):
super().__init__()
self.dictionary = dictionary
super().__init__(dictionary)
self.dropout_in = dropout_in
self.dropout_out = dropout_out
......@@ -108,8 +107,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
def __init__(self, dictionary, encoder_embed_dim=512, embed_dim=512,
out_embed_dim=512, num_layers=1, dropout_in=0.1,
dropout_out=0.1, attention=True):
super().__init__()
self.dictionary = dictionary
super().__init__(dictionary)
self.dropout_in = dropout_in
self.dropout_out = dropout_out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment