"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "b66a85ae6c344b838e90d65740e68051ea69ffc8"
Commit 9998bbfa authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/505

Differential Revision: D14110201

Pulled By: myleott

fbshipit-source-id: 099ce61fa386c016f3a1d1815c6fe1a9a6c9005d
parent 184629a7
...@@ -153,10 +153,6 @@ class BacktranslationDataset(FairseqDataset): ...@@ -153,10 +153,6 @@ class BacktranslationDataset(FairseqDataset):
"""Just use the tgt dataset ordered_indices""" """Just use the tgt dataset ordered_indices"""
return self.tgt_dataset.ordered_indices() return self.tgt_dataset.ordered_indices()
def valid_size(self, index, max_positions):
"""Just use the tgt dataset size"""
return self.tgt_dataset.valid_size(index, max_positions)
def size(self, index): def size(self, index):
"""Return an example's size as a float or tuple. This value is used """Return an example's size as a float or tuple. This value is used
when filtering a dataset with ``--max-positions``. when filtering a dataset with ``--max-positions``.
......
...@@ -30,9 +30,10 @@ class CountingIterator(object): ...@@ -30,9 +30,10 @@ class CountingIterator(object):
self.iterable = iterable self.iterable = iterable
self.count = 0 self.count = 0
self.itr = iter(self) self.itr = iter(self)
self.len = len(iterable)
def __len__(self): def __len__(self):
return len(self.iterable) return self.len
def __iter__(self): def __iter__(self):
for x in self.iterable: for x in self.iterable:
...@@ -49,6 +50,7 @@ class CountingIterator(object): ...@@ -49,6 +50,7 @@ class CountingIterator(object):
def skip(self, num_to_skip): def skip(self, num_to_skip):
"""Fast-forward the iterator by skipping *num_to_skip* elements.""" """Fast-forward the iterator by skipping *num_to_skip* elements."""
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None) next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
self.len -= num_to_skip
return self return self
......
...@@ -104,13 +104,6 @@ class RoundRobinZipDatasets(FairseqDataset): ...@@ -104,13 +104,6 @@ class RoundRobinZipDatasets(FairseqDataset):
"""Ordered indices for batching.""" """Ordered indices for batching."""
return np.arange(len(self)) return np.arange(len(self))
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
return all(
dataset.valid_size(self._map_index(key, index), max_positions[key])
for key, dataset in self.datasets.items()
)
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return all( return all(
......
...@@ -37,8 +37,12 @@ class FairseqDecoder(nn.Module): ...@@ -37,8 +37,12 @@ class FairseqDecoder(nn.Module):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None:
assert sample is not None and 'target' in sample if sample is not None:
out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target']) assert 'target' in sample
target = sample['target']
else:
target = None
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
return out.exp_() if not log_probs else out return out.exp_() if not log_probs else out
logits = net_output[0].float() logits = net_output[0].float()
......
...@@ -67,8 +67,8 @@ class MultilingualTransformerModel(FairseqMultiModel): ...@@ -67,8 +67,8 @@ class MultilingualTransformerModel(FairseqMultiModel):
if not hasattr(args, 'max_target_positions'): if not hasattr(args, 'max_target_positions'):
args.max_target_positions = 1024 args.max_target_positions = 1024
src_langs = [lang_pair.split('-')[0] for lang_pair in args.lang_pairs] src_langs = [lang_pair.split('-')[0] for lang_pair in task.lang_pairs]
tgt_langs = [lang_pair.split('-')[1] for lang_pair in args.lang_pairs] tgt_langs = [lang_pair.split('-')[1] for lang_pair in task.lang_pairs]
if args.share_encoders: if args.share_encoders:
args.share_encoder_embeddings = True args.share_encoder_embeddings = True
...@@ -158,12 +158,21 @@ class MultilingualTransformerModel(FairseqMultiModel): ...@@ -158,12 +158,21 @@ class MultilingualTransformerModel(FairseqMultiModel):
shared_decoder = get_decoder(tgt_langs[0]) shared_decoder = get_decoder(tgt_langs[0])
encoders, decoders = OrderedDict(), OrderedDict() encoders, decoders = OrderedDict(), OrderedDict()
for lang_pair, src, tgt in zip(args.lang_pairs, src_langs, tgt_langs): for lang_pair, src, tgt in zip(task.lang_pairs, src_langs, tgt_langs):
encoders[lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(src) encoders[lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(src)
decoders[lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(tgt) decoders[lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(tgt)
return MultilingualTransformerModel(encoders, decoders) return MultilingualTransformerModel(encoders, decoders)
def load_state_dict(self, state_dict, strict=True):
state_dict_subset = state_dict.copy()
for k, v in state_dict.items():
assert k.startswith('models.')
lang_pair = k.split('.')[1]
if lang_pair not in self.models:
del state_dict_subset[k]
super().load_state_dict(state_dict_subset, strict=strict)
@register_model_architecture('multilingual_transformer', 'multilingual_transformer') @register_model_architecture('multilingual_transformer', 'multilingual_transformer')
def base_multilingual_architecture(args): def base_multilingual_architecture(args):
......
...@@ -171,10 +171,13 @@ class SequenceGenerator(object): ...@@ -171,10 +171,13 @@ class SequenceGenerator(object):
incremental_states[model] = None incremental_states[model] = None
# compute the encoder output for each beam # compute the encoder output for each beam
encoder_out = model.encoder(**encoder_input) if hasattr(model, 'encoder'):
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) encoder_out = model.encoder(**encoder_input)
new_order = new_order.to(src_tokens.device).long() new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
encoder_out = model.encoder.reorder_encoder_out(encoder_out, new_order) new_order = new_order.to(src_tokens.device).long()
encoder_out = model.encoder.reorder_encoder_out(encoder_out, new_order)
else:
encoder_out = None
encoder_outs.append(encoder_out) encoder_outs.append(encoder_out)
# initialize buffers # initialize buffers
...@@ -333,7 +336,8 @@ class SequenceGenerator(object): ...@@ -333,7 +336,8 @@ class SequenceGenerator(object):
for i, model in enumerate(self.models): for i, model in enumerate(self.models):
if isinstance(model.decoder, FairseqIncrementalDecoder): if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state) model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state) if encoder_outs is not None and hasattr(model, 'encoder'):
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
lprobs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states) lprobs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)
...@@ -564,7 +568,7 @@ class SequenceGenerator(object): ...@@ -564,7 +568,7 @@ class SequenceGenerator(object):
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model])) decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model]))
else: else:
decoder_out = list(model.decoder(tokens, encoder_out)) decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :] decoder_out[0] = decoder_out[0][:, -1:, :]
attn = decoder_out[1] attn = decoder_out[1]
if type(attn) is dict: if type(attn) is dict:
attn = attn['attn'] attn = attn['attn']
...@@ -573,4 +577,5 @@ class SequenceGenerator(object): ...@@ -573,4 +577,5 @@ class SequenceGenerator(object):
attn = attn['attn'] attn = attn['attn']
attn = attn[:, -1, :] attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs) probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
probs = probs[:, -1, :]
return probs, attn return probs, attn
...@@ -76,6 +76,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -76,6 +76,7 @@ class MultilingualTranslationTask(FairseqTask):
def __init__(self, args, dicts, training): def __init__(self, args, dicts, training):
super().__init__(args) super().__init__(args)
self.dicts = dicts self.dicts = dicts
self.lang_pairs = args.lang_pairs
self.langs = list(dicts.keys()) self.langs = list(dicts.keys())
self.training = training self.training = training
...@@ -132,11 +133,8 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -132,11 +133,8 @@ class MultilingualTranslationTask(FairseqTask):
return IndexedCachedDataset(path, fix_lua_indexing=True) return IndexedCachedDataset(path, fix_lua_indexing=True)
return None return None
def sort_lang_pair(lang_pair):
return '-'.join(sorted(lang_pair.split('-')))
src_datasets, tgt_datasets = {}, {} src_datasets, tgt_datasets = {}, {}
for lang_pair in set(map(sort_lang_pair, self.args.lang_pairs)): for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-') src, tgt = lang_pair.split('-')
if split_exists(split, src, tgt, src): if split_exists(split, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt)) prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
...@@ -153,11 +151,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -153,11 +151,7 @@ class MultilingualTranslationTask(FairseqTask):
def language_pair_dataset(lang_pair): def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split('-') src, tgt = lang_pair.split('-')
if lang_pair in src_datasets: src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
else:
lang_pair = sort_lang_pair(lang_pair)
tgt_dataset, src_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
return LanguagePairDataset( return LanguagePairDataset(
src_dataset, src_dataset.sizes, self.dicts[src], src_dataset, src_dataset.sizes, self.dicts[src],
tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
...@@ -172,7 +166,16 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -172,7 +166,16 @@ class MultilingualTranslationTask(FairseqTask):
(lang_pair, language_pair_dataset(lang_pair)) (lang_pair, language_pair_dataset(lang_pair))
for lang_pair in self.args.lang_pairs for lang_pair in self.args.lang_pairs
]), ]),
eval_key=None if self.training else self.args.lang_pairs[0], eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang),
)
def build_dataset_for_inference(self, src_tokens, src_lengths):
lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang)
return RoundRobinZipDatasets(
OrderedDict([
(lang_pair, LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary))
]),
eval_key=lang_pair,
) )
def build_model(self, args): def build_model(self, args):
......
...@@ -192,6 +192,9 @@ class TranslationTask(FairseqTask): ...@@ -192,6 +192,9 @@ class TranslationTask(FairseqTask):
max_target_positions=self.args.max_target_positions, max_target_positions=self.args.max_target_positions,
) )
def build_dataset_for_inference(self, src_tokens, src_lengths):
return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary)
def max_positions(self): def max_positions(self):
"""Return the max sentence length allowed by the task.""" """Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions) return (self.args.max_source_positions, self.args.max_target_positions)
......
...@@ -39,7 +39,10 @@ def main(args): ...@@ -39,7 +39,10 @@ def main(args):
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Set dictionaries # Set dictionaries
src_dict = task.source_dictionary try:
src_dict = getattr(task, 'source_dictionary', None)
except NotImplementedError:
src_dict = None
tgt_dict = task.target_dictionary tgt_dict = task.target_dictionary
# Load ensemble # Load ensemble
...@@ -121,12 +124,16 @@ def main(args): ...@@ -121,12 +124,16 @@ def main(args):
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
else: else:
src_str = src_dict.string(src_tokens, args.remove_bpe) if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
else:
src_str = ""
if has_target: if has_target:
target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet: if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str)) if src_dict is not None:
print('S-{}\t{}'.format(sample_id, src_str))
if has_target: if has_target:
print('T-{}\t{}'.format(sample_id, target_str)) print('T-{}\t{}'.format(sample_id, target_str))
......
...@@ -43,7 +43,7 @@ def make_batches(lines, args, task, max_positions): ...@@ -43,7 +43,7 @@ def make_batches(lines, args, task, max_positions):
] ]
lengths = np.array([t.numel() for t in tokens]) lengths = np.array([t.numel() for t in tokens])
itr = task.get_batch_iterator( itr = task.get_batch_iterator(
dataset=data.LanguagePairDataset(tokens, lengths, task.source_dictionary), dataset=task.build_dataset_for_inference(tokens, lengths),
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
max_sentences=args.max_sentences, max_sentences=args.max_sentences,
max_positions=max_positions, max_positions=max_positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment