Commit cf1c64a5 authored by Myle Ott's avatar Myle Ott
Browse files

Nits

parent 6eda8e47
......@@ -81,10 +81,7 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal
dst.copy_(src)
for i, v in enumerate(values):
if left_pad:
copy_tensor(v, res[i][size - len(v):])
else:
copy_tensor(v, res[i][:len(v)])
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
return res
......
......@@ -5,9 +5,10 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import os
import struct
import numpy as np
import torch
import torch.utils.data
......@@ -197,10 +198,8 @@ class IndexedDatasetBuilder(object):
index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype),
self.element_size))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1,
len(self.sizes)))
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
......
......@@ -12,6 +12,7 @@ from .fairseq_decoder import FairseqDecoder # noqa: F401
from .fairseq_encoder import FairseqEncoder # noqa: F401
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401
from .composite_encoder import CompositeEncoder # noqa: F401
MODEL_REGISTRY = {}
......
......@@ -12,10 +12,15 @@ import torch.nn.functional as F
from fairseq import options, utils
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution, AdaptiveSoftmax
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, FairseqLanguageModel, register_model, \
register_model_architecture
from fairseq.modules import (
AdaptiveSoftmax, BeamableMM, GradMultiply, LearnedPositionalEmbedding,
LinearizedConvolution,
)
from . import (
FairseqEncoder, FairseqIncrementalDecoder, FairseqModel,
FairseqLanguageModel, register_model, register_model_architecture,
)
@register_model('fconv')
......
......@@ -7,16 +7,23 @@
#
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data import LanguagePairDataset
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution, DownsampledMultiHeadAttention
from fairseq.modules import (
DownsampledMultiHeadAttention, GradMultiply, LearnedPositionalEmbedding,
LinearizedConvolution,
)
from fairseq import utils
from . import FairseqEncoder, CompositeEncoder, FairseqDecoder, FairseqModel, register_model, register_model_architecture
from . import (
FairseqEncoder, CompositeEncoder, FairseqDecoder, FairseqModel,
register_model, register_model_architecture,
)
@register_model('fconv_self_att')
class FConvModelSelfAtt(FairseqModel):
......@@ -76,7 +83,6 @@ class FConvModelSelfAtt(FairseqModel):
pretrained = eval(args.pretrained)
if pretrained:
print("| Loading pretrained model")
state = torch.load(args.pretrained_checkpoint)
trained_model = utils.load_ensemble_for_inference(
# not actually for inference, but loads pretrained model parameters
filenames=[args.pretrained_checkpoint],
......@@ -131,9 +137,11 @@ class FConvModelSelfAtt(FairseqModel):
class FConvEncoder(FairseqEncoder):
"""Convolutional encoder"""
def __init__(self, dictionary, embed_dim=512, max_positions=1024,
def __init__(
self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, attention=False,
attention_nheads=1):
attention_nheads=1,
):
super().__init__(dictionary)
self.dropout = dropout
self.num_attention_layers = None
......@@ -163,20 +171,20 @@ class FConvEncoder(FairseqEncoder):
self.attention = nn.ModuleList()
self.attproj = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions):
self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None)
self.projections.append(
Linear(in_channels, out_channels) if in_channels != out_channels else None
)
self.convolutions.append(
ConvTBC(in_channels, out_channels * 2, kernel_size,
dropout=dropout))
ConvTBC(in_channels, out_channels * 2, kernel_size, dropout=dropout)
)
self.attention.append(SelfAttention(out_channels, embed_dim,
attention_nheads)
if attention[i] else None)
self.attention.append(
SelfAttention(out_channels, embed_dim, attention_nheads) if attention[i] else None
)
in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim)
def forward(self, src_tokens, src_lengths):
# embed tokens and positions
x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens)
......@@ -226,18 +234,20 @@ class FConvEncoder(FairseqEncoder):
class FConvDecoder(FairseqDecoder):
"""Convolutional decoder"""
def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 8,
attention=True, dropout=0.1, selfattention=False,
attention_nheads=1, selfattention_nheads=1,
def __init__(
self, dictionary, embed_dim=512, out_embed_dim=256, max_positions=1024,
convolutions=((512, 3),) * 8, attention=True, dropout=0.1,
selfattention=False, attention_nheads=1, selfattention_nheads=1,
project_input=False, gated_attention=False, downsample=False,
pretrained=False, trained_decoder=None):
pretrained=False, trained_decoder=None,
):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2]))
self.pretrained = pretrained
self.pretrained_decoder = trained_decoder
self.dropout = dropout
in_channels = convolutions[0][0]
def expand_bool_array(val):
if isinstance(val, bool):
# expand True into [True, True, ...] and do the same with False
......@@ -269,27 +279,33 @@ class FConvDecoder(FairseqDecoder):
self.selfattention = nn.ModuleList()
self.attproj = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions):
pad = kernel_size - 1
self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None)
self.projections.append(
Linear(in_channels, out_channels) if in_channels != out_channels else None
)
self.convolutions.append(
LinearizedConv1d(in_channels, out_channels * 2, kernel_size,
padding=(kernel_size - 1), dropout=dropout))
self.attention.append(DownsampledMultiHeadAttention(out_channels, embed_dim,
attention_nheads,
project_input=project_input,
gated=False, downsample=False)
if attention[i] else None)
self.attproj.append(Linear(out_channels, embed_dim, dropout=dropout)
if attention[i] else None)
self.selfattention.append(SelfAttention(out_channels, embed_dim,
selfattention_nheads,
project_input=project_input,
gated=gated_attention,
downsample=downsample)
if selfattention[i] else None)
LinearizedConv1d(
in_channels, out_channels * 2, kernel_size,
padding=(kernel_size - 1), dropout=dropout,
)
)
self.attention.append(
DownsampledMultiHeadAttention(
out_channels, embed_dim, attention_nheads,
project_input=project_input, gated=False, downsample=False,
) if attention[i] else None
)
self.attproj.append(
Linear(out_channels, embed_dim, dropout=dropout) if attention[i] else None
)
self.selfattention.append(
SelfAttention(
out_channels, embed_dim, selfattention_nheads,
project_input=project_input, gated=gated_attention,
downsample=downsample,
) if selfattention[i] else None
)
in_channels = out_channels
self.fc2 = Linear(in_channels, out_embed_dim)
......@@ -301,24 +317,27 @@ class FConvDecoder(FairseqDecoder):
self.gate1 = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim), nn.Sigmoid())
self.gate2 = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim), nn.Sigmoid())
# pretrained and trained models are joined
self.joining = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim*2),
self.joining = nn.Sequential(
Linear(out_embed_dim*2, out_embed_dim*2),
nn.LayerNorm(out_embed_dim*2),
nn.GLU(),
Linear(out_embed_dim, out_embed_dim*2),
nn.LayerNorm(out_embed_dim*2),
nn.GLU(),
Linear(out_embed_dim, out_embed_dim),
nn.LayerNorm(out_embed_dim))
nn.LayerNorm(out_embed_dim)
)
# pretrained model contains an output layer that is nhid -> vocab size
# but the models are combined in their hidden state
# the hook stores the output of the pretrained model forward
self.pretrained_outputs = {}
def save_output():
def hook(a, b, output):
self.pretrained_outputs["out"] = output
return hook
self.pretrained_decoder.fc2.register_forward_hook(save_output())
self.pretrained_decoder.fc2.register_forward_hook(save_output())
def forward(self, prev_output_tokens, encoder_out_dict):
encoder_out = encoder_out_dict['encoder']['encoder_out']
......@@ -342,11 +361,9 @@ class FConvDecoder(FairseqDecoder):
# temporal convolutions
avg_attn_scores = None
for proj, conv, attention, selfattention, attproj in zip(self.projections,
self.convolutions,
self.attention,
self.selfattention,
self.attproj):
for proj, conv, attention, selfattention, attproj in zip(
self.projections, self.convolutions, self.attention, self.selfattention, self.attproj
):
residual = x if proj is None else proj(x)
x = F.dropout(x, p=self.dropout, training=self.training)
......@@ -398,11 +415,14 @@ class FConvDecoder(FairseqDecoder):
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder']['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder']['encoder_out'])
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder']['encoder_out']
)
if 'pretrained' in encoder_out_dict:
encoder_out_dict['pretrained']['encoder']['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['pretrained']['encoder']['encoder_out'])
eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder']['encoder_out']
)
return encoder_out_dict
......@@ -425,8 +445,10 @@ class SelfAttention(nn.Module):
def __init__(self, out_channels, embed_dim, num_heads, project_input=False, gated=False, downsample=False):
super().__init__()
self.attention = DownsampledMultiHeadAttention(out_channels, embed_dim, num_heads,
dropout=0, bias=True, project_input=project_input, gated=gated, downsample=downsample)
self.attention = DownsampledMultiHeadAttention(
out_channels, embed_dim, num_heads, dropout=0, bias=True,
project_input=project_input, gated=gated, downsample=downsample,
)
self.in_proj_q = Linear(out_channels, embed_dim)
self.in_proj_k = Linear(out_channels, embed_dim)
self.in_proj_v = Linear(out_channels, embed_dim)
......@@ -441,7 +463,6 @@ class SelfAttention(nn.Module):
return self.ln(x + residual)
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, 0.1)
......
......@@ -10,10 +10,13 @@ from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq import options, utils
from fairseq.data import consts
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, register_model_architecture
from . import (
FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model,
register_model_architecture,
)
@register_model('lstm')
......@@ -92,10 +95,6 @@ class LSTMModel(FairseqModel):
bidirectional=args.encoder_bidirectional,
pretrained_embed=pretrained_encoder_embed,
)
try:
attention = bool(eval(args.decoder_attention))
except TypeError:
attention = bool(args.decoder_attention)
decoder = LSTMDecoder(
dictionary=dst_dict,
embed_dim=args.decoder_embed_dim,
......@@ -104,7 +103,7 @@ class LSTMModel(FairseqModel):
num_layers=args.decoder_layers,
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
attention=attention,
attention=options.eval_bool(args.decoder_attention),
encoder_embed_dim=args.encoder_embed_dim,
encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed,
......
......@@ -435,6 +435,7 @@ def transformer_vaswani_wmt_en_de_big(args):
args.dropout = getattr(args, 'dropout', 0.3)
base_architecture(args)
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_fr_big')
def transformer_vaswani_wmt_en_fr_big(args):
args.dropout = getattr(args, 'dropout', 0.1)
......
......@@ -11,8 +11,11 @@ from torch import nn
class AdaptiveSoftmax(nn.Module):
"""This is an implementation of the efficient softmax approximation for graphical processing units (GPU),
described in the paper "Efficient softmax approximation for GPUs" (http://arxiv.org/abs/1609.04309)."""
"""
This is an implementation of the efficient softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax
approximation for GPUs" (http://arxiv.org/abs/1609.04309).
"""
def __init__(self, vocab_size, input_dim, cutoff, dropout):
super().__init__()
......@@ -46,9 +49,12 @@ class AdaptiveSoftmax(nn.Module):
self.apply(init_weights)
def adapt_target(self, target):
"""In order to be efficient, the AdaptiveSoftMax does not compute the scores for all the word of the
vocabulary for all the examples.It is thus necessary to call the method adapt_target of the AdaptiveSoftMax
layer inside each forward pass."""
"""
In order to be efficient, the AdaptiveSoftMax does not compute the
scores for all the word of the vocabulary for all the examples. It is
thus necessary to call the method adapt_target of the AdaptiveSoftMax
layer inside each forward pass.
"""
target = target.view(-1)
new_target = [target.clone()]
......@@ -68,8 +74,13 @@ class AdaptiveSoftmax(nn.Module):
return new_target, target_idxs
def forward(self, input, target):
""" accepts input (b x t x d) and target (b x t) and returns
2 lists: output for each cutoff section and new targets by cut off """
"""
Args:
input: (b x t x d)
target: (b x t)
Returns:
2 lists: output for each cutoff section and new targets by cut off
"""
input = input.contiguous().view(-1, input.size(-1))
input = F.dropout(input, p=self.dropout, training=self.training)
......@@ -86,7 +97,10 @@ class AdaptiveSoftmax(nn.Module):
return output, new_target
def get_log_prob(self, input, target):
"""computes the log probabilities for all the words of the vocabulary, given a 2D tensor of hidden vectors"""
"""
Computes the log probabilities for all the words of the vocabulary,
given a 2D tensor of hidden vectors.
"""
bsz, length, dim = input.size()
input = input.contiguous().view(-1, dim)
......
......@@ -20,17 +20,9 @@ class SingleHeadAttention(nn.Module):
Single-head attention that supports Gating and Downsampling
"""
def __init__(
self,
out_channels,
embed_dim,
head_dim,
head_index,
dropout=0.,
bias=True,
project_input=True,
gated=False,
downsample=False,
num_heads=1
self, out_channels, embed_dim, head_dim, head_index, dropout=0.,
bias=True, project_input=True, gated=False, downsample=False,
num_heads=1,
):
super().__init__()
self.embed_dim = embed_dim
......@@ -71,13 +63,8 @@ class SingleHeadAttention(nn.Module):
self.scaling = self.head_dim**-0.5
def forward(
self,
query,
key,
value,
mask_future_timesteps=False,
key_padding_mask=None,
use_scalar_bias=False
self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, use_scalar_bias=False,
):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
......@@ -168,15 +155,8 @@ class DownsampledMultiHeadAttention(nn.ModuleList):
Multi-headed attention with Gating and Downsampling
"""
def __init__(
self,
out_channels,
embed_dim,
num_heads,
dropout=0.,
bias=True,
project_input=True,
gated=False,
downsample=False
self, out_channels, embed_dim, num_heads, dropout=0., bias=True,
project_input=True, gated=False, downsample=False,
):
self.embed_dim = embed_dim
self.num_heads = num_heads
......@@ -190,24 +170,27 @@ class DownsampledMultiHeadAttention(nn.ModuleList):
if self.downsample:
attention_heads = []
for index in range(self.num_heads):
attention_heads.append(SingleHeadAttention(out_channels, self.embed_dim, self.head_dim, index, self.dropout, bias, self.project_input, self.gated, self.downsample, self.num_heads))
attention_heads.append(
SingleHeadAttention(
out_channels, self.embed_dim, self.head_dim, index,
self.dropout, bias, self.project_input, self.gated,
self.downsample, self.num_heads,
)
)
super().__init__(modules=attention_heads)
self.out_proj = Linear(embed_dim, out_channels, bias=bias)
else:
# either we have a list of attention heads, or just one attention head
# if not being downsampled, we can do the heads with one linear layer instead of separate ones
super().__init__()
self.attention_module = SingleHeadAttention(out_channels, self.embed_dim, self.head_dim, 1, self.dropout, bias, self.project_input, self.gated, self.downsample, self.num_heads)
self.attention_module = SingleHeadAttention(
out_channels, self.embed_dim, self.head_dim, 1, self.dropout,
bias, self.project_input, self.gated, self.downsample, self.num_heads,
)
def forward(
self,
query,
key,
value,
mask_future_timesteps=False,
key_padding_mask=None,
use_scalar_bias=False
self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, use_scalar_bias=False,
):
src_len, bsz, embed_dim = key.size()
tgt_len = query.size(0)
......@@ -224,14 +207,18 @@ class DownsampledMultiHeadAttention(nn.ModuleList):
if self.downsample:
for attention_head_number in range(self.num_heads):
# call the forward of each attention head
_attn, _attn_weight = self[attention_head_number](query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias)
_attn, _attn_weight = self[attention_head_number](
query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
)
attn.append(_attn)
attn_weights.append(_attn_weight)
full_attn = torch.cat(attn, dim=2)
full_attn = self.out_proj(full_attn)
return full_attn, attn_weights[0].clone()
else:
_attn, _attn_weight = self.attention_module(query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias)
_attn, _attn_weight = self.attention_module(
query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
)
attn.append(_attn)
attn_weights.append(_attn_weight)
full_attn = torch.cat(attn, dim=2)
......
......@@ -5,8 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim.lr_scheduler
from . import FairseqLRScheduler, register_lr_scheduler
......
......@@ -34,6 +34,13 @@ def get_generation_parser(interactive=False):
return parser
def get_eval_lm_parser():
parser = get_parser('Evaluate Language Model')
add_dataset_args(parser, gen=True)
add_eval_lm_args(parser)
return parser
def eval_str_list(x, type=float):
if x is None:
return None
......@@ -41,15 +48,17 @@ def eval_str_list(x, type=float):
x = eval(x)
try:
return list(map(type, x))
except:
except TypeError:
return [type(x)]
def get_eval_lm_parser():
parser = get_parser('Evaluate Language Model')
add_dataset_args(parser, gen=True)
add_eval_lm_args(parser)
return parser
def eval_bool(x, default=False):
if x is None:
return default
try:
return bool(eval(x))
except TypeError:
return default
def parse_args_and_arch(parser, input_args=None):
......@@ -114,16 +123,17 @@ def add_dataset_args(parser, train=False, gen=False):
group.add_argument('--max-target-positions', '--tokens-per-sample', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set')
help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N',
help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch')
group.add_argument('--sample-break-mode', metavar='VAL',
choices=['none', 'complete', 'eos'],
help='Used only for LM datasets. If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end of sentence, but may include '
'multiple sentences per sample. If set to "eos", includes only one sentence per sample')
help='If omitted or "none", fills each sample with tokens-per-sample'
' tokens. If set to "complete", splits samples only at the end'
' of sentence, but may include multiple sentences per sample.'
' If set to "eos", includes only one sentence per sample.')
if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT',
......@@ -131,7 +141,7 @@ def add_dataset_args(parser, train=False, gen=False):
help='data subset to use for training (train, valid, test)')
group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list of data subsets to use for validation'
' (train, valid, valid1,test, test1)')
' (train, valid, valid1, test, test1)')
group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)')
......@@ -216,12 +226,10 @@ def add_checkpoint_args(parser):
help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, metavar='N',
help='if specified, saves best/last checkpoint every this many updates. '
'will also validate before saving to determine if val loss is better')
group.add_argument('--keep-interval-updates', type=int, default=0, metavar='N',
help='if --save-interval-updates is specified, keep the last this many checkpoints'
' created after specified number of updates (format is checkpoint_[epoch]_[numupd].pt')
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
help='save a checkpoint (and validate) every N updates')
group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
help='keep last N checkpoints saved with --save-interval-updates')
group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
......
......@@ -13,10 +13,11 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=None,
stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0, retain_dropout=False, sampling=False, sampling_topk=-1,
sampling_temperature=1):
def __init__(
self, models, beam_size=1, minlen=1, maxlen=None, stop_early=True,
normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False,
sampling=False, sampling_topk=-1, sampling_temperature=1,
):
"""Generates translations of a given source sentence.
Args:
min/maxlen: The length of the generated output will be bounded by
......@@ -53,8 +54,10 @@ class SequenceGenerator(object):
model.cuda()
return self
def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0):
def generate_batched_itr(
self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
......
......@@ -68,8 +68,10 @@ class Trainer(object):
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint
utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state)
utils.save_state(
filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename):
"""Load all training state from a checkpoint file."""
......
......@@ -66,7 +66,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
def load_model_state(filename, model):
if not os.path.exists(filename):
return None, [], None
state = torch.load(filename)
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
model.upgrade_state_dict(state['model'])
......@@ -142,9 +142,9 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
states.append(
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
)
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
states.append(state)
args = states[0]['args']
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
......@@ -157,6 +157,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
ensemble = []
for state in states:
model = models.build_model(args, src_dict, dst_dict)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
......@@ -278,7 +279,7 @@ def parse_embedding(embed_path):
"""
embed_dict = {}
with open(embed_path) as f_embed:
_ = next(f_embed) # skip header
next(f_embed) # skip header
for line in f_embed:
pieces = line.strip().split()
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
......@@ -352,12 +353,7 @@ def buffered_arange(max):
return buffered_arange.buf[:max]
def convert_padding_direction(
src_tokens,
padding_idx,
right_to_left=False,
left_to_right=False,
):
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
......@@ -401,9 +397,12 @@ def fill_with_neg_inf(t):
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
""" retrieves all checkpoints found in `path` directory. checkpoints are identified by matching filename to
the specified pattern. if the pattern contains groups, the result will be sorted by the first group in descending
order """
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)
......
......@@ -93,15 +93,15 @@ def main(args):
# Process input and ground truth
has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(target_tokens,
args.remove_bpe,
escape_unk=True) if has_target else ''
if has_target:
target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str))
......@@ -146,7 +146,7 @@ def main(args):
num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, num_sentences/ gen_timer.sum, 1. / gen_timer.avg))
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
if has_target:
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
......
......@@ -7,7 +7,6 @@
import unittest
import itertools
from unittest.mock import MagicMock, patch
import train
......
......@@ -163,7 +163,8 @@ def train(args, trainer, itr, epoch, dataset):
trainer.get_meter('wps').reset()
num_updates = trainer.get_num_updates()
if not args.no_save and (args.save_interval_updates or 0) > 0 and num_updates % args.save_interval_updates == 0:
if not args.no_save and (args.save_interval_updates or 0) > 0 and \
num_updates % args.save_interval_updates == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates)
save_checkpoint(trainer, args, epoch, end_of_epoch=False, val_loss=first_val_loss)
......@@ -235,17 +236,15 @@ def validate(args, trainer, dataset, subset, epoch, num_updates):
for sample in progress:
log_output = trainer.valid_step(sample)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
# log validation stats
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
if num_updates is not None:
stats['num_updates'] = num_updates
if hasattr(save_checkpoint, 'best'):
stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
progress.print(stats)
return stats['valid_loss']
......@@ -260,6 +259,9 @@ def get_valid_stats(trainer):
else:
nll_loss = trainer.get_meter('valid_loss').avg
stats['valid_ppl'] = get_perplexity(nll_loss)
stats['num_updates'] = trainer.get_num_updates()
if hasattr(save_checkpoint, 'best'):
stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment