Commit cf1c64a5 authored by Myle Ott's avatar Myle Ott
Browse files

Nits

parent 6eda8e47
...@@ -81,10 +81,7 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal ...@@ -81,10 +81,7 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal
dst.copy_(src) dst.copy_(src)
for i, v in enumerate(values): for i, v in enumerate(values):
if left_pad: copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
copy_tensor(v, res[i][size - len(v):])
else:
copy_tensor(v, res[i][:len(v)])
return res return res
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import numpy as np
import os import os
import struct import struct
import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
...@@ -197,10 +198,8 @@ class IndexedDatasetBuilder(object): ...@@ -197,10 +198,8 @@ class IndexedDatasetBuilder(object):
index = open(index_file, 'wb') index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00') index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1)) index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype), index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
self.element_size)) index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1,
len(self.sizes)))
write_longs(index, self.dim_offsets) write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets) write_longs(index, self.data_offsets)
write_longs(index, self.sizes) write_longs(index, self.sizes)
......
...@@ -12,7 +12,8 @@ from .fairseq_decoder import FairseqDecoder # noqa: F401 ...@@ -12,7 +12,8 @@ from .fairseq_decoder import FairseqDecoder # noqa: F401
from .fairseq_encoder import FairseqEncoder # noqa: F401 from .fairseq_encoder import FairseqEncoder # noqa: F401
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401 from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401 from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401
from .composite_encoder import CompositeEncoder # noqa: F401
from .composite_encoder import CompositeEncoder # noqa: F401
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {} ARCH_MODEL_REGISTRY = {}
......
...@@ -12,10 +12,15 @@ import torch.nn.functional as F ...@@ -12,10 +12,15 @@ import torch.nn.functional as F
from fairseq import options, utils from fairseq import options, utils
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution, AdaptiveSoftmax from fairseq.modules import (
AdaptiveSoftmax, BeamableMM, GradMultiply, LearnedPositionalEmbedding,
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, FairseqLanguageModel, register_model, \ LinearizedConvolution,
register_model_architecture )
from . import (
FairseqEncoder, FairseqIncrementalDecoder, FairseqModel,
FairseqLanguageModel, register_model, register_model_architecture,
)
@register_model('fconv') @register_model('fconv')
......
...@@ -7,16 +7,23 @@ ...@@ -7,16 +7,23 @@
# #
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data import LanguagePairDataset
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution, DownsampledMultiHeadAttention from fairseq.modules import (
DownsampledMultiHeadAttention, GradMultiply, LearnedPositionalEmbedding,
LinearizedConvolution,
)
from fairseq import utils from fairseq import utils
from . import FairseqEncoder, CompositeEncoder, FairseqDecoder, FairseqModel, register_model, register_model_architecture from . import (
FairseqEncoder, CompositeEncoder, FairseqDecoder, FairseqModel,
register_model, register_model_architecture,
)
@register_model('fconv_self_att') @register_model('fconv_self_att')
class FConvModelSelfAtt(FairseqModel): class FConvModelSelfAtt(FairseqModel):
...@@ -66,9 +73,9 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -66,9 +73,9 @@ class FConvModelSelfAtt(FairseqModel):
parser.add_argument('--downsample', type=str, metavar='EXPR', default='False', parser.add_argument('--downsample', type=str, metavar='EXPR', default='False',
help='Use downsampling in self-attention [True, ...]') help='Use downsampling in self-attention [True, ...]')
parser.add_argument('--pretrained-checkpoint', metavar='DIR', default='', parser.add_argument('--pretrained-checkpoint', metavar='DIR', default='',
help='path to load checkpoint from pretrained model') help='path to load checkpoint from pretrained model')
parser.add_argument('--pretrained', type=str, metavar='EXPR', default='False', parser.add_argument('--pretrained', type=str, metavar='EXPR', default='False',
help='use pretrained model when training [True, ...]') help='use pretrained model when training [True, ...]')
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, src_dict, dst_dict):
...@@ -76,7 +83,6 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -76,7 +83,6 @@ class FConvModelSelfAtt(FairseqModel):
pretrained = eval(args.pretrained) pretrained = eval(args.pretrained)
if pretrained: if pretrained:
print("| Loading pretrained model") print("| Loading pretrained model")
state = torch.load(args.pretrained_checkpoint)
trained_model = utils.load_ensemble_for_inference( trained_model = utils.load_ensemble_for_inference(
# not actually for inference, but loads pretrained model parameters # not actually for inference, but loads pretrained model parameters
filenames=[args.pretrained_checkpoint], filenames=[args.pretrained_checkpoint],
...@@ -131,9 +137,11 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -131,9 +137,11 @@ class FConvModelSelfAtt(FairseqModel):
class FConvEncoder(FairseqEncoder): class FConvEncoder(FairseqEncoder):
"""Convolutional encoder""" """Convolutional encoder"""
def __init__(self, dictionary, embed_dim=512, max_positions=1024, def __init__(
convolutions=((512, 3),) * 20, dropout=0.1, attention=False, self, dictionary, embed_dim=512, max_positions=1024,
attention_nheads=1): convolutions=((512, 3),) * 20, dropout=0.1, attention=False,
attention_nheads=1,
):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = dropout self.dropout = dropout
self.num_attention_layers = None self.num_attention_layers = None
...@@ -163,20 +171,20 @@ class FConvEncoder(FairseqEncoder): ...@@ -163,20 +171,20 @@ class FConvEncoder(FairseqEncoder):
self.attention = nn.ModuleList() self.attention = nn.ModuleList()
self.attproj = nn.ModuleList() self.attproj = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions): for i, (out_channels, kernel_size) in enumerate(convolutions):
self.projections.append(Linear(in_channels, out_channels) self.projections.append(
if in_channels != out_channels else None) Linear(in_channels, out_channels) if in_channels != out_channels else None
)
self.convolutions.append( self.convolutions.append(
ConvTBC(in_channels, out_channels * 2, kernel_size, ConvTBC(in_channels, out_channels * 2, kernel_size, dropout=dropout)
dropout=dropout)) )
self.attention.append(SelfAttention(out_channels, embed_dim, self.attention.append(
attention_nheads) SelfAttention(out_channels, embed_dim, attention_nheads) if attention[i] else None
if attention[i] else None) )
in_channels = out_channels in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim) self.fc2 = Linear(in_channels, embed_dim)
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens)
...@@ -226,18 +234,20 @@ class FConvEncoder(FairseqEncoder): ...@@ -226,18 +234,20 @@ class FConvEncoder(FairseqEncoder):
class FConvDecoder(FairseqDecoder): class FConvDecoder(FairseqDecoder):
"""Convolutional decoder""" """Convolutional decoder"""
def __init__(self, dictionary, embed_dim=512, out_embed_dim=256, def __init__(
max_positions=1024, convolutions=((512, 3),) * 8, self, dictionary, embed_dim=512, out_embed_dim=256, max_positions=1024,
attention=True, dropout=0.1, selfattention=False, convolutions=((512, 3),) * 8, attention=True, dropout=0.1,
attention_nheads=1, selfattention_nheads=1, selfattention=False, attention_nheads=1, selfattention_nheads=1,
project_input=False, gated_attention=False, downsample=False, project_input=False, gated_attention=False, downsample=False,
pretrained=False, trained_decoder=None): pretrained=False, trained_decoder=None,
):
super().__init__(dictionary) super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2])) self.register_buffer('version', torch.Tensor([2]))
self.pretrained = pretrained self.pretrained = pretrained
self.pretrained_decoder = trained_decoder self.pretrained_decoder = trained_decoder
self.dropout = dropout self.dropout = dropout
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
def expand_bool_array(val): def expand_bool_array(val):
if isinstance(val, bool): if isinstance(val, bool):
# expand True into [True, True, ...] and do the same with False # expand True into [True, True, ...] and do the same with False
...@@ -269,27 +279,33 @@ class FConvDecoder(FairseqDecoder): ...@@ -269,27 +279,33 @@ class FConvDecoder(FairseqDecoder):
self.selfattention = nn.ModuleList() self.selfattention = nn.ModuleList()
self.attproj = nn.ModuleList() self.attproj = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions): for i, (out_channels, kernel_size) in enumerate(convolutions):
pad = kernel_size - 1 self.projections.append(
self.projections.append(Linear(in_channels, out_channels) Linear(in_channels, out_channels) if in_channels != out_channels else None
if in_channels != out_channels else None) )
self.convolutions.append( self.convolutions.append(
LinearizedConv1d(in_channels, out_channels * 2, kernel_size, LinearizedConv1d(
padding=(kernel_size - 1), dropout=dropout)) in_channels, out_channels * 2, kernel_size,
padding=(kernel_size - 1), dropout=dropout,
self.attention.append(DownsampledMultiHeadAttention(out_channels, embed_dim, )
attention_nheads, )
project_input=project_input,
gated=False, downsample=False) self.attention.append(
if attention[i] else None) DownsampledMultiHeadAttention(
out_channels, embed_dim, attention_nheads,
self.attproj.append(Linear(out_channels, embed_dim, dropout=dropout) project_input=project_input, gated=False, downsample=False,
if attention[i] else None) ) if attention[i] else None
self.selfattention.append(SelfAttention(out_channels, embed_dim, )
selfattention_nheads,
project_input=project_input, self.attproj.append(
gated=gated_attention, Linear(out_channels, embed_dim, dropout=dropout) if attention[i] else None
downsample=downsample) )
if selfattention[i] else None) self.selfattention.append(
SelfAttention(
out_channels, embed_dim, selfattention_nheads,
project_input=project_input, gated=gated_attention,
downsample=downsample,
) if selfattention[i] else None
)
in_channels = out_channels in_channels = out_channels
self.fc2 = Linear(in_channels, out_embed_dim) self.fc2 = Linear(in_channels, out_embed_dim)
...@@ -301,24 +317,27 @@ class FConvDecoder(FairseqDecoder): ...@@ -301,24 +317,27 @@ class FConvDecoder(FairseqDecoder):
self.gate1 = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim), nn.Sigmoid()) self.gate1 = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim), nn.Sigmoid())
self.gate2 = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim), nn.Sigmoid()) self.gate2 = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim), nn.Sigmoid())
# pretrained and trained models are joined # pretrained and trained models are joined
self.joining = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim*2), self.joining = nn.Sequential(
nn.LayerNorm(out_embed_dim*2), Linear(out_embed_dim*2, out_embed_dim*2),
nn.GLU(), nn.LayerNorm(out_embed_dim*2),
Linear(out_embed_dim, out_embed_dim*2), nn.GLU(),
nn.LayerNorm(out_embed_dim*2), Linear(out_embed_dim, out_embed_dim*2),
nn.GLU(), nn.LayerNorm(out_embed_dim*2),
Linear(out_embed_dim, out_embed_dim), nn.GLU(),
nn.LayerNorm(out_embed_dim)) Linear(out_embed_dim, out_embed_dim),
nn.LayerNorm(out_embed_dim)
)
# pretrained model contains an output layer that is nhid -> vocab size # pretrained model contains an output layer that is nhid -> vocab size
# but the models are combined in their hidden state # but the models are combined in their hidden state
# the hook stores the output of the pretrained model forward # the hook stores the output of the pretrained model forward
self.pretrained_outputs = {} self.pretrained_outputs = {}
def save_output(): def save_output():
def hook(a, b, output): def hook(a, b, output):
self.pretrained_outputs["out"] = output self.pretrained_outputs["out"] = output
return hook return hook
self.pretrained_decoder.fc2.register_forward_hook(save_output())
self.pretrained_decoder.fc2.register_forward_hook(save_output())
def forward(self, prev_output_tokens, encoder_out_dict): def forward(self, prev_output_tokens, encoder_out_dict):
encoder_out = encoder_out_dict['encoder']['encoder_out'] encoder_out = encoder_out_dict['encoder']['encoder_out']
...@@ -342,11 +361,9 @@ class FConvDecoder(FairseqDecoder): ...@@ -342,11 +361,9 @@ class FConvDecoder(FairseqDecoder):
# temporal convolutions # temporal convolutions
avg_attn_scores = None avg_attn_scores = None
for proj, conv, attention, selfattention, attproj in zip(self.projections, for proj, conv, attention, selfattention, attproj in zip(
self.convolutions, self.projections, self.convolutions, self.attention, self.selfattention, self.attproj
self.attention, ):
self.selfattention,
self.attproj):
residual = x if proj is None else proj(x) residual = x if proj is None else proj(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
...@@ -398,11 +415,14 @@ class FConvDecoder(FairseqDecoder): ...@@ -398,11 +415,14 @@ class FConvDecoder(FairseqDecoder):
def reorder_encoder_out(self, encoder_out_dict, new_order): def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder']['encoder_out'] = tuple( encoder_out_dict['encoder']['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder']['encoder_out']) eo.index_select(0, new_order) for eo in encoder_out_dict['encoder']['encoder_out']
)
if 'pretrained' in encoder_out_dict: if 'pretrained' in encoder_out_dict:
encoder_out_dict['pretrained']['encoder']['encoder_out'] = tuple( encoder_out_dict['pretrained']['encoder']['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['pretrained']['encoder']['encoder_out']) eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder']['encoder_out']
)
return encoder_out_dict return encoder_out_dict
...@@ -425,8 +445,10 @@ class SelfAttention(nn.Module): ...@@ -425,8 +445,10 @@ class SelfAttention(nn.Module):
def __init__(self, out_channels, embed_dim, num_heads, project_input=False, gated=False, downsample=False): def __init__(self, out_channels, embed_dim, num_heads, project_input=False, gated=False, downsample=False):
super().__init__() super().__init__()
self.attention = DownsampledMultiHeadAttention(out_channels, embed_dim, num_heads, self.attention = DownsampledMultiHeadAttention(
dropout=0, bias=True, project_input=project_input, gated=gated, downsample=downsample) out_channels, embed_dim, num_heads, dropout=0, bias=True,
project_input=project_input, gated=gated, downsample=downsample,
)
self.in_proj_q = Linear(out_channels, embed_dim) self.in_proj_q = Linear(out_channels, embed_dim)
self.in_proj_k = Linear(out_channels, embed_dim) self.in_proj_k = Linear(out_channels, embed_dim)
self.in_proj_v = Linear(out_channels, embed_dim) self.in_proj_v = Linear(out_channels, embed_dim)
...@@ -441,7 +463,6 @@ class SelfAttention(nn.Module): ...@@ -441,7 +463,6 @@ class SelfAttention(nn.Module):
return self.ln(x + residual) return self.ln(x + residual)
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, 0.1) m.weight.data.normal_(0, 0.1)
......
...@@ -10,10 +10,13 @@ from torch.autograd import Variable ...@@ -10,10 +10,13 @@ from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import options, utils
from fairseq.data import consts from fairseq.data import consts
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, register_model_architecture from . import (
FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model,
register_model_architecture,
)
@register_model('lstm') @register_model('lstm')
...@@ -92,10 +95,6 @@ class LSTMModel(FairseqModel): ...@@ -92,10 +95,6 @@ class LSTMModel(FairseqModel):
bidirectional=args.encoder_bidirectional, bidirectional=args.encoder_bidirectional,
pretrained_embed=pretrained_encoder_embed, pretrained_embed=pretrained_encoder_embed,
) )
try:
attention = bool(eval(args.decoder_attention))
except TypeError:
attention = bool(args.decoder_attention)
decoder = LSTMDecoder( decoder = LSTMDecoder(
dictionary=dst_dict, dictionary=dst_dict,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
...@@ -104,7 +103,7 @@ class LSTMModel(FairseqModel): ...@@ -104,7 +103,7 @@ class LSTMModel(FairseqModel):
num_layers=args.decoder_layers, num_layers=args.decoder_layers,
dropout_in=args.decoder_dropout_in, dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out, dropout_out=args.decoder_dropout_out,
attention=attention, attention=options.eval_bool(args.decoder_attention),
encoder_embed_dim=args.encoder_embed_dim, encoder_embed_dim=args.encoder_embed_dim,
encoder_output_units=encoder.output_units, encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed, pretrained_embed=pretrained_decoder_embed,
......
...@@ -435,6 +435,7 @@ def transformer_vaswani_wmt_en_de_big(args): ...@@ -435,6 +435,7 @@ def transformer_vaswani_wmt_en_de_big(args):
args.dropout = getattr(args, 'dropout', 0.3) args.dropout = getattr(args, 'dropout', 0.3)
base_architecture(args) base_architecture(args)
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_fr_big') @register_model_architecture('transformer', 'transformer_vaswani_wmt_en_fr_big')
def transformer_vaswani_wmt_en_fr_big(args): def transformer_vaswani_wmt_en_fr_big(args):
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
......
...@@ -11,8 +11,11 @@ from torch import nn ...@@ -11,8 +11,11 @@ from torch import nn
class AdaptiveSoftmax(nn.Module): class AdaptiveSoftmax(nn.Module):
"""This is an implementation of the efficient softmax approximation for graphical processing units (GPU), """
described in the paper "Efficient softmax approximation for GPUs" (http://arxiv.org/abs/1609.04309).""" This is an implementation of the efficient softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax
approximation for GPUs" (http://arxiv.org/abs/1609.04309).
"""
def __init__(self, vocab_size, input_dim, cutoff, dropout): def __init__(self, vocab_size, input_dim, cutoff, dropout):
super().__init__() super().__init__()
...@@ -46,9 +49,12 @@ class AdaptiveSoftmax(nn.Module): ...@@ -46,9 +49,12 @@ class AdaptiveSoftmax(nn.Module):
self.apply(init_weights) self.apply(init_weights)
def adapt_target(self, target): def adapt_target(self, target):
"""In order to be efficient, the AdaptiveSoftMax does not compute the scores for all the word of the """
vocabulary for all the examples.It is thus necessary to call the method adapt_target of the AdaptiveSoftMax In order to be efficient, the AdaptiveSoftMax does not compute the
layer inside each forward pass.""" scores for all the word of the vocabulary for all the examples. It is
thus necessary to call the method adapt_target of the AdaptiveSoftMax
layer inside each forward pass.
"""
target = target.view(-1) target = target.view(-1)
new_target = [target.clone()] new_target = [target.clone()]
...@@ -68,12 +74,17 @@ class AdaptiveSoftmax(nn.Module): ...@@ -68,12 +74,17 @@ class AdaptiveSoftmax(nn.Module):
return new_target, target_idxs return new_target, target_idxs
def forward(self, input, target): def forward(self, input, target):
""" accepts input (b x t x d) and target (b x t) and returns """
2 lists: output for each cutoff section and new targets by cut off """ Args:
input: (b x t x d)
target: (b x t)
Returns:
2 lists: output for each cutoff section and new targets by cut off
"""
input = input.contiguous().view(-1, input.size(-1)) input = input.contiguous().view(-1, input.size(-1))
input = F.dropout(input, p=self.dropout, training=self.training) input = F.dropout(input, p=self.dropout, training=self.training)
new_target, target_idxs = self.adapt_target(target) new_target, target_idxs = self.adapt_target(target)
output = [self.head(input)] output = [self.head(input)]
...@@ -86,7 +97,10 @@ class AdaptiveSoftmax(nn.Module): ...@@ -86,7 +97,10 @@ class AdaptiveSoftmax(nn.Module):
return output, new_target return output, new_target
def get_log_prob(self, input, target): def get_log_prob(self, input, target):
"""computes the log probabilities for all the words of the vocabulary, given a 2D tensor of hidden vectors""" """
Computes the log probabilities for all the words of the vocabulary,
given a 2D tensor of hidden vectors.
"""
bsz, length, dim = input.size() bsz, length, dim = input.size()
input = input.contiguous().view(-1, dim) input = input.contiguous().view(-1, dim)
......
...@@ -20,64 +20,51 @@ class SingleHeadAttention(nn.Module): ...@@ -20,64 +20,51 @@ class SingleHeadAttention(nn.Module):
Single-head attention that supports Gating and Downsampling Single-head attention that supports Gating and Downsampling
""" """
def __init__( def __init__(
self, self, out_channels, embed_dim, head_dim, head_index, dropout=0.,
out_channels, bias=True, project_input=True, gated=False, downsample=False,
embed_dim, num_heads=1,
head_dim,
head_index,
dropout=0.,
bias=True,
project_input=True,
gated=False,
downsample=False,
num_heads=1
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.dropout = dropout self.dropout = dropout
self.head_index = head_index self.head_index = head_index
self.head_dim = head_dim self.head_dim = head_dim
self.project_input = project_input self.project_input = project_input
self.gated = gated self.gated = gated
self.downsample = downsample self.downsample = downsample
self.num_heads = num_heads self.num_heads = num_heads
self.projection = None self.projection = None
k_layers = [] k_layers = []
v_layers = [] v_layers = []
if self.downsample: if self.downsample:
k_layers.append(Downsample(self.head_index)) k_layers.append(Downsample(self.head_index))
v_layers.append(Downsample(self.head_index)) v_layers.append(Downsample(self.head_index))
out_proj_size = self.head_dim out_proj_size = self.head_dim
else: else:
out_proj_size = self.head_dim * self.num_heads out_proj_size = self.head_dim * self.num_heads
if self.gated: if self.gated:
k_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias)) k_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_q = GatedLinear(self.embed_dim, out_proj_size, bias=bias) self.in_proj_q = GatedLinear(self.embed_dim, out_proj_size, bias=bias)
v_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias)) v_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
else: else:
k_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias)) k_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_q = Linear(self.embed_dim, out_proj_size, bias=bias) self.in_proj_q = Linear(self.embed_dim, out_proj_size, bias=bias)
v_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias)) v_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_k = nn.Sequential(*k_layers) self.in_proj_k = nn.Sequential(*k_layers)
self.in_proj_v = nn.Sequential(*v_layers) self.in_proj_v = nn.Sequential(*v_layers)
if self.downsample: if self.downsample:
self.out_proj = Linear(out_proj_size, self.head_dim, bias=bias) self.out_proj = Linear(out_proj_size, self.head_dim, bias=bias)
else: else:
self.out_proj = Linear(out_proj_size, out_channels, bias=bias) self.out_proj = Linear(out_proj_size, out_channels, bias=bias)
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
def forward( def forward(
self, self, query, key, value, mask_future_timesteps=False,
query, key_padding_mask=None, use_scalar_bias=False,
key,
value,
mask_future_timesteps=False,
key_padding_mask=None,
use_scalar_bias=False
): ):
"""Input shape: Time x Batch x Channel """Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for Self-attention can be implemented by passing in the same arguments for
...@@ -168,15 +155,8 @@ class DownsampledMultiHeadAttention(nn.ModuleList): ...@@ -168,15 +155,8 @@ class DownsampledMultiHeadAttention(nn.ModuleList):
Multi-headed attention with Gating and Downsampling Multi-headed attention with Gating and Downsampling
""" """
def __init__( def __init__(
self, self, out_channels, embed_dim, num_heads, dropout=0., bias=True,
out_channels, project_input=True, gated=False, downsample=False,
embed_dim,
num_heads,
dropout=0.,
bias=True,
project_input=True,
gated=False,
downsample=False
): ):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
...@@ -190,24 +170,27 @@ class DownsampledMultiHeadAttention(nn.ModuleList): ...@@ -190,24 +170,27 @@ class DownsampledMultiHeadAttention(nn.ModuleList):
if self.downsample: if self.downsample:
attention_heads = [] attention_heads = []
for index in range(self.num_heads): for index in range(self.num_heads):
attention_heads.append(SingleHeadAttention(out_channels, self.embed_dim, self.head_dim, index, self.dropout, bias, self.project_input, self.gated, self.downsample, self.num_heads)) attention_heads.append(
SingleHeadAttention(
out_channels, self.embed_dim, self.head_dim, index,
self.dropout, bias, self.project_input, self.gated,
self.downsample, self.num_heads,
)
)
super().__init__(modules=attention_heads) super().__init__(modules=attention_heads)
self.out_proj = Linear(embed_dim, out_channels, bias=bias) self.out_proj = Linear(embed_dim, out_channels, bias=bias)
else: else:
# either we have a list of attention heads, or just one attention head # either we have a list of attention heads, or just one attention head
# if not being downsampled, we can do the heads with one linear layer instead of separate ones # if not being downsampled, we can do the heads with one linear layer instead of separate ones
super().__init__() super().__init__()
self.attention_module = SingleHeadAttention(out_channels, self.embed_dim, self.head_dim, 1, self.dropout, bias, self.project_input, self.gated, self.downsample, self.num_heads) self.attention_module = SingleHeadAttention(
out_channels, self.embed_dim, self.head_dim, 1, self.dropout,
bias, self.project_input, self.gated, self.downsample, self.num_heads,
)
def forward( def forward(
self, self, query, key, value, mask_future_timesteps=False,
query, key_padding_mask=None, use_scalar_bias=False,
key,
value,
mask_future_timesteps=False,
key_padding_mask=None,
use_scalar_bias=False
): ):
src_len, bsz, embed_dim = key.size() src_len, bsz, embed_dim = key.size()
tgt_len = query.size(0) tgt_len = query.size(0)
...@@ -224,14 +207,18 @@ class DownsampledMultiHeadAttention(nn.ModuleList): ...@@ -224,14 +207,18 @@ class DownsampledMultiHeadAttention(nn.ModuleList):
if self.downsample: if self.downsample:
for attention_head_number in range(self.num_heads): for attention_head_number in range(self.num_heads):
# call the forward of each attention head # call the forward of each attention head
_attn, _attn_weight = self[attention_head_number](query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias) _attn, _attn_weight = self[attention_head_number](
query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
)
attn.append(_attn) attn.append(_attn)
attn_weights.append(_attn_weight) attn_weights.append(_attn_weight)
full_attn = torch.cat(attn, dim=2) full_attn = torch.cat(attn, dim=2)
full_attn = self.out_proj(full_attn) full_attn = self.out_proj(full_attn)
return full_attn, attn_weights[0].clone() return full_attn, attn_weights[0].clone()
else: else:
_attn, _attn_weight = self.attention_module(query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias) _attn, _attn_weight = self.attention_module(
query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
)
attn.append(_attn) attn.append(_attn)
attn_weights.append(_attn_weight) attn_weights.append(_attn_weight)
full_attn = torch.cat(attn, dim=2) full_attn = torch.cat(attn, dim=2)
...@@ -264,9 +251,9 @@ def Linear(in_features, out_features, dropout=0., bias=True): ...@@ -264,9 +251,9 @@ def Linear(in_features, out_features, dropout=0., bias=True):
def GatedLinear(in_features, out_features, dropout=0., bias=True): def GatedLinear(in_features, out_features, dropout=0., bias=True):
"""Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units""" """Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units"""
return nn.Sequential( return nn.Sequential(
Linear(in_features, out_features*4, dropout, bias), Linear(in_features, out_features*4, dropout, bias),
nn.GLU(), nn.GLU(),
Linear(out_features*2, out_features*2, dropout, bias), Linear(out_features*2, out_features*2, dropout, bias),
nn.GLU(), nn.GLU(),
Linear(out_features, out_features, dropout, bias) Linear(out_features, out_features, dropout, bias)
) )
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch.optim.lr_scheduler
from . import FairseqLRScheduler, register_lr_scheduler from . import FairseqLRScheduler, register_lr_scheduler
......
...@@ -34,6 +34,13 @@ def get_generation_parser(interactive=False): ...@@ -34,6 +34,13 @@ def get_generation_parser(interactive=False):
return parser return parser
def get_eval_lm_parser():
parser = get_parser('Evaluate Language Model')
add_dataset_args(parser, gen=True)
add_eval_lm_args(parser)
return parser
def eval_str_list(x, type=float): def eval_str_list(x, type=float):
if x is None: if x is None:
return None return None
...@@ -41,15 +48,17 @@ def eval_str_list(x, type=float): ...@@ -41,15 +48,17 @@ def eval_str_list(x, type=float):
x = eval(x) x = eval(x)
try: try:
return list(map(type, x)) return list(map(type, x))
except: except TypeError:
return [type(x)] return [type(x)]
def get_eval_lm_parser(): def eval_bool(x, default=False):
parser = get_parser('Evaluate Language Model') if x is None:
add_dataset_args(parser, gen=True) return default
add_eval_lm_args(parser) try:
return parser return bool(eval(x))
except TypeError:
return default
def parse_args_and_arch(parser, input_args=None): def parse_args_and_arch(parser, input_args=None):
...@@ -114,16 +123,17 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -114,16 +123,17 @@ def add_dataset_args(parser, train=False, gen=False):
group.add_argument('--max-target-positions', '--tokens-per-sample', default=1024, type=int, metavar='N', group.add_argument('--max-target-positions', '--tokens-per-sample', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence') help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set') help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N', group.add_argument('--max-tokens', type=int, metavar='N',
help='maximum number of tokens in a batch') help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch') help='maximum number of sentences in a batch')
group.add_argument('--sample-break-mode', metavar='VAL', group.add_argument('--sample-break-mode', metavar='VAL',
choices=['none', 'complete', 'eos'], choices=['none', 'complete', 'eos'],
help='Used only for LM datasets. If omitted or "none", fills each sample with tokens-per-sample ' help='If omitted or "none", fills each sample with tokens-per-sample'
'tokens. If set to "complete", splits samples only at the end of sentence, but may include ' ' tokens. If set to "complete", splits samples only at the end'
'multiple sentences per sample. If set to "eos", includes only one sentence per sample') ' of sentence, but may include multiple sentences per sample.'
' If set to "eos", includes only one sentence per sample.')
if train: if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT', group.add_argument('--train-subset', default='train', metavar='SPLIT',
...@@ -131,7 +141,7 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -131,7 +141,7 @@ def add_dataset_args(parser, train=False, gen=False):
help='data subset to use for training (train, valid, test)') help='data subset to use for training (train, valid, test)')
group.add_argument('--valid-subset', default='valid', metavar='SPLIT', group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list of data subsets to use for validation' help='comma separated list of data subsets to use for validation'
' (train, valid, valid1,test, test1)') ' (train, valid, valid1, test, test1)')
group.add_argument('--max-sentences-valid', type=int, metavar='N', group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch' help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)') ' (defaults to --max-sentences)')
...@@ -216,12 +226,10 @@ def add_checkpoint_args(parser): ...@@ -216,12 +226,10 @@ def add_checkpoint_args(parser):
help='filename in save-dir from which to load checkpoint') help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N', group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs') help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, metavar='N', group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
help='if specified, saves best/last checkpoint every this many updates. ' help='save a checkpoint (and validate) every N updates')
'will also validate before saving to determine if val loss is better') group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
group.add_argument('--keep-interval-updates', type=int, default=0, metavar='N', help='keep last N checkpoints saved with --save-interval-updates')
help='if --save-interval-updates is specified, keep the last this many checkpoints'
' created after specified number of updates (format is checkpoint_[epoch]_[numupd].pt')
group.add_argument('--no-save', action='store_true', group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints') help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true', group.add_argument('--no-epoch-checkpoints', action='store_true',
......
...@@ -13,10 +13,11 @@ from fairseq.models import FairseqIncrementalDecoder ...@@ -13,10 +13,11 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object): class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=None, def __init__(
stop_early=True, normalize_scores=True, len_penalty=1, self, models, beam_size=1, minlen=1, maxlen=None, stop_early=True,
unk_penalty=0, retain_dropout=False, sampling=False, sampling_topk=-1, normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False,
sampling_temperature=1): sampling=False, sampling_topk=-1, sampling_temperature=1,
):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
Args: Args:
min/maxlen: The length of the generated output will be bounded by min/maxlen: The length of the generated output will be bounded by
...@@ -53,8 +54,10 @@ class SequenceGenerator(object): ...@@ -53,8 +54,10 @@ class SequenceGenerator(object):
model.cuda() model.cuda()
return self return self
def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None, def generate_batched_itr(
cuda=False, timer=None, prefix_size=0): self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
Args: Args:
maxlen_a/b: generate sequences of maximum length ax + b, maxlen_a/b: generate sequences of maximum length ax + b,
......
...@@ -68,8 +68,10 @@ class Trainer(object): ...@@ -68,8 +68,10 @@ class Trainer(object):
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer, utils.save_state(
self.lr_scheduler, self._num_updates, self._optim_history, extra_state) filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename): def load_checkpoint(self, filename):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
......
...@@ -66,7 +66,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, ...@@ -66,7 +66,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
def load_model_state(filename, model): def load_model_state(filename, model):
if not os.path.exists(filename): if not os.path.exists(filename):
return None, [], None return None, [], None
state = torch.load(filename) state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state) state = _upgrade_state_dict(state)
model.upgrade_state_dict(state['model']) model.upgrade_state_dict(state['model'])
...@@ -142,9 +142,9 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, ...@@ -142,9 +142,9 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
for filename in filenames: for filename in filenames:
if not os.path.exists(filename): if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename)) raise IOError('Model file not found: {}'.format(filename))
states.append( state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) state = _upgrade_state_dict(state)
) states.append(state)
args = states[0]['args'] args = states[0]['args']
if model_arg_overrides is not None: if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides) args = _override_model_args(args, model_arg_overrides)
...@@ -157,6 +157,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, ...@@ -157,6 +157,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
ensemble = [] ensemble = []
for state in states: for state in states:
model = models.build_model(args, src_dict, dst_dict) model = models.build_model(args, src_dict, dst_dict)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True) model.load_state_dict(state['model'], strict=True)
ensemble.append(model) ensemble.append(model)
return ensemble, args return ensemble, args
...@@ -278,7 +279,7 @@ def parse_embedding(embed_path): ...@@ -278,7 +279,7 @@ def parse_embedding(embed_path):
""" """
embed_dict = {} embed_dict = {}
with open(embed_path) as f_embed: with open(embed_path) as f_embed:
_ = next(f_embed) # skip header next(f_embed) # skip header
for line in f_embed: for line in f_embed:
pieces = line.strip().split() pieces = line.strip().split()
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]]) embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
...@@ -352,12 +353,7 @@ def buffered_arange(max): ...@@ -352,12 +353,7 @@ def buffered_arange(max):
return buffered_arange.buf[:max] return buffered_arange.buf[:max]
def convert_padding_direction( def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
src_tokens,
padding_idx,
right_to_left=False,
left_to_right=False,
):
assert right_to_left ^ left_to_right assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx) pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any(): if not pad_mask.any():
...@@ -401,9 +397,12 @@ def fill_with_neg_inf(t): ...@@ -401,9 +397,12 @@ def fill_with_neg_inf(t):
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
""" retrieves all checkpoints found in `path` directory. checkpoints are identified by matching filename to """Retrieves all checkpoints found in `path` directory.
the specified pattern. if the pattern contains groups, the result will be sorted by the first group in descending
order """ Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern) pt_regexp = re.compile(pattern)
files = os.listdir(path) files = os.listdir(path)
......
...@@ -93,15 +93,15 @@ def main(args): ...@@ -93,15 +93,15 @@ def main(args):
# Process input and ground truth # Process input and ground truth
has_target = target_tokens is not None has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None target_tokens = target_tokens.int().cpu() if has_target else None
# Either retrieve the original sentences or regenerate them from tokens. # Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None: if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id) src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id) target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
else: else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe) src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(target_tokens, if has_target:
args.remove_bpe, target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
escape_unk=True) if has_target else ''
if not args.quiet: if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str)) print('S-{}\t{}'.format(sample_id, src_str))
...@@ -146,7 +146,7 @@ def main(args): ...@@ -146,7 +146,7 @@ def main(args):
num_sentences += 1 num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, num_sentences/ gen_timer.sum, 1. / gen_timer.avg)) num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
if has_target: if has_target:
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
import unittest import unittest
import itertools
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import train import train
......
...@@ -163,7 +163,8 @@ def train(args, trainer, itr, epoch, dataset): ...@@ -163,7 +163,8 @@ def train(args, trainer, itr, epoch, dataset):
trainer.get_meter('wps').reset() trainer.get_meter('wps').reset()
num_updates = trainer.get_num_updates() num_updates = trainer.get_num_updates()
if not args.no_save and (args.save_interval_updates or 0) > 0 and num_updates % args.save_interval_updates == 0: if not args.no_save and (args.save_interval_updates or 0) > 0 and \
num_updates % args.save_interval_updates == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates) first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates)
save_checkpoint(trainer, args, epoch, end_of_epoch=False, val_loss=first_val_loss) save_checkpoint(trainer, args, epoch, end_of_epoch=False, val_loss=first_val_loss)
...@@ -235,17 +236,15 @@ def validate(args, trainer, dataset, subset, epoch, num_updates): ...@@ -235,17 +236,15 @@ def validate(args, trainer, dataset, subset, epoch, num_updates):
for sample in progress: for sample in progress:
log_output = trainer.valid_step(sample) log_output = trainer.valid_step(sample)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
# log validation stats # log validation stats
stats = get_valid_stats(trainer) stats = get_valid_stats(trainer)
for k, meter in extra_meters.items(): for k, meter in extra_meters.items():
stats[k] = meter.avg stats[k] = meter.avg
if num_updates is not None:
stats['num_updates'] = num_updates
if hasattr(save_checkpoint, 'best'):
stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
progress.print(stats) progress.print(stats)
return stats['valid_loss'] return stats['valid_loss']
...@@ -260,6 +259,9 @@ def get_valid_stats(trainer): ...@@ -260,6 +259,9 @@ def get_valid_stats(trainer):
else: else:
nll_loss = trainer.get_meter('valid_loss').avg nll_loss = trainer.get_meter('valid_loss').avg
stats['valid_ppl'] = get_perplexity(nll_loss) stats['valid_ppl'] = get_perplexity(nll_loss)
stats['num_updates'] = trainer.get_num_updates()
if hasattr(save_checkpoint, 'best'):
stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
return stats return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment