Commit ff68a9ef authored by Myle Ott's avatar Myle Ott
Browse files

Add FairseqTask

A Task defines the data format, stores shared state (e.g., dictionaries) and provides helpers for building the model/criterion and calculating the loss.

Changes:
- Add TranslationTask and LanguageModelingTask. New tasks can be registered with @register_task decorator.
- Add EpochBatchIterator to encapsulate batching and saving/restoring dataloader position
- Remove LEFT_PAD_* constants and make them configurable per task
parent 2de93532
...@@ -15,13 +15,14 @@ from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel ...@@ -15,13 +15,14 @@ from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel
from .composite_encoder import CompositeEncoder # noqa: F401 from .composite_encoder import CompositeEncoder # noqa: F401
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {} ARCH_MODEL_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {} ARCH_CONFIG_REGISTRY = {}
def build_model(args, src_dict, dst_dict): def build_model(args, task):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, src_dict, dst_dict) return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task)
def register_model(name): def register_model(name):
......
...@@ -23,10 +23,27 @@ class BaseFairseqModel(nn.Module): ...@@ -23,10 +23,27 @@ class BaseFairseqModel(nn.Module):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
pass pass
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
raise NotImplementedError
def get_targets(self, sample, net_output): def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output.""" """Get targets from either the sample or the net's output."""
return sample['target'] return sample['target']
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
def max_positions(self):
"""Maximum length supported by the model."""
raise NotImplementedError
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Copies parameters and buffers from state_dict into this module and """Copies parameters and buffers from state_dict into this module and
its descendants. its descendants.
...@@ -87,33 +104,14 @@ class FairseqModel(BaseFairseqModel): ...@@ -87,33 +104,14 @@ class FairseqModel(BaseFairseqModel):
assert isinstance(self.encoder, FairseqEncoder) assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder) assert isinstance(self.decoder, FairseqDecoder)
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
raise NotImplementedError
def forward(self, src_tokens, src_lengths, prev_output_tokens): def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths) encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out) decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out return decoder_out
def get_normalized_probs(self, net_output, log_probs, sample=None): def max_positions(self):
"""Get normalized probabilities (or log probs) from a net's output.""" """Maximum length supported by the model."""
return self.decoder.get_normalized_probs(net_output, log_probs, sample) return (self.encoder.max_positions(), self.decoder.max_positions())
def max_encoder_positions(self):
"""Maximum input length supported by the encoder."""
return self.encoder.max_positions()
def max_decoder_positions(self):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
class FairseqLanguageModel(BaseFairseqModel): class FairseqLanguageModel(BaseFairseqModel):
...@@ -124,16 +122,9 @@ class FairseqLanguageModel(BaseFairseqModel): ...@@ -124,16 +122,9 @@ class FairseqLanguageModel(BaseFairseqModel):
self.decoder = decoder self.decoder = decoder
assert isinstance(self.decoder, FairseqDecoder) assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, **unused): def forward(self, src_tokens):
return self.decoder(src_tokens) return self.decoder(src_tokens)
def get_normalized_probs(self, net_output, log_probs, sample=None): def max_positions(self):
"""Get normalized probabilities (or log probs) from a net's output.""" """Maximum length supported by the model."""
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
def max_decoder_positions(self):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions() return self.decoder.max_positions()
def max_encoder_positions(self):
return self.max_decoder_positions()
...@@ -11,7 +11,6 @@ import torch.nn as nn ...@@ -11,7 +11,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import options, utils from fairseq import options, utils
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import ( from fairseq.modules import (
AdaptiveSoftmax, BeamableMM, GradMultiply, LearnedPositionalEmbedding, AdaptiveSoftmax, BeamableMM, GradMultiply, LearnedPositionalEmbedding,
LinearizedConvolution, LinearizedConvolution,
...@@ -58,26 +57,23 @@ class FConvModel(FairseqModel): ...@@ -58,26 +57,23 @@ class FConvModel(FairseqModel):
' to be equal)') ' to be equal)')
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones) # make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args) base_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions
encoder_embed_dict = None encoder_embed_dict = None
if args.encoder_embed_path: if args.encoder_embed_path:
encoder_embed_dict = utils.parse_embedding(args.encoder_embed_path) encoder_embed_dict = utils.parse_embedding(args.encoder_embed_path)
utils.print_embed_overlap(encoder_embed_dict, src_dict) utils.print_embed_overlap(encoder_embed_dict, task.source_dictionary)
decoder_embed_dict = None decoder_embed_dict = None
if args.decoder_embed_path: if args.decoder_embed_path:
decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path) decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path)
utils.print_embed_overlap(decoder_embed_dict, dst_dict) utils.print_embed_overlap(decoder_embed_dict, task.target_dictionary)
encoder = FConvEncoder( encoder = FConvEncoder(
src_dict, dictionary=task.source_dictionary,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
embed_dict=encoder_embed_dict, embed_dict=encoder_embed_dict,
convolutions=eval(args.encoder_layers), convolutions=eval(args.encoder_layers),
...@@ -86,7 +82,7 @@ class FConvModel(FairseqModel): ...@@ -86,7 +82,7 @@ class FConvModel(FairseqModel):
normalization_constant=args.normalization_constant, normalization_constant=args.normalization_constant,
) )
decoder = FConvDecoder( decoder = FConvDecoder(
dst_dict, dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
embed_dict=decoder_embed_dict, embed_dict=decoder_embed_dict,
convolutions=eval(args.decoder_layers), convolutions=eval(args.decoder_layers),
...@@ -125,27 +121,28 @@ class FConvLanguageModel(FairseqLanguageModel): ...@@ -125,27 +121,28 @@ class FConvLanguageModel(FairseqLanguageModel):
help='multiplies the result of the residual block by sqrt(value)') help='multiplies the result of the residual block by sqrt(value)')
@classmethod @classmethod
def build_model(cls, args, dict, *_): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions
# make sure all arguments are present in older models # make sure all arguments are present in older models
base_lm_architecture(args) base_lm_architecture(args)
if hasattr(args, 'max_target_positions'):
args.tokens_per_sample = args.max_target_positions
decoder = FConvDecoder( decoder = FConvDecoder(
dict, dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers), convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_embed_dim, out_embed_dim=args.decoder_embed_dim,
attention=eval(args.decoder_attention), attention=eval(args.decoder_attention),
dropout=args.dropout, dropout=args.dropout,
max_positions=args.max_target_positions, max_positions=args.tokens_per_sample,
share_embed=False, share_embed=False,
positional_embeddings=False, positional_embeddings=False,
adaptive_softmax_cutoff=options.eval_str_list(args.adaptive_softmax_cutoff, adaptive_softmax_cutoff=(
type=int) if args.criterion == 'adaptive_loss' else None, options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == 'adaptive_loss' else None
),
normalization_constant=args.normalization_constant, normalization_constant=args.normalization_constant,
) )
return FConvLanguageModel(decoder) return FConvLanguageModel(decoder)
...@@ -154,12 +151,15 @@ class FConvLanguageModel(FairseqLanguageModel): ...@@ -154,12 +151,15 @@ class FConvLanguageModel(FairseqLanguageModel):
class FConvEncoder(FairseqEncoder): class FConvEncoder(FairseqEncoder):
"""Convolutional encoder""" """Convolutional encoder"""
def __init__(self, dictionary, embed_dim=512, embed_dict=None, def __init__(
max_positions=1024, convolutions=((512, 3),) * 20, dropout=0.1, self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024,
normalization_constant=0.5): convolutions=((512, 3),) * 20, dropout=0.1, normalization_constant=0.5,
left_pad=True,
):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = dropout self.dropout = dropout
self.normalization_constant = normalization_constant self.normalization_constant = normalization_constant
self.left_pad = left_pad
self.num_attention_layers = None self.num_attention_layers = None
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
...@@ -172,7 +172,7 @@ class FConvEncoder(FairseqEncoder): ...@@ -172,7 +172,7 @@ class FConvEncoder(FairseqEncoder):
max_positions, max_positions,
embed_dim, embed_dim,
self.padding_idx, self.padding_idx,
left_pad=LEFT_PAD_SOURCE, left_pad=self.left_pad,
) )
convolutions = extend_conv_spec(convolutions) convolutions = extend_conv_spec(convolutions)
...@@ -329,14 +329,18 @@ class AttentionLayer(nn.Module): ...@@ -329,14 +329,18 @@ class AttentionLayer(nn.Module):
class FConvDecoder(FairseqIncrementalDecoder): class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder""" """Convolutional decoder"""
def __init__(self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256, def __init__(
max_positions=1024, convolutions=((512, 3),) * 20, self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
attention=True, dropout=0.1, share_embed=False, positional_embeddings=True, max_positions=1024, convolutions=((512, 3),) * 20, attention=True,
adaptive_softmax_cutoff=None, normalization_constant=0.5): dropout=0.1, share_embed=False, positional_embeddings=True,
adaptive_softmax_cutoff=None, normalization_constant=0.5,
left_pad=False,
):
super().__init__(dictionary) super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2])) self.register_buffer('version', torch.Tensor([2]))
self.dropout = dropout self.dropout = dropout
self.normalization_constant = normalization_constant self.normalization_constant = normalization_constant
self.left_pad = left_pad
convolutions = extend_conv_spec(convolutions) convolutions = extend_conv_spec(convolutions)
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
...@@ -357,7 +361,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -357,7 +361,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
max_positions, max_positions,
embed_dim, embed_dim,
padding_idx, padding_idx,
left_pad=LEFT_PAD_TARGET, left_pad=self.left_pad,
) if positional_embeddings else None ) if positional_embeddings else None
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
...@@ -609,6 +613,22 @@ def base_lm_architecture(args): ...@@ -609,6 +613,22 @@ def base_lm_architecture(args):
args.normalization_constant = getattr(args, 'normalization_constant', 0.5) args.normalization_constant = getattr(args, 'normalization_constant', 0.5)
@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_wikitext103')
def fconv_lm_dauphin_wikitext103(args):
layers = '[(850, 6)] * 3'
layers += ' + [(850, 1)] * 1'
layers += ' + [(850, 5)] * 4'
layers += ' + [(850, 1)] * 1'
layers += ' + [(850, 4)] * 3'
layers += ' + [(1024, 4)] * 1'
layers += ' + [(2048, 4)] * 1'
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 280)
args.decoder_layers = getattr(args, 'decoder_layers', layers)
args.decoder_attention = getattr(args, 'decoder_attention', 'False')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,20000,200000')
base_lm_architecture(args)
@register_model_architecture('fconv', 'fconv') @register_model_architecture('fconv', 'fconv')
def base_architecture(args): def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
......
...@@ -12,7 +12,6 @@ import torch ...@@ -12,7 +12,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import ( from fairseq.modules import (
DownsampledMultiHeadAttention, GradMultiply, LearnedPositionalEmbedding, DownsampledMultiHeadAttention, GradMultiply, LearnedPositionalEmbedding,
LinearizedConvolution, LinearizedConvolution,
...@@ -78,7 +77,7 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -78,7 +77,7 @@ class FConvModelSelfAtt(FairseqModel):
help='use pretrained model when training [True, ...]') help='use pretrained model when training [True, ...]')
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, task):
trained_encoder, trained_decoder = None, None trained_encoder, trained_decoder = None, None
pretrained = eval(args.pretrained) pretrained = eval(args.pretrained)
if pretrained: if pretrained:
...@@ -86,8 +85,7 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -86,8 +85,7 @@ class FConvModelSelfAtt(FairseqModel):
trained_model = utils.load_ensemble_for_inference( trained_model = utils.load_ensemble_for_inference(
# not actually for inference, but loads pretrained model parameters # not actually for inference, but loads pretrained model parameters
filenames=[args.pretrained_checkpoint], filenames=[args.pretrained_checkpoint],
src_dict=src_dict, task=task,
dst_dict=dst_dict,
)[0][0] )[0][0]
trained_decoder = list(trained_model.children())[1] trained_decoder = list(trained_model.children())[1]
trained_encoder = list(trained_model.children())[0] trained_encoder = list(trained_model.children())[0]
...@@ -100,7 +98,7 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -100,7 +98,7 @@ class FConvModelSelfAtt(FairseqModel):
"""Build a new model instance.""" """Build a new model instance."""
encoder = FConvEncoder( encoder = FConvEncoder(
src_dict, task.source_dictionary,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers), convolutions=eval(args.encoder_layers),
dropout=args.dropout, dropout=args.dropout,
...@@ -110,7 +108,7 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -110,7 +108,7 @@ class FConvModelSelfAtt(FairseqModel):
) )
decoder = FConvDecoder( decoder = FConvDecoder(
dst_dict, task.target_dictionary,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers), convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim, out_embed_dim=args.decoder_out_embed_dim,
...@@ -140,11 +138,12 @@ class FConvEncoder(FairseqEncoder): ...@@ -140,11 +138,12 @@ class FConvEncoder(FairseqEncoder):
def __init__( def __init__(
self, dictionary, embed_dim=512, max_positions=1024, self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, attention=False, convolutions=((512, 3),) * 20, dropout=0.1, attention=False,
attention_nheads=1, attention_nheads=1, left_pad=True,
): ):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = dropout self.dropout = dropout
self.num_attention_layers = None self.num_attention_layers = None
self.left_pad = left_pad
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad() self.padding_idx = dictionary.pad()
...@@ -153,7 +152,7 @@ class FConvEncoder(FairseqEncoder): ...@@ -153,7 +152,7 @@ class FConvEncoder(FairseqEncoder):
max_positions, max_positions,
embed_dim, embed_dim,
self.padding_idx, self.padding_idx,
left_pad=LEFT_PAD_SOURCE, left_pad=self.left_pad,
) )
def expand_bool_array(val): def expand_bool_array(val):
...@@ -239,13 +238,14 @@ class FConvDecoder(FairseqDecoder): ...@@ -239,13 +238,14 @@ class FConvDecoder(FairseqDecoder):
convolutions=((512, 3),) * 8, attention=True, dropout=0.1, convolutions=((512, 3),) * 8, attention=True, dropout=0.1,
selfattention=False, attention_nheads=1, selfattention_nheads=1, selfattention=False, attention_nheads=1, selfattention_nheads=1,
project_input=False, gated_attention=False, downsample=False, project_input=False, gated_attention=False, downsample=False,
pretrained=False, trained_decoder=None, pretrained=False, trained_decoder=None, left_pad=False,
): ):
super().__init__(dictionary) super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2])) self.register_buffer('version', torch.Tensor([2]))
self.pretrained = pretrained self.pretrained = pretrained
self.pretrained_decoder = trained_decoder self.pretrained_decoder = trained_decoder
self.dropout = dropout self.dropout = dropout
self.left_pad = left_pad
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
def expand_bool_array(val): def expand_bool_array(val):
...@@ -269,7 +269,7 @@ class FConvDecoder(FairseqDecoder): ...@@ -269,7 +269,7 @@ class FConvDecoder(FairseqDecoder):
max_positions, max_positions,
embed_dim, embed_dim,
padding_idx, padding_idx,
left_pad=LEFT_PAD_TARGET, left_pad=self.left_pad,
) )
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
......
...@@ -11,7 +11,6 @@ import torch.nn as nn ...@@ -11,7 +11,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import options, utils from fairseq import options, utils
from fairseq.data import consts
from . import ( from . import (
FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model,
...@@ -63,7 +62,7 @@ class LSTMModel(FairseqModel): ...@@ -63,7 +62,7 @@ class LSTMModel(FairseqModel):
help='dropout probability for decoder output') help='dropout probability for decoder output')
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones) # make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args) base_architecture(args)
...@@ -79,14 +78,14 @@ class LSTMModel(FairseqModel): ...@@ -79,14 +78,14 @@ class LSTMModel(FairseqModel):
pretrained_encoder_embed = None pretrained_encoder_embed = None
if args.encoder_embed_path: if args.encoder_embed_path:
pretrained_encoder_embed = load_pretrained_embedding_from_file( pretrained_encoder_embed = load_pretrained_embedding_from_file(
args.encoder_embed_path, src_dict, args.encoder_embed_dim) args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim)
pretrained_decoder_embed = None pretrained_decoder_embed = None
if args.decoder_embed_path: if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file( pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path, dst_dict, args.decoder_embed_dim) args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim)
encoder = LSTMEncoder( encoder = LSTMEncoder(
dictionary=src_dict, dictionary=task.source_dictionary,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
hidden_size=args.encoder_hidden_size, hidden_size=args.encoder_hidden_size,
num_layers=args.encoder_layers, num_layers=args.encoder_layers,
...@@ -96,7 +95,7 @@ class LSTMModel(FairseqModel): ...@@ -96,7 +95,7 @@ class LSTMModel(FairseqModel):
pretrained_embed=pretrained_encoder_embed, pretrained_embed=pretrained_encoder_embed,
) )
decoder = LSTMDecoder( decoder = LSTMDecoder(
dictionary=dst_dict, dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
hidden_size=args.decoder_hidden_size, hidden_size=args.decoder_hidden_size,
out_embed_dim=args.decoder_out_embed_dim, out_embed_dim=args.decoder_out_embed_dim,
...@@ -114,11 +113,9 @@ class LSTMModel(FairseqModel): ...@@ -114,11 +113,9 @@ class LSTMModel(FairseqModel):
class LSTMEncoder(FairseqEncoder): class LSTMEncoder(FairseqEncoder):
"""LSTM encoder.""" """LSTM encoder."""
def __init__( def __init__(
self, dictionary, embed_dim=512, hidden_size=512, num_layers=1, self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
dropout_in=0.1, dropout_out=0.1, bidirectional=False, dropout_in=0.1, dropout_out=0.1, bidirectional=False,
left_pad_source=consts.LEFT_PAD_SOURCE, left_pad=True, pretrained_embed=None, padding_value=0.,
pretrained_embed=None,
padding_value=0.,
): ):
super().__init__(dictionary) super().__init__(dictionary)
self.num_layers = num_layers self.num_layers = num_layers
...@@ -141,7 +138,7 @@ class LSTMEncoder(FairseqEncoder): ...@@ -141,7 +138,7 @@ class LSTMEncoder(FairseqEncoder):
dropout=self.dropout_out, dropout=self.dropout_out,
bidirectional=bidirectional, bidirectional=bidirectional,
) )
self.left_pad_source = left_pad_source self.left_pad = left_pad
self.padding_value = padding_value self.padding_value = padding_value
self.output_units = hidden_size self.output_units = hidden_size
...@@ -149,7 +146,7 @@ class LSTMEncoder(FairseqEncoder): ...@@ -149,7 +146,7 @@ class LSTMEncoder(FairseqEncoder):
self.output_units *= 2 self.output_units *= 2
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
if self.left_pad_source: if self.left_pad:
# convert left-padding to right-padding # convert left-padding to right-padding
src_tokens = utils.convert_padding_direction( src_tokens = utils.convert_padding_direction(
src_tokens, src_tokens,
...@@ -248,10 +245,9 @@ class AttentionLayer(nn.Module): ...@@ -248,10 +245,9 @@ class AttentionLayer(nn.Module):
class LSTMDecoder(FairseqIncrementalDecoder): class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder.""" """LSTM decoder."""
def __init__( def __init__(
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_embed_dim=512, encoder_output_units=512, encoder_embed_dim=512, encoder_output_units=512, pretrained_embed=None,
pretrained_embed=None,
): ):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout_in = dropout_in self.dropout_in = dropout_in
......
...@@ -11,7 +11,6 @@ import torch ...@@ -11,7 +11,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import ( from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding, SinusoidalPositionalEmbedding,
...@@ -68,8 +67,9 @@ class TransformerModel(FairseqModel): ...@@ -68,8 +67,9 @@ class TransformerModel(FairseqModel):
' (requires shared dictionary and embed dim)') ' (requires shared dictionary and embed dim)')
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim): def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
...@@ -77,7 +77,7 @@ class TransformerModel(FairseqModel): ...@@ -77,7 +77,7 @@ class TransformerModel(FairseqModel):
return Embedding(num_embeddings, embed_dim, padding_idx) return Embedding(num_embeddings, embed_dim, padding_idx)
if args.share_all_embeddings: if args.share_all_embeddings:
if src_dict != dst_dict: if src_dict != tgt_dict:
raise RuntimeError('--share-all-embeddings requires a joined dictionary') raise RuntimeError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim: if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError( raise RuntimeError(
...@@ -87,17 +87,17 @@ class TransformerModel(FairseqModel): ...@@ -87,17 +87,17 @@ class TransformerModel(FairseqModel):
args.share_decoder_input_output_embed = True args.share_decoder_input_output_embed = True
else: else:
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim) encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = build_embedding(dst_dict, args.decoder_embed_dim) decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim)
encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens) encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
decoder = TransformerDecoder(args, dst_dict, decoder_embed_tokens) decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)
return TransformerModel(encoder, decoder) return TransformerModel(encoder, decoder)
class TransformerEncoder(FairseqEncoder): class TransformerEncoder(FairseqEncoder):
"""Transformer encoder.""" """Transformer encoder."""
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens, left_pad=True):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = args.dropout self.dropout = args.dropout
...@@ -108,7 +108,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -108,7 +108,7 @@ class TransformerEncoder(FairseqEncoder):
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx, 1024, embed_dim, self.padding_idx,
left_pad=LEFT_PAD_SOURCE, left_pad=left_pad,
learned=args.encoder_learned_pos, learned=args.encoder_learned_pos,
) )
...@@ -157,7 +157,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -157,7 +157,7 @@ class TransformerEncoder(FairseqEncoder):
class TransformerDecoder(FairseqIncrementalDecoder): class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder.""" """Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens, left_pad=False):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = args.dropout self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed self.share_input_output_embed = args.share_decoder_input_output_embed
...@@ -169,7 +169,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -169,7 +169,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx, 1024, embed_dim, padding_idx,
left_pad=LEFT_PAD_TARGET, left_pad=left_pad,
learned=args.decoder_learned_pos, learned=args.decoder_learned_pos,
) )
......
...@@ -48,7 +48,7 @@ class FixedSchedule(FairseqLRScheduler): ...@@ -48,7 +48,7 @@ class FixedSchedule(FairseqLRScheduler):
def step_update(self, num_updates): def step_update(self, num_updates):
"""Update the learning rate after each update.""" """Update the learning rate after each update."""
if num_updates <= self.args.warmup_updates: if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates:
self.warmup_factor = num_updates / float(self.args.warmup_updates) self.warmup_factor = num_updates / float(self.args.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr) self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr() return self.optimizer.get_lr()
...@@ -13,10 +13,11 @@ from fairseq.criterions import CRITERION_REGISTRY ...@@ -13,10 +13,11 @@ from fairseq.criterions import CRITERION_REGISTRY
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY from fairseq.optim import OPTIMIZER_REGISTRY
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
from fairseq.tasks import TASK_REGISTRY
def get_training_parser(): def get_training_parser(default_task='translation'):
parser = get_parser('Trainer') parser = get_parser('Trainer', default_task)
add_dataset_args(parser, train=True) add_dataset_args(parser, train=True)
add_distributed_training_args(parser) add_distributed_training_args(parser)
add_model_args(parser) add_model_args(parser)
...@@ -25,8 +26,8 @@ def get_training_parser(): ...@@ -25,8 +26,8 @@ def get_training_parser():
return parser return parser
def get_generation_parser(interactive=False): def get_generation_parser(interactive=False, default_task='translation'):
parser = get_parser('Generation') parser = get_parser('Generation', default_task)
add_dataset_args(parser, gen=True) add_dataset_args(parser, gen=True)
add_generation_args(parser) add_generation_args(parser)
if interactive: if interactive:
...@@ -34,8 +35,8 @@ def get_generation_parser(interactive=False): ...@@ -34,8 +35,8 @@ def get_generation_parser(interactive=False):
return parser return parser
def get_eval_lm_parser(): def get_eval_lm_parser(default_task='language_modeling'):
parser = get_parser('Evaluate Language Model') parser = get_parser('Evaluate Language Model', default_task)
add_dataset_args(parser, gen=True) add_dataset_args(parser, gen=True)
add_eval_lm_args(parser) add_eval_lm_args(parser)
return parser return parser
...@@ -85,6 +86,8 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -85,6 +86,8 @@ def parse_args_and_arch(parser, input_args=None):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser) OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
if hasattr(args, 'lr_scheduler'): if hasattr(args, 'lr_scheduler'):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser) LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
if hasattr(args, 'task'):
TASK_REGISTRY[args.task].add_args(parser)
# Parse a second time. # Parse a second time.
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -104,7 +107,7 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -104,7 +107,7 @@ def parse_args_and_arch(parser, input_args=None):
return args return args
def get_parser(desc): def get_parser(desc, default_task='translation'):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc) description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc)
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
...@@ -114,34 +117,24 @@ def get_parser(desc): ...@@ -114,34 +117,24 @@ def get_parser(desc):
choices=['json', 'none', 'simple', 'tqdm']) choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N', parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed') help='pseudo random number generator seed')
# Task definitions can be found under fairseq/tasks/
parser.add_argument(
'--task', metavar='TASK', default=default_task, choices=TASK_REGISTRY.keys(),
help='task: {} (default: {})'.format(', '.join(TASK_REGISTRY.keys()), default_task)
)
return parser return parser
def add_dataset_args(parser, train=False, gen=False): def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading') group = parser.add_argument_group('Dataset and data loading')
group.add_argument('data', metavar='DIR',
help='path to data directory')
group.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
group.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
group.add_argument('--max-target-positions', '--tokens-per-sample', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='ignore too long or too short lines in valid and test set') help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N', group.add_argument('--max-tokens', type=int, metavar='N',
help='maximum number of tokens in a batch') help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch') help='maximum number of sentences in a batch')
group.add_argument('--sample-break-mode', metavar='VAL',
choices=['none', 'complete', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample'
' tokens. If set to "complete", splits samples only at the end'
' of sentence, but may include multiple sentences per sample.'
' If set to "eos", includes only one sentence per sample.')
if train: if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT', group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'], choices=['train', 'valid', 'test'],
...@@ -152,10 +145,6 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -152,10 +145,6 @@ def add_dataset_args(parser, train=False, gen=False):
group.add_argument('--max-sentences-valid', type=int, metavar='N', group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch' help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)') ' (defaults to --max-sentences)')
group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the'
' dataset')
if gen: if gen:
group.add_argument('--gen-subset', default='test', metavar='SPLIT', group.add_argument('--gen-subset', default='test', metavar='SPLIT',
help='data subset to generate (train, valid, test)') help='data subset to generate (train, valid, test)')
......
...@@ -14,7 +14,7 @@ from fairseq.models import FairseqIncrementalDecoder ...@@ -14,7 +14,7 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object): class SequenceGenerator(object):
def __init__( def __init__(
self, models, beam_size=1, minlen=1, maxlen=None, stop_early=True, self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False, normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False,
sampling=False, sampling_topk=-1, sampling_temperature=1, sampling=False, sampling_topk=-1, sampling_temperature=1,
): ):
...@@ -28,13 +28,10 @@ class SequenceGenerator(object): ...@@ -28,13 +28,10 @@ class SequenceGenerator(object):
normalize_scores: Normalize scores by the length of the output. normalize_scores: Normalize scores by the length of the output.
""" """
self.models = models self.models = models
self.pad = models[0].dst_dict.pad() self.pad = tgt_dict.pad()
self.unk = models[0].dst_dict.unk() self.unk = tgt_dict.unk()
self.eos = models[0].dst_dict.eos() self.eos = tgt_dict.eos()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:]) self.vocab_size = len(tgt_dict)
assert all(m.dst_dict.unk() == self.unk for m in self.models[1:])
assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size self.beam_size = beam_size
self.minlen = minlen self.minlen = minlen
max_decoder_len = min(m.max_decoder_positions() for m in self.models) max_decoder_len = min(m.max_decoder_positions() for m in self.models)
...@@ -70,6 +67,8 @@ class SequenceGenerator(object): ...@@ -70,6 +67,8 @@ class SequenceGenerator(object):
for sample in data_itr: for sample in data_itr:
s = utils.make_variable(sample, volatile=True, cuda=cuda) s = utils.make_variable(sample, volatile=True, cuda=cuda)
if 'net_input' not in s:
continue
input = s['net_input'] input = s['net_input']
srclen = input['src_tokens'].size(1) srclen = input['src_tokens'].size(1)
if timer is not None: if timer is not None:
......
...@@ -11,10 +11,9 @@ from fairseq import utils ...@@ -11,10 +11,9 @@ from fairseq import utils
class SequenceScorer(object): class SequenceScorer(object):
"""Scores the target for a given source sentence.""" """Scores the target for a given source sentence."""
def __init__(self, models): def __init__(self, models, tgt_dict):
self.models = models self.models = models
self.pad = models[0].dst_dict.pad() self.pad = tgt_dict.pad()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
def cuda(self): def cuda(self):
for model in self.models: for model in self.models:
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_task import FairseqTask
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
def setup_task(args):
return TASK_REGISTRY[args.task].setup_task(args)
def register_task(name):
"""Decorator to register a new task."""
def register_task_cls(cls):
if name in TASK_REGISTRY:
raise ValueError('Cannot register duplicate task ({})'.format(name))
if not issubclass(cls, FairseqTask):
raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__))
if cls.__name__ in TASK_CLASS_NAMES:
raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__))
TASK_REGISTRY[name] = cls
TASK_CLASS_NAMES.add(cls.__name__)
return cls
return register_task_cls
# automatically import any Python files in the tasks/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.tasks.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import criterions, models
from fairseq.data import FairseqDataset
class FairseqTask(object):
"""
A Task defines the data format, stores shared state (e.g., dictionaries) and
provides helpers for building the model/criterion and calculating the loss.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
pass
def __init__(self, args):
self.args = args
self.datasets = {}
@classmethod
def setup_task(cls, args, **kwargs):
raise NotImplementedError
def load_dataset(self, split):
raise NotImplementedError
def dataset(self, split):
"""Return a dataset split."""
if split not in self.datasets:
raise KeyError('Dataset not loaded: ' + split)
if not isinstance(self.datasets[split], FairseqDataset):
raise TypeError('Datasets are expected to be of type FairseqDataset')
return self.datasets[split]
def build_model(self, args):
return models.build_model(args, self)
def build_criterion(self, args):
return criterions.build_criterion(args, self)
def get_loss(self, model, criterion, sample):
return criterion(model, sample)
@property
def source_dictionary(self):
raise NotImplementedError
@property
def target_dictionary(self):
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset,
)
from . import FairseqTask, register_task
@register_task('language_modeling')
class LanguageModelingTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('--sample-break-mode', metavar='VAL',
choices=['none', 'complete', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
'of sentence, but may include multiple sentences per sample. '
'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int, metavar='N',
help='max number of tokens per sample for LM dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
@classmethod
def setup_task(cls, args, **kwargs):
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split):
"""Load a dataset split."""
path = os.path.join(self.args.data, split)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = ds.tokens_list
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path)
tokens = ds.buffer
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
dataset = TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True, # return next tokens as targets
)
self.datasets[split] = MonolingualDataset(dataset, dataset.sizes, self.dictionary, shuffle=False)
@property
def target_dictionary(self):
return self.dictionary
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq import options
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
IndexedRawTextDataset,
)
from . import FairseqTask, register_task
@register_task('translation')
class TranslationTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left (default: True)')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left (default: False)')
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, args, **kwargs):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
# find language pair automatically
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data)
if args.source_lang is None or args.target_lang is None:
raise Exception('Could not infer language pair, please provide it explicitly')
# load dictionaries
src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))
return cls(args, src_dict, tgt_dict)
def load_dataset(self, split):
"""Load a dataset split."""
def split_exists(src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
return True
return False
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path):
return IndexedInMemoryDataset(path)
return None
src_dataset = indexed_dataset(prefix + src, self.src_dict)
tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict)
self.datasets[split] = LanguagePairDataset(
src_dataset, src_dataset.sizes, self.src_dict,
tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
)
@property
def source_dictionary(self):
return self.src_dict
@property
def target_dictionary(self):
return self.tgt_dict
...@@ -27,7 +27,7 @@ class Trainer(object): ...@@ -27,7 +27,7 @@ class Trainer(object):
torch.distributed.all_reduce. torch.distributed.all_reduce.
""" """
def __init__(self, args, model, criterion): def __init__(self, args, task, model, criterion):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported') raise NotImplementedError('Training on CPU is not supported')
...@@ -35,6 +35,7 @@ class Trainer(object): ...@@ -35,6 +35,7 @@ class Trainer(object):
self.args = args self.args = args
# copy model and criterion to current device # copy model and criterion to current device
self.task = task
self.model = model.cuda() self.model = model.cuda()
self.criterion = criterion.cuda() self.criterion = criterion.cuda()
...@@ -67,6 +68,7 @@ class Trainer(object): ...@@ -67,6 +68,7 @@ class Trainer(object):
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters
utils.save_state( utils.save_state(
filename, self.args, self.model, self.criterion, self.optimizer, filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state, self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
...@@ -90,6 +92,10 @@ class Trainer(object): ...@@ -90,6 +92,10 @@ class Trainer(object):
self._num_updates = last_optim['num_updates'] self._num_updates = last_optim['num_updates']
if 'train_meters' in extra_state:
self.meters = extra_state['train_meters']
del extra_state['train_meters']
return extra_state return extra_state
def train_step(self, sample, update_params=True): def train_step(self, sample, update_params=True):
...@@ -99,9 +105,14 @@ class Trainer(object): ...@@ -99,9 +105,14 @@ class Trainer(object):
# initialize optimizer and LR scheduler if hasn't been loaded from the checkpoint # initialize optimizer and LR scheduler if hasn't been loaded from the checkpoint
self._build_optimizer() self._build_optimizer()
sample = self._prepare_sample(sample, volatile=False) # Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# forward and backward pass # forward and backward pass
sample = self._prepare_sample(sample, volatile=False)
loss, sample_size, logging_output, oom_fwd = self._forward(sample) loss, sample_size, logging_output, oom_fwd = self._forward(sample)
oom_bwd = self._backward(loss) oom_bwd = self._backward(loss)
...@@ -182,7 +193,7 @@ class Trainer(object): ...@@ -182,7 +193,7 @@ class Trainer(object):
try: try:
with utils.maybe_no_grad(eval): with utils.maybe_no_grad(eval):
# calculate loss and sample size # calculate loss and sample size
loss, sample_size, logging_output_ = self.criterion(self.model, sample) loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
logging_output.update(logging_output_) logging_output.update(logging_output_)
except RuntimeError as e: except RuntimeError as e:
if not eval and 'out of memory' in str(e): if not eval and 'out of memory' in str(e):
...@@ -311,6 +322,10 @@ class Trainer(object): ...@@ -311,6 +322,10 @@ class Trainer(object):
"""Adjust the learning rate based on the validation loss.""" """Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss) return self.lr_scheduler.step(epoch, val_loss)
def lr_step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.lr_scheduler.step_update(num_updates)
def get_lr(self): def get_lr(self):
"""Get the current learning rate.""" """Get the current learning rate."""
return self.optimizer.get_lr() return self.optimizer.get_lr()
......
...@@ -72,7 +72,7 @@ def load_model_state(filename, model): ...@@ -72,7 +72,7 @@ def load_model_state(filename, model):
# load model parameters # load model parameters
try: try:
model.load_state_dict(state['model']) model.load_state_dict(state['model'], strict=True)
except Exception: except Exception:
raise Exception('Cannot load model parameters from checkpoint, ' raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match') 'please ensure that the architectures match')
...@@ -120,23 +120,26 @@ def _upgrade_state_dict(state): ...@@ -120,23 +120,26 @@ def _upgrade_state_dict(state):
# keep track of number of updates # keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]: if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0 state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': 0,
}
return state return state
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
data_dir=None, model_arg_overrides=None):
"""Load an ensemble of models for inference. """Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
model_arg_overrides allows you to pass a dictionary model_arg_overrides -- model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model {'arg_name': arg} -- to override model args that were used during model
training training
""" """
from fairseq import models
from fairseq.data import data_utils
# load model architectures and weights # load model architectures and weights
states = [] states = []
for filename in filenames: for filename in filenames:
...@@ -149,14 +152,10 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, ...@@ -149,14 +152,10 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
if model_arg_overrides is not None: if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides) args = _override_model_args(args, model_arg_overrides)
if src_dict is None or dst_dict is None:
assert data_dir is not None
src_dict, dst_dict = data_utils.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble # build ensemble
ensemble = [] ensemble = []
for state in states: for state in states:
model = models.build_model(args, src_dict, dst_dict) model = task.build_model(args)
model.upgrade_state_dict(state['model']) model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True) model.load_state_dict(state['model'], strict=True)
ensemble.append(model) ensemble.append(model)
...@@ -308,15 +307,15 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk): ...@@ -308,15 +307,15 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
return ' '.join(hypo_tokens) return ' '.join(hypo_tokens)
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dict, remove_bpe): def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe):
from fairseq import tokenizer from fairseq import tokenizer
hypo_str = dst_dict.string(hypo_tokens, remove_bpe) hypo_str = tgt_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None: if align_dict is not None:
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, dst_dict.unk_string()) hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string())
if align_dict is not None or remove_bpe is not None: if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE # Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method. # Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, dst_dict, add_if_not_exist=True) hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment return hypo_tokens, hypo_str, alignment
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
import torch import torch
from fairseq import bleu, options, progress_bar, tokenizer, utils from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils
from fairseq.data import data_utils, data_loaders
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
...@@ -17,65 +16,67 @@ from fairseq.sequence_scorer import SequenceScorer ...@@ -17,65 +16,67 @@ from fairseq.sequence_scorer import SequenceScorer
def main(args): def main(args):
assert args.path is not None, '--path required for generation!' assert args.path is not None, '--path required for generation!'
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'
if args.max_tokens is None and args.max_sentences is None: if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000 args.max_tokens = 12000
print(args) print(args)
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset # Load dataset splits
dataset = data_loaders.load_dataset(args, [args.gen_subset], args.replace_unk is not None) task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(args.path)) print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(',') models, _ = utils.load_ensemble_for_inference([args.path], task)
models, _ = utils.load_ensemble_for_inference(model_paths, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_( model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
)
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk) align_dict = utils.load_align_dict(args.replace_unk)
# Load dataset (possibly sharded)
max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.eval_dataloader( # Load dataset (possibly sharded)
args.gen_subset, itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
max_sentences=args.max_sentences, max_sentences=args.max_sentences,
max_positions=max_positions, max_positions=models[0].max_positions(),
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
) required_batch_size_multiple=8,
itr = data_utils.ShardedIterator(itr, args.num_shards, args.shard_id) num_shards=args.num_shards,
shard_id=args.shard_id,
).next_epoch_itr(shuffle=False)
# Initialize generator # Initialize generator
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
if args.score_reference: if args.score_reference:
translator = SequenceScorer(models) translator = SequenceScorer(models, task.target_dictionary)
else: else:
translator = SequenceGenerator( translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop), models, task.target_dictionary, beam_size=args.beam,
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk, len_penalty=args.lenpen, unk_penalty=args.unkpen,
minlen=args.min_len) sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
)
if use_cuda: if use_cuda:
translator.cuda() translator.cuda()
# Generate and compute BLEU score # Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk()) scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
num_sentences = 0 num_sentences = 0
has_target = True has_target = True
with progress_bar.build_progress_bar(args, itr) as t: with progress_bar.build_progress_bar(args, itr) as t:
...@@ -84,7 +85,9 @@ def main(args): ...@@ -84,7 +85,9 @@ def main(args):
else: else:
translations = translator.generate_batched_itr( translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size) cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
)
wps_meter = TimeMeter() wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations: for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth # Process input and ground truth
...@@ -93,12 +96,12 @@ def main(args): ...@@ -93,12 +96,12 @@ def main(args):
# Either retrieve the original sentences or regenerate them from tokens. # Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None: if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id) src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id) target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
else: else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe) src_str = src_dict.string(src_tokens, args.remove_bpe)
if has_target: if has_target:
target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True) target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet: if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str)) print('S-{}\t{}'.format(sample_id, src_str))
...@@ -112,7 +115,7 @@ def main(args): ...@@ -112,7 +115,7 @@ def main(args):
src_str=src_str, src_str=src_str,
alignment=hypo['alignment'].int().cpu(), alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict, align_dict=align_dict,
dst_dict=dataset.dst_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
) )
...@@ -135,7 +138,7 @@ def main(args): ...@@ -135,7 +138,7 @@ def main(args):
if align_dict is not None or args.remove_bpe is not None: if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE # Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize( target_tokens = tokenizer.Tokenizer.tokenize(
target_str, dataset.dst_dict, add_if_not_exist=True) target_str, tgt_dict, add_if_not_exist=True)
scorer.add(target_tokens, hypo_tokens) scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0)) wps_meter.update(src_tokens.size(0))
......
...@@ -6,17 +6,17 @@ ...@@ -6,17 +6,17 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import namedtuple
import numpy as np import numpy as np
import sys import sys
import torch import torch
from collections import namedtuple
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import options, tokenizer, utils from fairseq import data, options, tasks, tokenizer, utils
from fairseq.data.data_utils import collate_tokens
from fairseq.data.consts import LEFT_PAD_SOURCE
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
Batch = namedtuple('Batch', 'srcs tokens lengths') Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos alignments') Translation = namedtuple('Translation', 'src_str hypos alignments')
...@@ -33,44 +33,52 @@ def buffered_read(buffer_size): ...@@ -33,44 +33,52 @@ def buffered_read(buffer_size):
yield buffer yield buffer
def make_batches(lines, batch_size, src_dict): def make_batches(lines, args, src_dict, max_positions):
tokens = [tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long() for src_str in lines] tokens = [
lengths = [t.numel() for t in tokens] tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
for src_str in lines
indices = np.argsort(lengths) ]
num_batches = np.ceil(len(indices) / batch_size) lengths = np.array([t.numel() for t in tokens])
batches = np.array_split(indices, num_batches) itr = data.EpochBatchIterator(
for batch_idxs in batches: dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
batch_toks = [tokens[i] for i in batch_idxs] max_tokens=args.max_tokens,
batch_toks = collate_tokens(batch_toks, src_dict.pad(), src_dict.eos(), LEFT_PAD_SOURCE, max_sentences=args.max_sentences,
move_eos_to_beginning=False) max_positions=max_positions,
).next_epoch_itr(shuffle=False)
for batch in itr:
yield Batch( yield Batch(
srcs=[lines[i] for i in batch_idxs], srcs=[lines[i] for i in batch['id']],
tokens=batch_toks, tokens=batch['net_input']['src_tokens'],
lengths=tokens[0].new([lengths[i] for i in batch_idxs]), lengths=batch['net_input']['src_lengths'],
), batch_idxs ), batch['id']
def main(args): def main(args):
print(args) if args.buffer_size < 1:
args.buffer_size = 1
if args.max_tokens is None and args.max_sentences is None:
args.max_sentences = 1
assert not args.sampling or args.nbest == args.beam, \ assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam' '--sampling requires --nbest to be equal to --beam'
assert not args.max_sentences or args.max_sentences <= args.buffer_size, \ assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
'--max-sentences/--batch-size cannot be larger than --buffer-size' '--max-sentences/--batch-size cannot be larger than --buffer-size'
if args.buffer_size < 1: print(args)
args.buffer_size = 1
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Setup task, e.g., translation
task = tasks.setup_task(args)
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(args.path)) print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(',') model_paths = args.path.split(',')
models, model_args = utils.load_ensemble_for_inference(model_paths, data_dir=args.data) models, model_args = utils.load_ensemble_for_inference(model_paths, task)
src_dict, dst_dict = models[0].src_dict, models[0].dst_dict
print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict))) # Set dictionaries
print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict))) src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
...@@ -80,10 +88,11 @@ def main(args): ...@@ -80,10 +88,11 @@ def main(args):
# Initialize generator # Initialize generator
translator = SequenceGenerator( translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop), models, tgt_dict, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk, unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk,
minlen=args.min_len) minlen=args.min_len,
)
if use_cuda: if use_cuda:
translator.cuda() translator.cuda()
...@@ -106,7 +115,7 @@ def main(args): ...@@ -106,7 +115,7 @@ def main(args):
src_str=src_str, src_str=src_str,
alignment=hypo['alignment'].int().cpu(), alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict, align_dict=align_dict,
dst_dict=dst_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
) )
result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str)) result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
...@@ -135,7 +144,7 @@ def main(args): ...@@ -135,7 +144,7 @@ def main(args):
for inputs in buffered_read(args.buffer_size): for inputs in buffered_read(args.buffer_size):
indices = [] indices = []
results = [] results = []
for batch, batch_indices in make_batches(inputs, max(1, args.max_sentences or 1), src_dict): for batch, batch_indices in make_batches(inputs, args, src_dict, models[0].max_positions()):
indices.extend(batch_indices) indices.extend(batch_indices)
results += process_batch(batch) results += process_batch(batch)
......
...@@ -148,6 +148,7 @@ def train_translation_model(data_dir, arch, extra_flags=None): ...@@ -148,6 +148,7 @@ def train_translation_model(data_dir, arch, extra_flags=None):
train_args = options.parse_args_and_arch( train_args = options.parse_args_and_arch(
train_parser, train_parser,
[ [
'--task', 'translation',
data_dir, data_dir,
'--save-dir', data_dir, '--save-dir', data_dir,
'--arch', arch, '--arch', arch,
...@@ -166,15 +167,18 @@ def train_translation_model(data_dir, arch, extra_flags=None): ...@@ -166,15 +167,18 @@ def train_translation_model(data_dir, arch, extra_flags=None):
def generate_main(data_dir): def generate_main(data_dir):
generate_parser = options.get_generation_parser() generate_parser = options.get_generation_parser()
generate_args = generate_parser.parse_args([ generate_args = options.parse_args_and_arch(
data_dir, generate_parser,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'), [
'--beam', '3', data_dir,
'--batch-size', '64', '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--max-len-b', '5', '--beam', '3',
'--gen-subset', 'valid', '--batch-size', '64',
'--no-progress-bar', '--max-len-b', '5',
]) '--gen-subset', 'valid',
'--no-progress-bar',
],
)
# evaluate model in batch mode # evaluate model in batch mode
generate.main(generate_args) generate.main(generate_args)
...@@ -205,6 +209,7 @@ def train_language_model(data_dir, arch): ...@@ -205,6 +209,7 @@ def train_language_model(data_dir, arch):
train_args = options.parse_args_and_arch( train_args = options.parse_args_and_arch(
train_parser, train_parser,
[ [
'--task', 'language_modeling',
data_dir, data_dir,
'--arch', arch, '--arch', arch,
'--optimizer', 'nag', '--optimizer', 'nag',
...@@ -214,7 +219,7 @@ def train_language_model(data_dir, arch): ...@@ -214,7 +219,7 @@ def train_language_model(data_dir, arch):
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]', '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280', '--decoder-embed-dim', '280',
'--max-tokens', '500', '--max-tokens', '500',
'--max-target-positions', '500', '--tokens-per-sample', '500',
'--save-dir', data_dir, '--save-dir', data_dir,
'--max-epoch', '1', '--max-epoch', '1',
'--no-progress-bar', '--no-progress-bar',
...@@ -226,11 +231,14 @@ def train_language_model(data_dir, arch): ...@@ -226,11 +231,14 @@ def train_language_model(data_dir, arch):
def eval_lm_main(data_dir): def eval_lm_main(data_dir):
eval_lm_parser = options.get_eval_lm_parser() eval_lm_parser = options.get_eval_lm_parser()
eval_lm_args = eval_lm_parser.parse_args([ eval_lm_args = options.parse_args_and_arch(
data_dir, eval_lm_parser,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'), [
'--no-progress-bar', data_dir,
]) '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--no-progress-bar',
],
)
eval_lm.main(eval_lm_args) eval_lm.main(eval_lm_args)
......
...@@ -5,6 +5,25 @@ ...@@ -5,6 +5,25 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
# padding constants import unittest
LEFT_PAD_SOURCE = True
LEFT_PAD_TARGET = False from fairseq.data import data_utils
class TestDataUtils(unittest.TestCase):
def test_counting_iterator(self):
x = list(range(10))
itr = data_utils.CountingIterator(x)
self.assertTrue(itr.has_next())
self.assertEqual(next(itr), 0)
self.assertEqual(next(itr), 1)
itr.skip(3)
self.assertEqual(next(itr), 5)
itr.skip(3)
self.assertEqual(next(itr), 9)
self.assertFalse(itr.has_next())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment