Unverified Commit fe4e185a authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge internal changes

Changelog:
- `f472d141`: Support tied embeddings in LSTM encoder/decoder
- `89e19d42`: Don't print alignment by default (use `--print-alignment` to re-enable it)
- `d2e2a1d4`: Add Transformer-based language model
- `c2794070`: Add new Transformer configuration for IWSLT
- `2fbfda0d`: Misc changes for pytorch-translate
- Miscellaneous bug fixes
parents 7358296b 2fbfda0d
......@@ -37,11 +37,13 @@ def main(args):
if args.fp16:
model.half()
assert len(models) > 0
itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences or 4,
max_positions=model.max_positions(),
max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences,
max_positions=models[0].max_positions(),
num_shards=args.num_shards,
shard_id=args.shard_id,
ignore_invalid_inputs=True,
......@@ -54,19 +56,51 @@ def main(args):
score_sum = 0.
count = 0
if args.remove_bpe is not None:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
bpe_len = len(bpe_cont)
else:
bpe_toks = None
bpe_len = 0
with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results:
for hypo in hypos:
pos_scores = hypo['positional_scores']
skipped_toks = 0
if bpe_toks is not None:
for i in range(len(hypo['tokens']) - 1):
if hypo['tokens'][i].item() in bpe_toks:
skipped_toks += 1
pos_scores[i + 1] += pos_scores[i]
pos_scores[i] = 0
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum()
count += pos_scores.numel()
count += pos_scores.numel() - skipped_toks
if args.output_word_probs:
w = ''
word_prob = []
for i in range(len(hypo['tokens'])):
w_ind = hypo['tokens'][i].item()
w += task.dictionary[w_ind]
if bpe_toks is not None and w_ind in bpe_toks:
w = w[:-bpe_len]
else:
word_prob.append((w, pos_scores[i].item()))
w = ''
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
......
......@@ -36,6 +36,31 @@ $ python generate.py data-bin/iwslt14.tokenized.de-en \
```
To train transformer model on IWSLT'14 German to English:
```
# Preparation steps are the same as for fconv model.
# Train the model (better for a single GPU setup):
$ mkdir -p checkpoints/transformer
$ CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
-a transformer_iwslt_de_en --optimizer adam --lr 0.0005 -s de -t en \
--label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 \
--min-lr '1e-09' --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --max-update 50000 \
--warmup-updates 4000 --warmup-init-lr '1e-07' \
--adam-betas '(0.9, 0.98)' --save-dir checkpoints/transformer
# Average 10 latest checkpoints:
$ python scripts/average_checkpoints.py --inputs checkpoints/transformer \
--num-epoch-checkpoints 10 --output checkpoints/transformer/model.pt
# Generate:
$ python generate.py data-bin/iwslt14.tokenized.de-en \
--path checkpoints/transformer/model.pt \
--batch-size 128 --beam 5 --remove-bpe
```
### prepare-wmt14en2de.sh
......
......@@ -7,7 +7,7 @@
from .dictionary import Dictionary
from .fairseq_dataset import FairseqDataset
from .indexed_dataset import IndexedInMemoryDataset, IndexedRawTextDataset
from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset # noqa: F401
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset
......
......@@ -47,7 +47,7 @@ class TokenBlockDataset(torch.utils.data.Dataset):
self.slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete':
assert sizes is not None and sum(sizes) == len(tokens)
assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
tok_idx = 0
sz_idx = 0
curr_size = 0
......@@ -62,7 +62,7 @@ class TokenBlockDataset(torch.utils.data.Dataset):
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
assert sizes is not None and sum(sizes) == len(tokens)
assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
curr = 0
for sz in sizes:
# skip samples with just 1 example (which would be just the eos token)
......@@ -76,13 +76,17 @@ class TokenBlockDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
s, e = self.slice_indices[index]
item = torch.LongTensor(self.tokens[s:e])
if self.include_targets:
if e == self.total_size:
return item[:-1], item[1:]
else:
return item, torch.LongTensor(self.tokens[s + 1:e + 1])
# target is the sentence, for source, rotate item one token to the left (would start with eos)
if s == 0:
source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]])
else:
source = self.tokens[s - 1:e - 1]
return torch.LongTensor(source), item
return item
def __len__(self):
......
......@@ -19,8 +19,14 @@ class FairseqDecoder(nn.Module):
def forward(self, prev_output_tokens, encoder_out):
raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs, _):
def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None:
assert sample is not None and 'target' in sample
out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target'])
return out.exp_() if not log_probs else out
logits = net_output[0].float()
if log_probs:
return F.log_softmax(logits, dim=-1)
......
......@@ -268,16 +268,16 @@ class FConvEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = (
encoder_out_dict['encoder_out'][0].index_select(0, new_order),
encoder_out_dict['encoder_out'][1].index_select(0, new_order),
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out['encoder_out'] is not None:
encoder_out['encoder_out'] = (
encoder_out['encoder_out'][0].index_select(0, new_order),
encoder_out['encoder_out'][1].index_select(0, new_order),
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
......@@ -352,6 +352,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
self.dropout = dropout
self.normalization_constant = normalization_constant
self.left_pad = left_pad
self.need_attn = True
convolutions = extend_conv_spec(convolutions)
in_channels = convolutions[0][0]
......@@ -466,6 +467,8 @@ class FConvDecoder(FairseqIncrementalDecoder):
x = self._transpose_if_training(x, incremental_state)
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask)
if not self.training and self.need_attn:
attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None:
avg_attn_scores = attn_scores
......@@ -490,16 +493,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x, avg_attn_scores
def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""
if self.adaptive_softmax is not None:
assert sample is not None and 'target' in sample
out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target'])
return out.exp_() if not log_probs else out
else:
return super().get_normalized_probs(net_output, log_probs, sample)
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
......@@ -521,6 +514,9 @@ class FConvDecoder(FairseqIncrementalDecoder):
state_dict['decoder.version'] = torch.Tensor([1])
return state_dict
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def _embed_tokens(self, tokens, incremental_state):
if incremental_state is not None:
# keep only the last token for incremental forward pass
......
......@@ -226,18 +226,18 @@ class FConvEncoder(FairseqEncoder):
'encoder_out': (x, y),
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder_out']
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out['encoder_out']
)
if 'pretrained' in encoder_out_dict:
encoder_out_dict['pretrained']['encoder_out'] = tuple(
if 'pretrained' in encoder_out:
encoder_out['pretrained']['encoder_out'] = tuple(
eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder_out']
for eo in encoder_out['pretrained']['encoder_out']
)
return encoder_out_dict
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
......@@ -259,6 +259,7 @@ class FConvDecoder(FairseqDecoder):
self.pretrained_decoder = trained_decoder
self.dropout = dropout
self.left_pad = left_pad
self.need_attn = True
in_channels = convolutions[0][0]
def expand_bool_array(val):
......@@ -388,6 +389,7 @@ class FConvDecoder(FairseqDecoder):
r = x
x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b)
x = x + r
if not self.training and self.need_attn:
if avg_attn_scores is None:
avg_attn_scores = attn_scores
else:
......@@ -426,6 +428,9 @@ class FConvDecoder(FairseqDecoder):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs."""
# transpose only once to speed up attention layers
......
......@@ -59,6 +59,12 @@ class LSTMModel(FairseqModel):
help='dropout probability for decoder input embedding')
parser.add_argument('--decoder-dropout-out', type=float, metavar='D',
help='dropout probability for decoder output')
parser.add_argument('--share-decoder-input-output-embed', default=False,
action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
@classmethod
def build_model(cls, args, task):
......@@ -74,14 +80,47 @@ class LSTMModel(FairseqModel):
utils.print_embed_overlap(embed_dict, dictionary)
return utils.load_embedding(embed_dict, dictionary, embed_tokens)
pretrained_encoder_embed = None
if args.encoder_embed_path:
pretrained_encoder_embed = load_pretrained_embedding_from_file(
args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim)
else:
num_embeddings = len(task.source_dictionary)
pretrained_encoder_embed = Embedding(
num_embeddings, args.encoder_embed_dim, task.source_dictionary.pad()
)
if args.share_all_embeddings:
# double check all parameters combinations are valid
if task.source_dictionary != task.target_dictionary:
raise RuntimeError('--share-all-embeddings requires a joint dictionary')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise RuntimeError(
'--share-all-embed not compatible with --decoder-embed-path'
)
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to '
'match --decoder-embed-dim'
)
pretrained_decoder_embed = pretrained_encoder_embed
args.share_decoder_input_output_embed = True
else:
# separate decoder input embeddings
pretrained_decoder_embed = None
if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim)
args.decoder_embed_path,
task.target_dictionary,
args.decoder_embed_dim
)
# one last double check of parameter combinations
if args.share_decoder_input_output_embed and (
args.decoder_embed_dim != args.decoder_out_embed_dim):
raise RuntimeError(
'--share-decoder-input-output-embeddings requires '
'--decoder-embed-dim to match --decoder-out-embed-dim'
)
encoder = LSTMEncoder(
dictionary=task.source_dictionary,
......@@ -105,6 +144,7 @@ class LSTMModel(FairseqModel):
encoder_embed_dim=args.encoder_embed_dim,
encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed,
share_input_output_embed=args.share_decoder_input_output_embed,
)
return cls(encoder, decoder)
......@@ -197,15 +237,15 @@ class LSTMEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out['encoder_out'] = tuple(
eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out']
for eo in encoder_out['encoder_out']
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
......@@ -251,11 +291,14 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_embed_dim=512, encoder_output_units=512, pretrained_embed=None,
share_input_output_embed=False,
):
super().__init__(dictionary)
self.dropout_in = dropout_in
self.dropout_out = dropout_out
self.hidden_size = hidden_size
self.share_input_output_embed = share_input_output_embed
self.need_attn = True
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
......@@ -279,6 +322,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.attention = AttentionLayer(encoder_output_units, hidden_size) if attention else None
if hidden_size != out_embed_dim:
self.additional_fc = Linear(hidden_size, out_embed_dim)
if not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
......@@ -352,12 +396,18 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x = x.transpose(1, 0)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
if not self.training and self.need_attn:
attn_scores = attn_scores.transpose(0, 2)
else:
attn_scores = None
# project back to size of vocabulary
if hasattr(self, 'additional_fc'):
x = self.additional_fc(x)
x = F.dropout(x, p=self.dropout_out, training=self.training)
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = self.fc_out(x)
return x, attn_scores
......@@ -380,6 +430,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
......@@ -405,7 +458,7 @@ def LSTMCell(input_size, hidden_size, **kwargs):
def Linear(in_features, out_features, bias=True, dropout=0):
"""Weight-normalized Linear layer (input: N x T x C)"""
"""Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.uniform_(-0.1, 0.1)
if bias:
......@@ -431,6 +484,8 @@ def base_architecture(args):
args.decoder_attention = getattr(args, 'decoder_attention', '1')
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
......
This diff is collapsed.
......@@ -30,7 +30,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor())
self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
......
......@@ -16,7 +16,7 @@ class FixedSchedule(FairseqLRScheduler):
super().__init__(args, optimizer)
# set defaults
args.warmup_updates = getattr(args, 'warmup_updates', 0)
args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0
self.lr = args.lr[0]
if args.warmup_updates > 0:
......
......@@ -62,7 +62,7 @@ def eval_bool(x, default=False):
return default
def parse_args_and_arch(parser, input_args=None):
def parse_args_and_arch(parser, input_args=None, parse_known=False):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments.
......@@ -90,7 +90,11 @@ def parse_args_and_arch(parser, input_args=None):
TASK_REGISTRY[args.task].add_args(parser)
# Parse a second time.
if parse_known:
args, extra = parser.parse_known_args(input_args)
else:
args = parser.parse_args(input_args)
extra = None
# Post-process args.
if hasattr(args, 'lr'):
......@@ -104,6 +108,9 @@ def parse_args_and_arch(parser, input_args=None):
if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args)
if parse_known:
return args, extra
else:
return args
......@@ -249,6 +256,8 @@ def add_common_eval_args(group):
def add_eval_lm_args(parser):
group = parser.add_argument_group('LM Evaluation')
add_common_eval_args(group)
group.add_argument('--output-word-probs', action='store_true',
help='if set, outputs words and their predicted log probabilities to standard output')
def add_generation_args(parser):
......@@ -290,6 +299,8 @@ def add_generation_args(parser):
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling')
group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training')
return group
......
......@@ -126,8 +126,8 @@ class SequenceGenerator(object):
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn_buf = attn.clone()
attn, attn_buf = None, None
nonpad_idxs = None
# list of completed sentences
finalized = [[] for i in range(bsz)]
......@@ -191,7 +191,7 @@ class SequenceGenerator(object):
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2]
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
......@@ -220,10 +220,13 @@ class SequenceGenerator(object):
def get_hypo():
if attn_clone is not None:
# remove padding tokens from attn scores
nonpad_idxs = src_tokens[sent].ne(self.pad)
hypo_attn = attn_clone[i][nonpad_idxs]
hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0)
else:
hypo_attn = None
alignment = None
return {
'tokens': tokens_clone[i],
......@@ -270,8 +273,7 @@ class SequenceGenerator(object):
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode(
tokens[:, :step + 1], encoder_outs, incremental_states)
probs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
......@@ -286,6 +288,11 @@ class SequenceGenerator(object):
probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores
if avg_attn_scores is not None:
if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores)
......@@ -417,6 +424,7 @@ class SequenceGenerator(object):
scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens)
if attn is not None:
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz
......@@ -473,6 +481,7 @@ class SequenceGenerator(object):
)
# copy attention for active hypotheses
if attn is not None:
torch.index_select(
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step + 2],
......@@ -481,6 +490,7 @@ class SequenceGenerator(object):
# swap buffers
tokens, tokens_buf = tokens_buf, tokens
scores, scores_buf = scores_buf, scores
if attn is not None:
attn, attn_buf = attn_buf, attn
# reorder incremental state in decoder
......@@ -518,7 +528,7 @@ class SequenceGenerator(object):
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
with torch.no_grad():
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
......
......@@ -5,9 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import criterions, models
from fairseq.data import FairseqDataset
class FairseqTask(object):
"""
......@@ -28,11 +25,12 @@ class FairseqTask(object):
def setup_task(cls, args, **kwargs):
raise NotImplementedError
def load_dataset(self, split):
def load_dataset(self, split, combine=False):
raise NotImplementedError
def dataset(self, split):
"""Return a dataset split."""
from fairseq.data import FairseqDataset
if split not in self.datasets:
raise KeyError('Dataset not loaded: ' + split)
if not isinstance(self.datasets[split], FairseqDataset):
......@@ -40,9 +38,11 @@ class FairseqTask(object):
return self.datasets[split]
def build_model(self, args):
from fairseq import models
return models.build_model(args, self)
def build_criterion(self, args):
from fairseq import criterions
return criterions.build_criterion(args, self)
def get_loss(self, model, criterion, sample):
......
......@@ -5,8 +5,12 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os
from torch.utils.data import ConcatDataset
from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset,
......@@ -43,23 +47,46 @@ class LanguageModelingTask(FairseqTask):
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split):
def load_dataset(self, split, combine=False):
"""Load a dataset split."""
path = os.path.join(self.args.data, split)
loaded_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(self.args.data, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = ds.tokens_list
tokens = [t for l in ds.tokens_list for t in l]
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
tokens = ds.buffer
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
dataset = TokenBlockDataset(
loaded_datasets.append(
TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True, # return next tokens as targets
)
self.datasets[split] = MonolingualDataset(dataset, dataset.sizes, self.dictionary, shuffle=False)
include_targets=True
))
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
if not combine:
break
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False)
@property
def target_dictionary(self):
......
......@@ -5,8 +5,12 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os
from torch.utils.data import ConcatDataset
from fairseq import options
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
......@@ -65,10 +69,10 @@ class TranslationTask(FairseqTask):
return cls(args, src_dict, tgt_dict)
def load_dataset(self, split):
def load_dataset(self, split, combine=False):
"""Load a dataset split."""
def split_exists(src, tgt, lang):
def split_exists(split, src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
......@@ -76,15 +80,6 @@ class TranslationTask(FairseqTask):
return True
return False
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
......@@ -92,11 +87,48 @@ class TranslationTask(FairseqTask):
return IndexedInMemoryDataset(path, fix_lua_indexing=True)
return None
src_dataset = indexed_dataset(prefix + src, self.src_dict)
tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict)
src_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(split_k, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))
print('| {} {} {} examples'.format(self.args.data, split_k, len(src_datasets[-1])))
if not combine:
break
assert len(src_datasets) == len(tgt_datasets)
if len(src_datasets) == 1:
src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
src_sizes = src_dataset.sizes
tgt_sizes = tgt_dataset.sizes
else:
src_dataset = ConcatDataset(src_datasets)
tgt_dataset = ConcatDataset(tgt_datasets)
src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])
self.datasets[split] = LanguagePairDataset(
src_dataset, src_dataset.sizes, self.src_dict,
tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
src_dataset, src_sizes, self.src_dict,
tgt_dataset, tgt_sizes, self.tgt_dict,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
......
......@@ -140,6 +140,11 @@ class Trainer(object):
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)
if ooms_fwd == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch')
self.zero_grad()
return None
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
......@@ -178,11 +183,6 @@ class Trainer(object):
return None # buffering updates
def _forward(self, sample, eval=False):
# prepare model and optimizer
if eval:
self.model.eval()
else:
self.model.train()
loss = None
sample_size = 0
logging_output = {
......@@ -190,8 +190,15 @@ class Trainer(object):
'nsentences': sample['target'].size(0) if sample is not None else 0,
}
oom = 0
if sample is not None:
try:
# prepare model and optimizer
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
if sample is not None:
with torch.no_grad() if eval else contextlib.ExitStack():
# calculate loss and sample size
loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
......
......@@ -42,7 +42,10 @@ def main(args):
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
......@@ -115,7 +118,7 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
......@@ -130,6 +133,8 @@ def main(args):
hypo['positional_scores'].tolist(),
))
))
if args.print_alignment:
print('A-{}\t{}'.format(
sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
......
......@@ -17,7 +17,7 @@ from fairseq.sequence_generator import SequenceGenerator
Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos alignments')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
def buffered_read(buffer_size):
......@@ -81,7 +81,10 @@ def main(args):
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
......@@ -104,6 +107,7 @@ def main(args):
result = Translation(
src_str='O\t{}'.format(src_str),
hypos=[],
pos_scores=[],
alignments=[],
)
......@@ -112,13 +116,22 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
result.alignments.append('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment))))
result.pos_scores.append('P\t{}'.format(
' '.join(map(
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
))
))
result.alignments.append(
'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
if args.print_alignment else None
)
return result
def process_batch(batch):
......@@ -150,8 +163,10 @@ def main(args):
for i in np.argsort(indices):
result = results[i]
print(result.src_str)
for hypo, align in zip(result.hypos, result.alignments):
for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
print(hypo)
print(pos_scores)
if align is not None:
print(align)
......
......@@ -203,6 +203,7 @@ def generate_main(data_dir, extra_flags=None):
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
'--print-alignment',
] + (extra_flags or []),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment