Unverified Commit fe4e185a authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge internal changes

Changelog:
- `f472d141`: Support tied embeddings in LSTM encoder/decoder
- `89e19d42`: Don't print alignment by default (use `--print-alignment` to re-enable it)
- `d2e2a1d4`: Add Transformer-based language model
- `c2794070`: Add new Transformer configuration for IWSLT
- `2fbfda0d`: Misc changes for pytorch-translate
- Miscellaneous bug fixes
parents 7358296b 2fbfda0d
...@@ -37,11 +37,13 @@ def main(args): ...@@ -37,11 +37,13 @@ def main(args):
if args.fp16: if args.fp16:
model.half() model.half()
assert len(models) > 0
itr = data.EpochBatchIterator( itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset), dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens, max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences or 4, max_sentences=args.max_sentences,
max_positions=model.max_positions(), max_positions=models[0].max_positions(),
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
ignore_invalid_inputs=True, ignore_invalid_inputs=True,
...@@ -54,19 +56,51 @@ def main(args): ...@@ -54,19 +56,51 @@ def main(args):
score_sum = 0. score_sum = 0.
count = 0 count = 0
if args.remove_bpe is not None:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
bpe_len = len(bpe_cont)
else:
bpe_toks = None
bpe_len = 0
with progress_bar.build_progress_bar(args, itr) as t: with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter() wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results: for _, src_tokens, __, hypos in results:
for hypo in hypos: for hypo in hypos:
pos_scores = hypo['positional_scores'] pos_scores = hypo['positional_scores']
skipped_toks = 0
if bpe_toks is not None:
for i in range(len(hypo['tokens']) - 1):
if hypo['tokens'][i].item() in bpe_toks:
skipped_toks += 1
pos_scores[i + 1] += pos_scores[i]
pos_scores[i] = 0
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any(): if inf_scores.any():
print('| Skipping tokens with inf scores:', print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()])) task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()] pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum() score_sum += pos_scores.sum()
count += pos_scores.numel() count += pos_scores.numel() - skipped_toks
if args.output_word_probs:
w = ''
word_prob = []
for i in range(len(hypo['tokens'])):
w_ind = hypo['tokens'][i].item()
w += task.dictionary[w_ind]
if bpe_toks is not None and w_ind in bpe_toks:
w = w[:-bpe_len]
else:
word_prob.append((w, pos_scores[i].item()))
w = ''
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
wps_meter.update(src_tokens.size(0)) wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)}) t.log({'wps': round(wps_meter.avg)})
......
...@@ -36,6 +36,31 @@ $ python generate.py data-bin/iwslt14.tokenized.de-en \ ...@@ -36,6 +36,31 @@ $ python generate.py data-bin/iwslt14.tokenized.de-en \
``` ```
To train transformer model on IWSLT'14 German to English:
```
# Preparation steps are the same as for fconv model.
# Train the model (better for a single GPU setup):
$ mkdir -p checkpoints/transformer
$ CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
-a transformer_iwslt_de_en --optimizer adam --lr 0.0005 -s de -t en \
--label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 \
--min-lr '1e-09' --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --max-update 50000 \
--warmup-updates 4000 --warmup-init-lr '1e-07' \
--adam-betas '(0.9, 0.98)' --save-dir checkpoints/transformer
# Average 10 latest checkpoints:
$ python scripts/average_checkpoints.py --inputs checkpoints/transformer \
--num-epoch-checkpoints 10 --output checkpoints/transformer/model.pt
# Generate:
$ python generate.py data-bin/iwslt14.tokenized.de-en \
--path checkpoints/transformer/model.pt \
--batch-size 128 --beam 5 --remove-bpe
```
### prepare-wmt14en2de.sh ### prepare-wmt14en2de.sh
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
from .dictionary import Dictionary from .dictionary import Dictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .indexed_dataset import IndexedInMemoryDataset, IndexedRawTextDataset from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset # noqa: F401
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
......
...@@ -47,7 +47,7 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -47,7 +47,7 @@ class TokenBlockDataset(torch.utils.data.Dataset):
self.slice_indices = [block_at(i) for i in range(length)] self.slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete': elif break_mode == 'complete':
assert sizes is not None and sum(sizes) == len(tokens) assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
tok_idx = 0 tok_idx = 0
sz_idx = 0 sz_idx = 0
curr_size = 0 curr_size = 0
...@@ -62,7 +62,7 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -62,7 +62,7 @@ class TokenBlockDataset(torch.utils.data.Dataset):
if curr_size > 0: if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size)) self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos': elif break_mode == 'eos':
assert sizes is not None and sum(sizes) == len(tokens) assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
curr = 0 curr = 0
for sz in sizes: for sz in sizes:
# skip samples with just 1 example (which would be just the eos token) # skip samples with just 1 example (which would be just the eos token)
...@@ -76,14 +76,18 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -76,14 +76,18 @@ class TokenBlockDataset(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
s, e = self.slice_indices[index] s, e = self.slice_indices[index]
item = torch.LongTensor(self.tokens[s:e]) item = torch.LongTensor(self.tokens[s:e])
if self.include_targets: if self.include_targets:
if e == self.total_size: # target is the sentence, for source, rotate item one token to the left (would start with eos)
return item[:-1], item[1:] if s == 0:
source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]])
else: else:
return item, torch.LongTensor(self.tokens[s + 1:e + 1]) source = self.tokens[s - 1:e - 1]
else:
return item return torch.LongTensor(source), item
return item
def __len__(self): def __len__(self):
return len(self.slice_indices) return len(self.slice_indices)
...@@ -19,8 +19,14 @@ class FairseqDecoder(nn.Module): ...@@ -19,8 +19,14 @@ class FairseqDecoder(nn.Module):
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out):
raise NotImplementedError raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs, _): def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None:
assert sample is not None and 'target' in sample
out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target'])
return out.exp_() if not log_probs else out
logits = net_output[0].float() logits = net_output[0].float()
if log_probs: if log_probs:
return F.log_softmax(logits, dim=-1) return F.log_softmax(logits, dim=-1)
......
...@@ -268,16 +268,16 @@ class FConvEncoder(FairseqEncoder): ...@@ -268,16 +268,16 @@ class FConvEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_padding_mask': encoder_padding_mask, # B x T
} }
def reorder_encoder_out(self, encoder_out_dict, new_order): def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out_dict['encoder_out'] is not None: if encoder_out['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = ( encoder_out['encoder_out'] = (
encoder_out_dict['encoder_out'][0].index_select(0, new_order), encoder_out['encoder_out'][0].index_select(0, new_order),
encoder_out_dict['encoder_out'][1].index_select(0, new_order), encoder_out['encoder_out'][1].index_select(0, new_order),
) )
if encoder_out_dict['encoder_padding_mask'] is not None: if encoder_out['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order) encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict return encoder_out
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
...@@ -352,6 +352,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -352,6 +352,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
self.dropout = dropout self.dropout = dropout
self.normalization_constant = normalization_constant self.normalization_constant = normalization_constant
self.left_pad = left_pad self.left_pad = left_pad
self.need_attn = True
convolutions = extend_conv_spec(convolutions) convolutions = extend_conv_spec(convolutions)
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
...@@ -466,11 +467,13 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -466,11 +467,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
x = self._transpose_if_training(x, incremental_state) x = self._transpose_if_training(x, incremental_state)
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask) x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask)
attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None: if not self.training and self.need_attn:
avg_attn_scores = attn_scores attn_scores = attn_scores / num_attn_layers
else: if avg_attn_scores is None:
avg_attn_scores.add_(attn_scores) avg_attn_scores = attn_scores
else:
avg_attn_scores.add_(attn_scores)
x = self._transpose_if_training(x, incremental_state) x = self._transpose_if_training(x, incremental_state)
...@@ -490,16 +493,6 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -490,16 +493,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x, avg_attn_scores return x, avg_attn_scores
def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""
if self.adaptive_softmax is not None:
assert sample is not None and 'target' in sample
out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target'])
return out.exp_() if not log_probs else out
else:
return super().get_normalized_probs(net_output, log_probs, sample)
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order) super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out') encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
...@@ -521,6 +514,9 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -521,6 +514,9 @@ class FConvDecoder(FairseqIncrementalDecoder):
state_dict['decoder.version'] = torch.Tensor([1]) state_dict['decoder.version'] = torch.Tensor([1])
return state_dict return state_dict
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def _embed_tokens(self, tokens, incremental_state): def _embed_tokens(self, tokens, incremental_state):
if incremental_state is not None: if incremental_state is not None:
# keep only the last token for incremental forward pass # keep only the last token for incremental forward pass
......
...@@ -226,18 +226,18 @@ class FConvEncoder(FairseqEncoder): ...@@ -226,18 +226,18 @@ class FConvEncoder(FairseqEncoder):
'encoder_out': (x, y), 'encoder_out': (x, y),
} }
def reorder_encoder_out(self, encoder_out_dict, new_order): def reorder_encoder_out(self, encoder_out, new_order):
encoder_out_dict['encoder_out'] = tuple( encoder_out['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder_out'] eo.index_select(0, new_order) for eo in encoder_out['encoder_out']
) )
if 'pretrained' in encoder_out_dict: if 'pretrained' in encoder_out:
encoder_out_dict['pretrained']['encoder_out'] = tuple( encoder_out['pretrained']['encoder_out'] = tuple(
eo.index_select(0, new_order) eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder_out'] for eo in encoder_out['pretrained']['encoder_out']
) )
return encoder_out_dict return encoder_out
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
...@@ -259,6 +259,7 @@ class FConvDecoder(FairseqDecoder): ...@@ -259,6 +259,7 @@ class FConvDecoder(FairseqDecoder):
self.pretrained_decoder = trained_decoder self.pretrained_decoder = trained_decoder
self.dropout = dropout self.dropout = dropout
self.left_pad = left_pad self.left_pad = left_pad
self.need_attn = True
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
def expand_bool_array(val): def expand_bool_array(val):
...@@ -388,10 +389,11 @@ class FConvDecoder(FairseqDecoder): ...@@ -388,10 +389,11 @@ class FConvDecoder(FairseqDecoder):
r = x r = x
x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b) x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b)
x = x + r x = x + r
if avg_attn_scores is None: if not self.training and self.need_attn:
avg_attn_scores = attn_scores if avg_attn_scores is None:
else: avg_attn_scores = attn_scores
avg_attn_scores.add_(attn_scores) else:
avg_attn_scores.add_(attn_scores)
if selfattention is not None: if selfattention is not None:
x = selfattention(x) x = selfattention(x)
...@@ -426,6 +428,9 @@ class FConvDecoder(FairseqDecoder): ...@@ -426,6 +428,9 @@ class FConvDecoder(FairseqDecoder):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def _split_encoder_out(self, encoder_out): def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs.""" """Split and transpose encoder outputs."""
# transpose only once to speed up attention layers # transpose only once to speed up attention layers
......
...@@ -59,6 +59,12 @@ class LSTMModel(FairseqModel): ...@@ -59,6 +59,12 @@ class LSTMModel(FairseqModel):
help='dropout probability for decoder input embedding') help='dropout probability for decoder input embedding')
parser.add_argument('--decoder-dropout-out', type=float, metavar='D', parser.add_argument('--decoder-dropout-out', type=float, metavar='D',
help='dropout probability for decoder output') help='dropout probability for decoder output')
parser.add_argument('--share-decoder-input-output-embed', default=False,
action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
...@@ -74,14 +80,47 @@ class LSTMModel(FairseqModel): ...@@ -74,14 +80,47 @@ class LSTMModel(FairseqModel):
utils.print_embed_overlap(embed_dict, dictionary) utils.print_embed_overlap(embed_dict, dictionary)
return utils.load_embedding(embed_dict, dictionary, embed_tokens) return utils.load_embedding(embed_dict, dictionary, embed_tokens)
pretrained_encoder_embed = None
if args.encoder_embed_path: if args.encoder_embed_path:
pretrained_encoder_embed = load_pretrained_embedding_from_file( pretrained_encoder_embed = load_pretrained_embedding_from_file(
args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim) args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim)
pretrained_decoder_embed = None else:
if args.decoder_embed_path: num_embeddings = len(task.source_dictionary)
pretrained_decoder_embed = load_pretrained_embedding_from_file( pretrained_encoder_embed = Embedding(
args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim) num_embeddings, args.encoder_embed_dim, task.source_dictionary.pad()
)
if args.share_all_embeddings:
# double check all parameters combinations are valid
if task.source_dictionary != task.target_dictionary:
raise RuntimeError('--share-all-embeddings requires a joint dictionary')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise RuntimeError(
'--share-all-embed not compatible with --decoder-embed-path'
)
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to '
'match --decoder-embed-dim'
)
pretrained_decoder_embed = pretrained_encoder_embed
args.share_decoder_input_output_embed = True
else:
# separate decoder input embeddings
pretrained_decoder_embed = None
if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path,
task.target_dictionary,
args.decoder_embed_dim
)
# one last double check of parameter combinations
if args.share_decoder_input_output_embed and (
args.decoder_embed_dim != args.decoder_out_embed_dim):
raise RuntimeError(
'--share-decoder-input-output-embeddings requires '
'--decoder-embed-dim to match --decoder-out-embed-dim'
)
encoder = LSTMEncoder( encoder = LSTMEncoder(
dictionary=task.source_dictionary, dictionary=task.source_dictionary,
...@@ -105,6 +144,7 @@ class LSTMModel(FairseqModel): ...@@ -105,6 +144,7 @@ class LSTMModel(FairseqModel):
encoder_embed_dim=args.encoder_embed_dim, encoder_embed_dim=args.encoder_embed_dim,
encoder_output_units=encoder.output_units, encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed, pretrained_embed=pretrained_decoder_embed,
share_input_output_embed=args.share_decoder_input_output_embed,
) )
return cls(encoder, decoder) return cls(encoder, decoder)
...@@ -197,15 +237,15 @@ class LSTMEncoder(FairseqEncoder): ...@@ -197,15 +237,15 @@ class LSTMEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None 'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
} }
def reorder_encoder_out(self, encoder_out_dict, new_order): def reorder_encoder_out(self, encoder_out, new_order):
encoder_out_dict['encoder_out'] = tuple( encoder_out['encoder_out'] = tuple(
eo.index_select(1, new_order) eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out'] for eo in encoder_out['encoder_out']
) )
if encoder_out_dict['encoder_padding_mask'] is not None: if encoder_out['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order) encoder_out['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict return encoder_out
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
...@@ -251,11 +291,14 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -251,11 +291,14 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_embed_dim=512, encoder_output_units=512, pretrained_embed=None, encoder_embed_dim=512, encoder_output_units=512, pretrained_embed=None,
share_input_output_embed=False,
): ):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout_in = dropout_in self.dropout_in = dropout_in
self.dropout_out = dropout_out self.dropout_out = dropout_out
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.share_input_output_embed = share_input_output_embed
self.need_attn = True
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
...@@ -279,7 +322,8 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -279,7 +322,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.attention = AttentionLayer(encoder_output_units, hidden_size) if attention else None self.attention = AttentionLayer(encoder_output_units, hidden_size) if attention else None
if hidden_size != out_embed_dim: if hidden_size != out_embed_dim:
self.additional_fc = Linear(hidden_size, out_embed_dim) self.additional_fc = Linear(hidden_size, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) if not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out'] encoder_out = encoder_out_dict['encoder_out']
...@@ -352,13 +396,19 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -352,13 +396,19 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x = x.transpose(1, 0) x = x.transpose(1, 0)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
attn_scores = attn_scores.transpose(0, 2) if not self.training and self.need_attn:
attn_scores = attn_scores.transpose(0, 2)
else:
attn_scores = None
# project back to size of vocabulary # project back to size of vocabulary
if hasattr(self, 'additional_fc'): if hasattr(self, 'additional_fc'):
x = self.additional_fc(x) x = self.additional_fc(x)
x = F.dropout(x, p=self.dropout_out, training=self.training) x = F.dropout(x, p=self.dropout_out, training=self.training)
x = self.fc_out(x) if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = self.fc_out(x)
return x, attn_scores return x, attn_scores
...@@ -380,6 +430,9 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -380,6 +430,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number return int(1e5) # an arbitrary large number
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
...@@ -405,7 +458,7 @@ def LSTMCell(input_size, hidden_size, **kwargs): ...@@ -405,7 +458,7 @@ def LSTMCell(input_size, hidden_size, **kwargs):
def Linear(in_features, out_features, bias=True, dropout=0): def Linear(in_features, out_features, bias=True, dropout=0):
"""Weight-normalized Linear layer (input: N x T x C)""" """Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias) m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.uniform_(-0.1, 0.1) m.weight.data.uniform_(-0.1, 0.1)
if bias: if bias:
...@@ -431,6 +484,8 @@ def base_architecture(args): ...@@ -431,6 +484,8 @@ def base_architecture(args):
args.decoder_attention = getattr(args, 'decoder_attention', '1') args.decoder_attention = getattr(args, 'decoder_attention', '1')
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en') @register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
......
...@@ -11,16 +11,16 @@ import torch ...@@ -11,16 +11,16 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import options
from fairseq import utils from fairseq import utils
from fairseq.modules import ( from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention, AdaptiveSoftmax, LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding
SinusoidalPositionalEmbedding,
) )
from . import ( from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqModel, FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel, FairseqModel, register_model,
register_model, register_model_architecture, register_model_architecture,
) )
...@@ -71,13 +71,22 @@ class TransformerModel(FairseqModel): ...@@ -71,13 +71,22 @@ class TransformerModel(FairseqModel):
parser.add_argument('--share-all-embeddings', action='store_true', parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings' help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)') ' (requires shared dictionary and embed dim)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones)
# make sure all arguments are present in older models
base_architecture(args) base_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = 1024
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = 1024
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None): def build_embedding(dictionary, embed_dim, path=None):
...@@ -117,6 +126,56 @@ class TransformerModel(FairseqModel): ...@@ -117,6 +126,56 @@ class TransformerModel(FairseqModel):
return TransformerModel(encoder, decoder) return TransformerModel(encoder, decoder)
@register_model('transformer_lm')
class TransformerLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', default=0., type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
help='share decoder input and output embeddings')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_lm_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.tokens_per_sample
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = args.tokens_per_sample
embed_tokens = Embedding(len(task.dictionary), args.decoder_embed_dim, task.dictionary.pad())
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True)
return TransformerLanguageModel(decoder)
class TransformerEncoder(FairseqEncoder): class TransformerEncoder(FairseqEncoder):
"""Transformer encoder.""" """Transformer encoder."""
...@@ -126,14 +185,15 @@ class TransformerEncoder(FairseqEncoder): ...@@ -126,14 +185,15 @@ class TransformerEncoder(FairseqEncoder):
embed_dim = embed_tokens.embedding_dim embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx, args.max_source_positions, embed_dim, self.padding_idx,
left_pad=left_pad, left_pad=left_pad,
learned=args.encoder_learned_pos, learned=args.encoder_learned_pos,
) ) if not args.no_token_positional_embeddings else None
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.layers.extend([ self.layers.extend([
...@@ -144,7 +204,8 @@ class TransformerEncoder(FairseqEncoder): ...@@ -144,7 +204,8 @@ class TransformerEncoder(FairseqEncoder):
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
# embed tokens and positions # embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens) x = self.embed_scale * self.embed_tokens(src_tokens)
x += self.embed_positions(src_tokens) if self.embed_positions is not None:
x += self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C # B x T x C -> T x B x C
...@@ -164,106 +225,136 @@ class TransformerEncoder(FairseqEncoder): ...@@ -164,106 +225,136 @@ class TransformerEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_padding_mask': encoder_padding_mask, # B x T
} }
def reorder_encoder_out(self, encoder_out_dict, new_order): def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out_dict['encoder_out'] is not None: if encoder_out['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = \ encoder_out['encoder_out'] = \
encoder_out_dict['encoder_out'].index_select(1, new_order) encoder_out['encoder_out'].index_select(1, new_order)
if encoder_out_dict['encoder_padding_mask'] is not None: if encoder_out['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order) encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict return encoder_out
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.embed_positions.max_positions() if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions())
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'encoder.embed_positions.weights' in state_dict: if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights'] del state_dict['encoder.embed_positions.weights']
if 'encoder.embed_positions._float_tensor' not in state_dict: state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict return state_dict
class TransformerDecoder(FairseqIncrementalDecoder): class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder.""" """Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens, left_pad=False): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = args.dropout self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed self.share_input_output_embed = args.share_decoder_input_output_embed
embed_dim = embed_tokens.embedding_dim embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx, args.max_target_positions, embed_dim, padding_idx,
left_pad=left_pad, left_pad=left_pad,
learned=args.decoder_learned_pos, learned=args.decoder_learned_pos,
) ) if not args.no_token_positional_embeddings else None
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.layers.extend([ self.layers.extend([
TransformerDecoderLayer(args) TransformerDecoderLayer(args, no_encoder_attn)
for i in range(args.decoder_layers) for _ in range(args.decoder_layers)
]) ])
if not self.share_input_output_embed: self.adaptive_softmax = None
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary), args.decoder_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.dropout
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5) nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
# embed positions # embed positions
positions = self.embed_positions( positions = self.embed_positions(
prev_output_tokens, prev_output_tokens,
incremental_state=incremental_state, incremental_state=incremental_state,
) ) if self.embed_positions is not None else None
if incremental_state is not None: if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:] prev_output_tokens = prev_output_tokens[:, -1:]
positions = positions[:, -1:] if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions # embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens) x = self.embed_scale * self.embed_tokens(prev_output_tokens)
x += positions if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
attn = None
# decoder layers # decoder layers
for layer in self.layers: for layer in self.layers:
x, attn = layer( x, attn = layer(
x, x,
encoder_out['encoder_out'], encoder_out['encoder_out'] if encoder_out is not None else None,
encoder_out['encoder_padding_mask'], encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state, incremental_state,
) )
# T x B x C -> B x T x C # T x B x C -> B x T x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
# project back to size of vocabulary if self.adaptive_softmax is None:
if self.share_input_output_embed: # project back to size of vocabulary
x = F.linear(x, self.embed_tokens.weight) if self.share_input_output_embed:
else: x = F.linear(x, self.embed_tokens.weight)
x = F.linear(x, self.embed_out) else:
x = F.linear(x, self.embed_out)
return x, attn return x, attn
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions())
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'decoder.embed_positions.weights' in state_dict: if 'decoder.embed_positions.weights' in state_dict:
del state_dict['decoder.embed_positions.weights'] del state_dict['decoder.embed_positions.weights']
if 'decoder.embed_positions._float_tensor' not in state_dict: state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
'0': 'self_attn_layer_norm',
'1': 'encoder_attn_layer_norm',
'2': 'final_layer_norm'
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = 'decoder.layers.{}.layer_norms.{}.{}'.format(i, old, m)
if k in state_dict:
state_dict['decoder.layers.{}.{}.{}'.format(i, new, m)] = state_dict[k]
del state_dict[k]
return state_dict return state_dict
...@@ -322,7 +413,7 @@ class TransformerEncoderLayer(nn.Module): ...@@ -322,7 +413,7 @@ class TransformerEncoderLayer(nn.Module):
class TransformerDecoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.""" """Decoder layer block."""
def __init__(self, args): def __init__(self, args, no_encoder_attn=False):
super().__init__() super().__init__()
self.embed_dim = args.decoder_embed_dim self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention( self.self_attn = MultiheadAttention(
...@@ -332,17 +423,28 @@ class TransformerDecoderLayer(nn.Module): ...@@ -332,17 +423,28 @@ class TransformerDecoderLayer(nn.Module):
self.dropout = args.dropout self.dropout = args.dropout
self.relu_dropout = args.relu_dropout self.relu_dropout = args.relu_dropout
self.normalize_before = args.decoder_normalize_before self.normalize_before = args.decoder_normalize_before
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads, self.self_attn_layer_norm = LayerNorm(self.embed_dim)
dropout=args.attention_dropout,
) if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
self.final_layer_norm = LayerNorm(self.embed_dim)
self.need_attn = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state): def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x residual = x
x = self.maybe_layer_norm(0, x, before=True) x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, _ = self.self_attn( x, _ = self.self_attn(
query=x, query=x,
key=x, key=x,
...@@ -353,43 +455,50 @@ class TransformerDecoderLayer(nn.Module): ...@@ -353,43 +455,50 @@ class TransformerDecoderLayer(nn.Module):
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(0, x, after=True) x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
residual = x attn = None
x = self.maybe_layer_norm(1, x, before=True) if self.encoder_attn is not None:
x, attn = self.encoder_attn( residual = x
query=x, x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
key=encoder_out, x, attn = self.encoder_attn(
value=encoder_out, query=x,
key_padding_mask=encoder_padding_mask, key=encoder_out,
incremental_state=incremental_state, value=encoder_out,
static_kv=True, key_padding_mask=encoder_padding_mask,
) incremental_state=incremental_state,
x = F.dropout(x, p=self.dropout, training=self.training) static_kv=True,
x = residual + x need_weights=(not self.training and self.need_attn),
x = self.maybe_layer_norm(1, x, after=True) )
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
residual = x residual = x
x = self.maybe_layer_norm(2, x, before=True) x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training) x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x) x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(2, x, after=True) x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return x, attn return x, attn
def maybe_layer_norm(self, i, x, before=False, after=False): def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after assert before ^ after
if after ^ self.normalize_before: if after ^ self.normalize_before:
return self.layer_norms[i](x) return layer_norm(x)
else: else:
return x return x
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m return m
...@@ -407,14 +516,49 @@ def Linear(in_features, out_features, bias=True): ...@@ -407,14 +516,49 @@ def Linear(in_features, out_features, bias=True):
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False): def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned: if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad) m = LearnedPositionalEmbedding(num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0) nn.init.constant_(m.weight[padding_idx], 0)
else: else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings) m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings + padding_idx + 1)
return m return m
@register_model_architecture('transformer_lm', 'transformer_lm')
def base_lm_architecture(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
# The model training is not stable without this
args.decoder_normalize_before = True
@register_model_architecture('transformer_lm', 'transformer_lm_big')
def transformer_lm_big(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
base_lm_architecture(args)
@register_model_architecture('transformer_lm', 'transformer_lm_wiki103')
def transformer_lm_wiki103(args):
args.dropout = getattr(args, 'dropout', 0.3)
base_lm_architecture(args)
@register_model_architecture('transformer_lm', 'transformer_lm_gbw')
def transformer_lm_gbw(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
transformer_lm_big(args)
@register_model_architecture('transformer', 'transformer') @register_model_architecture('transformer', 'transformer')
def base_architecture(args): def base_architecture(args):
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
...@@ -434,20 +578,22 @@ def base_architecture(args): ...@@ -434,20 +578,22 @@ def base_architecture(args):
args.attention_dropout = getattr(args, 'attention_dropout', 0.) args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.relu_dropout = getattr(args, 'relu_dropout', 0.) args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
@register_model_architecture('transformer', 'transformer_iwslt_de_en') @register_model_architecture('transformer', 'transformer_iwslt_de_en')
def transformer_iwslt_de_en(args): def transformer_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 3) args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 512) args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 3) args.decoder_layers = getattr(args, 'decoder_layers', 6)
base_architecture(args) base_architecture(args)
......
...@@ -30,7 +30,7 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -30,7 +30,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
embedding_dim, embedding_dim,
padding_idx, padding_idx,
) )
self.register_buffer('_float_tensor', torch.FloatTensor()) self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod @staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None): def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
......
...@@ -16,7 +16,7 @@ class FixedSchedule(FairseqLRScheduler): ...@@ -16,7 +16,7 @@ class FixedSchedule(FairseqLRScheduler):
super().__init__(args, optimizer) super().__init__(args, optimizer)
# set defaults # set defaults
args.warmup_updates = getattr(args, 'warmup_updates', 0) args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0
self.lr = args.lr[0] self.lr = args.lr[0]
if args.warmup_updates > 0: if args.warmup_updates > 0:
......
...@@ -62,7 +62,7 @@ def eval_bool(x, default=False): ...@@ -62,7 +62,7 @@ def eval_bool(x, default=False):
return default return default
def parse_args_and_arch(parser, input_args=None): def parse_args_and_arch(parser, input_args=None, parse_known=False):
# The parser doesn't know about model/criterion/optimizer-specific args, so # The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we # we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments. # parse a second time after adding the *-specific arguments.
...@@ -90,7 +90,11 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -90,7 +90,11 @@ def parse_args_and_arch(parser, input_args=None):
TASK_REGISTRY[args.task].add_args(parser) TASK_REGISTRY[args.task].add_args(parser)
# Parse a second time. # Parse a second time.
args = parser.parse_args(input_args) if parse_known:
args, extra = parser.parse_known_args(input_args)
else:
args = parser.parse_args(input_args)
extra = None
# Post-process args. # Post-process args.
if hasattr(args, 'lr'): if hasattr(args, 'lr'):
...@@ -104,7 +108,10 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -104,7 +108,10 @@ def parse_args_and_arch(parser, input_args=None):
if hasattr(args, 'arch'): if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args) ARCH_CONFIG_REGISTRY[args.arch](args)
return args if parse_known:
return args, extra
else:
return args
def get_parser(desc, default_task='translation'): def get_parser(desc, default_task='translation'):
...@@ -249,6 +256,8 @@ def add_common_eval_args(group): ...@@ -249,6 +256,8 @@ def add_common_eval_args(group):
def add_eval_lm_args(parser): def add_eval_lm_args(parser):
group = parser.add_argument_group('LM Evaluation') group = parser.add_argument_group('LM Evaluation')
add_common_eval_args(group) add_common_eval_args(group)
group.add_argument('--output-word-probs', action='store_true',
help='if set, outputs words and their predicted log probabilities to standard output')
def add_generation_args(parser): def add_generation_args(parser):
...@@ -290,6 +299,8 @@ def add_generation_args(parser): ...@@ -290,6 +299,8 @@ def add_generation_args(parser):
help='sample from top K likely next words instead of all words') help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N', group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling') help='temperature for random sampling')
group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training') help='a dictionary used to override model args at generation that were used during model training')
return group return group
......
...@@ -126,8 +126,8 @@ class SequenceGenerator(object): ...@@ -126,8 +126,8 @@ class SequenceGenerator(object):
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad) tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone() tokens_buf = tokens.clone()
tokens[:, 0] = self.eos tokens[:, 0] = self.eos
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2) attn, attn_buf = None, None
attn_buf = attn.clone() nonpad_idxs = None
# list of completed sentences # list of completed sentences
finalized = [[] for i in range(bsz)] finalized = [[] for i in range(bsz)]
...@@ -191,7 +191,7 @@ class SequenceGenerator(object): ...@@ -191,7 +191,7 @@ class SequenceGenerator(object):
tokens_clone = tokens.index_select(0, bbsz_idx) tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
# compute scores per token position # compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1] pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
...@@ -220,10 +220,13 @@ class SequenceGenerator(object): ...@@ -220,10 +220,13 @@ class SequenceGenerator(object):
def get_hypo(): def get_hypo():
# remove padding tokens from attn scores if attn_clone is not None:
nonpad_idxs = src_tokens[sent].ne(self.pad) # remove padding tokens from attn scores
hypo_attn = attn_clone[i][nonpad_idxs] hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0) _, alignment = hypo_attn.max(dim=0)
else:
hypo_attn = None
alignment = None
return { return {
'tokens': tokens_clone[i], 'tokens': tokens_clone[i],
...@@ -270,8 +273,7 @@ class SequenceGenerator(object): ...@@ -270,8 +273,7 @@ class SequenceGenerator(object):
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state) model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state) encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode( probs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)
tokens[:, :step + 1], encoder_outs, incremental_states)
if step == 0: if step == 0:
# at the first step all hypotheses are equally likely, so use # at the first step all hypotheses are equally likely, so use
# only the first beam # only the first beam
...@@ -286,7 +288,12 @@ class SequenceGenerator(object): ...@@ -286,7 +288,12 @@ class SequenceGenerator(object):
probs[:, self.unk] -= self.unk_penalty # apply unk penalty probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores # Record attention scores
attn[:, :, step + 1].copy_(avg_attn_scores) if avg_attn_scores is not None:
if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores) cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices') cand_indices = buffer('cand_indices')
...@@ -417,8 +424,9 @@ class SequenceGenerator(object): ...@@ -417,8 +424,9 @@ class SequenceGenerator(object):
scores_buf.resize_as_(scores) scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens) tokens_buf.resize_as_(tokens)
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1) if attn is not None:
attn_buf.resize_as_(attn) attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz bsz = new_bsz
else: else:
batch_idxs = None batch_idxs = None
...@@ -473,15 +481,17 @@ class SequenceGenerator(object): ...@@ -473,15 +481,17 @@ class SequenceGenerator(object):
) )
# copy attention for active hypotheses # copy attention for active hypotheses
torch.index_select( if attn is not None:
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, torch.index_select(
out=attn_buf[:, :, :step + 2], attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
) out=attn_buf[:, :, :step + 2],
)
# swap buffers # swap buffers
tokens, tokens_buf = tokens_buf, tokens tokens, tokens_buf = tokens_buf, tokens
scores, scores_buf = scores_buf, scores scores, scores_buf = scores_buf, scores
attn, attn_buf = attn_buf, attn if attn is not None:
attn, attn_buf = attn_buf, attn
# reorder incremental state in decoder # reorder incremental state in decoder
reorder_state = active_bbsz_idx reorder_state = active_bbsz_idx
...@@ -518,7 +528,7 @@ class SequenceGenerator(object): ...@@ -518,7 +528,7 @@ class SequenceGenerator(object):
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs): def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
with torch.no_grad(): with torch.no_grad():
if incremental_states[model] is not None: if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model])) decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model]))
else: else:
decoder_out = list(model.decoder(tokens, encoder_out)) decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :] decoder_out[0] = decoder_out[0][:, -1, :]
......
...@@ -5,9 +5,6 @@ ...@@ -5,9 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from fairseq import criterions, models
from fairseq.data import FairseqDataset
class FairseqTask(object): class FairseqTask(object):
""" """
...@@ -28,11 +25,12 @@ class FairseqTask(object): ...@@ -28,11 +25,12 @@ class FairseqTask(object):
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
raise NotImplementedError raise NotImplementedError
def load_dataset(self, split): def load_dataset(self, split, combine=False):
raise NotImplementedError raise NotImplementedError
def dataset(self, split): def dataset(self, split):
"""Return a dataset split.""" """Return a dataset split."""
from fairseq.data import FairseqDataset
if split not in self.datasets: if split not in self.datasets:
raise KeyError('Dataset not loaded: ' + split) raise KeyError('Dataset not loaded: ' + split)
if not isinstance(self.datasets[split], FairseqDataset): if not isinstance(self.datasets[split], FairseqDataset):
...@@ -40,9 +38,11 @@ class FairseqTask(object): ...@@ -40,9 +38,11 @@ class FairseqTask(object):
return self.datasets[split] return self.datasets[split]
def build_model(self, args): def build_model(self, args):
from fairseq import models
return models.build_model(args, self) return models.build_model(args, self)
def build_criterion(self, args): def build_criterion(self, args):
from fairseq import criterions
return criterions.build_criterion(args, self) return criterions.build_criterion(args, self)
def get_loss(self, model, criterion, sample): def get_loss(self, model, criterion, sample):
......
...@@ -5,8 +5,12 @@ ...@@ -5,8 +5,12 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os import os
from torch.utils.data import ConcatDataset
from fairseq.data import ( from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset, Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset, MonolingualDataset, TokenBlockDataset,
...@@ -43,23 +47,46 @@ class LanguageModelingTask(FairseqTask): ...@@ -43,23 +47,46 @@ class LanguageModelingTask(FairseqTask):
print('| dictionary: {} types'.format(len(dictionary))) print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary) return cls(args, dictionary)
def load_dataset(self, split): def load_dataset(self, split, combine=False):
"""Load a dataset split.""" """Load a dataset split."""
path = os.path.join(self.args.data, split)
if self.args.raw_text and IndexedRawTextDataset.exists(path): loaded_datasets = []
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = ds.tokens_list for k in itertools.count():
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path): split_k = split + (str(k) if k > 0 else '')
ds = IndexedInMemoryDataset(path, fix_lua_indexing=True) path = os.path.join(self.args.data, split_k)
tokens = ds.buffer
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = [t for l in ds.tokens_list for t in l]
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
tokens = ds.buffer
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
loaded_datasets.append(
TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True
))
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
if not combine:
break
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else: else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
dataset = TokenBlockDataset( self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False)
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True, # return next tokens as targets
)
self.datasets[split] = MonolingualDataset(dataset, dataset.sizes, self.dictionary, shuffle=False)
@property @property
def target_dictionary(self): def target_dictionary(self):
......
...@@ -5,8 +5,12 @@ ...@@ -5,8 +5,12 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os import os
from torch.utils.data import ConcatDataset
from fairseq import options from fairseq import options
from fairseq.data import ( from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset, data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
...@@ -65,10 +69,10 @@ class TranslationTask(FairseqTask): ...@@ -65,10 +69,10 @@ class TranslationTask(FairseqTask):
return cls(args, src_dict, tgt_dict) return cls(args, src_dict, tgt_dict)
def load_dataset(self, split): def load_dataset(self, split, combine=False):
"""Load a dataset split.""" """Load a dataset split."""
def split_exists(src, tgt, lang): def split_exists(split, src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename): if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True return True
...@@ -76,15 +80,6 @@ class TranslationTask(FairseqTask): ...@@ -76,15 +80,6 @@ class TranslationTask(FairseqTask):
return True return True
return False return False
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
def indexed_dataset(path, dictionary): def indexed_dataset(path, dictionary):
if self.args.raw_text: if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary) return IndexedRawTextDataset(path, dictionary)
...@@ -92,11 +87,48 @@ class TranslationTask(FairseqTask): ...@@ -92,11 +87,48 @@ class TranslationTask(FairseqTask):
return IndexedInMemoryDataset(path, fix_lua_indexing=True) return IndexedInMemoryDataset(path, fix_lua_indexing=True)
return None return None
src_dataset = indexed_dataset(prefix + src, self.src_dict) src_datasets = []
tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict) tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(split_k, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))
print('| {} {} {} examples'.format(self.args.data, split_k, len(src_datasets[-1])))
if not combine:
break
assert len(src_datasets) == len(tgt_datasets)
if len(src_datasets) == 1:
src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
src_sizes = src_dataset.sizes
tgt_sizes = tgt_dataset.sizes
else:
src_dataset = ConcatDataset(src_datasets)
tgt_dataset = ConcatDataset(tgt_datasets)
src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])
self.datasets[split] = LanguagePairDataset( self.datasets[split] = LanguagePairDataset(
src_dataset, src_dataset.sizes, self.src_dict, src_dataset, src_sizes, self.src_dict,
tgt_dataset, tgt_dataset.sizes, self.tgt_dict, tgt_dataset, tgt_sizes, self.tgt_dict,
left_pad_source=self.args.left_pad_source, left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target, left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions, max_source_positions=self.args.max_source_positions,
......
...@@ -140,6 +140,11 @@ class Trainer(object): ...@@ -140,6 +140,11 @@ class Trainer(object):
ooms_fwd = sum(ooms_fwd) ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd) ooms_bwd = sum(ooms_bwd)
if ooms_fwd == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch')
self.zero_grad()
return None
# aggregate stats and logging outputs # aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
...@@ -178,11 +183,6 @@ class Trainer(object): ...@@ -178,11 +183,6 @@ class Trainer(object):
return None # buffering updates return None # buffering updates
def _forward(self, sample, eval=False): def _forward(self, sample, eval=False):
# prepare model and optimizer
if eval:
self.model.eval()
else:
self.model.train()
loss = None loss = None
sample_size = 0 sample_size = 0
logging_output = { logging_output = {
...@@ -190,19 +190,26 @@ class Trainer(object): ...@@ -190,19 +190,26 @@ class Trainer(object):
'nsentences': sample['target'].size(0) if sample is not None else 0, 'nsentences': sample['target'].size(0) if sample is not None else 0,
} }
oom = 0 oom = 0
if sample is not None: try:
try: # prepare model and optimizer
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
if sample is not None:
with torch.no_grad() if eval else contextlib.ExitStack(): with torch.no_grad() if eval else contextlib.ExitStack():
# calculate loss and sample size # calculate loss and sample size
loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample) loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
logging_output.update(logging_output_) logging_output.update(logging_output_)
except RuntimeError as e: except RuntimeError as e:
if not eval and 'out of memory' in str(e): if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch') print('| WARNING: ran out of memory, skipping batch')
oom = 1 oom = 1
loss = None loss = None
else: else:
raise e raise e
return loss, sample_size, logging_output, oom return loss, sample_size, logging_output, oom
def _backward(self, loss): def _backward(self, loss):
......
...@@ -42,7 +42,10 @@ def main(args): ...@@ -42,7 +42,10 @@ def main(args):
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam) model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16: if args.fp16:
model.half() model.half()
...@@ -115,7 +118,7 @@ def main(args): ...@@ -115,7 +118,7 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(), hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str, src_str=src_str,
alignment=hypo['alignment'].int().cpu(), alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict, align_dict=align_dict,
tgt_dict=tgt_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
...@@ -130,10 +133,12 @@ def main(args): ...@@ -130,10 +133,12 @@ def main(args):
hypo['positional_scores'].tolist(), hypo['positional_scores'].tolist(),
)) ))
)) ))
print('A-{}\t{}'.format(
sample_id, if args.print_alignment:
' '.join(map(lambda x: str(utils.item(x)), alignment)) print('A-{}\t{}'.format(
)) sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
# Score only the top hypothesis # Score only the top hypothesis
if has_target and i == 0: if has_target and i == 0:
......
...@@ -17,7 +17,7 @@ from fairseq.sequence_generator import SequenceGenerator ...@@ -17,7 +17,7 @@ from fairseq.sequence_generator import SequenceGenerator
Batch = namedtuple('Batch', 'srcs tokens lengths') Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos alignments') Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
def buffered_read(buffer_size): def buffered_read(buffer_size):
...@@ -81,7 +81,10 @@ def main(args): ...@@ -81,7 +81,10 @@ def main(args):
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam) model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16: if args.fp16:
model.half() model.half()
...@@ -104,6 +107,7 @@ def main(args): ...@@ -104,6 +107,7 @@ def main(args):
result = Translation( result = Translation(
src_str='O\t{}'.format(src_str), src_str='O\t{}'.format(src_str),
hypos=[], hypos=[],
pos_scores=[],
alignments=[], alignments=[],
) )
...@@ -112,13 +116,22 @@ def main(args): ...@@ -112,13 +116,22 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(), hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str, src_str=src_str,
alignment=hypo['alignment'].int().cpu(), alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict, align_dict=align_dict,
tgt_dict=tgt_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
) )
result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str)) result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
result.alignments.append('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))) result.pos_scores.append('P\t{}'.format(
' '.join(map(
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
))
))
result.alignments.append(
'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
if args.print_alignment else None
)
return result return result
def process_batch(batch): def process_batch(batch):
...@@ -150,9 +163,11 @@ def main(args): ...@@ -150,9 +163,11 @@ def main(args):
for i in np.argsort(indices): for i in np.argsort(indices):
result = results[i] result = results[i]
print(result.src_str) print(result.src_str)
for hypo, align in zip(result.hypos, result.alignments): for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
print(hypo) print(hypo)
print(align) print(pos_scores)
if align is not None:
print(align)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -203,6 +203,7 @@ def generate_main(data_dir, extra_flags=None): ...@@ -203,6 +203,7 @@ def generate_main(data_dir, extra_flags=None):
'--max-len-b', '5', '--max-len-b', '5',
'--gen-subset', 'valid', '--gen-subset', 'valid',
'--no-progress-bar', '--no-progress-bar',
'--print-alignment',
] + (extra_flags or []), ] + (extra_flags or []),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment