Commit 84b82dc6 authored by Myle Ott's avatar Myle Ott
Browse files

Simplify deps of build_model to only depend on dict (instead of dataset)

parent d646a4a8
...@@ -15,12 +15,18 @@ from fairseq.modules import BeamableMM, LinearizedConvolution ...@@ -15,12 +15,18 @@ from fairseq.modules import BeamableMM, LinearizedConvolution
class FConvModel(nn.Module): class FConvModel(nn.Module):
def __init__(self, encoder, decoder, padding_idx=1): def __init__(self, encoder, decoder):
super(FConvModel, self).__init__() super(FConvModel, self).__init__()
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self.encoder.num_attention_layers = sum([layer is not None for layer in decoder.attention]) self.encoder.num_attention_layers = sum([layer is not None for layer in decoder.attention])
self.padding_idx = padding_idx
self._is_generation_fast = False self._is_generation_fast = False
def forward(self, src_tokens, src_positions, input_tokens, input_positions): def forward(self, src_tokens, src_positions, input_tokens, input_positions):
...@@ -67,11 +73,15 @@ class FConvModel(nn.Module): ...@@ -67,11 +73,15 @@ class FConvModel(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
"""Convolutional encoder""" """Convolutional encoder"""
def __init__(self, num_embeddings, embed_dim=512, max_positions=1024, def __init__(self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, padding_idx=1): convolutions=((512, 3),) * 20, dropout=0.1):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
self.num_attention_layers = None self.num_attention_layers = None
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx) self.embed_positions = Embedding(max_positions, embed_dim, padding_idx)
...@@ -160,10 +170,11 @@ class AttentionLayer(nn.Module): ...@@ -160,10 +170,11 @@ class AttentionLayer(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
"""Convolutional decoder""" """Convolutional decoder"""
def __init__(self, num_embeddings, embed_dim=512, out_embed_dim=256, def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1, padding_idx=1): attention=True, dropout=0.1):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
...@@ -171,8 +182,11 @@ class Decoder(nn.Module): ...@@ -171,8 +182,11 @@ class Decoder(nn.Module):
# expand True into [True, True, ...] and do the same with False # expand True into [True, True, ...] and do the same with False
attention = [attention] * len(convolutions) attention = [attention] * len(convolutions)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx) self.embed_positions = Embedding(max_positions, embed_dim, padding_idx)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList() self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList() self.convolutions = nn.ModuleList()
...@@ -503,24 +517,21 @@ def parse_arch(args): ...@@ -503,24 +517,21 @@ def parse_arch(args):
return args return args
def build_model(args, dataset): def build_model(args, src_dict, dst_dict):
padding_idx = dataset.dst_dict.pad()
encoder = Encoder( encoder = Encoder(
len(dataset.src_dict), src_dict,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers), convolutions=eval(args.encoder_layers),
dropout=args.dropout, dropout=args.dropout,
padding_idx=padding_idx,
max_positions=args.max_positions, max_positions=args.max_positions,
) )
decoder = Decoder( decoder = Decoder(
len(dataset.dst_dict), dst_dict,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers), convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim, out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention), attention=eval(args.decoder_attention),
dropout=args.dropout, dropout=args.dropout,
padding_idx=padding_idx,
max_positions=args.max_positions, max_positions=args.max_positions,
) )
return FConvModel(encoder, decoder, padding_idx) return FConvModel(encoder, decoder)
...@@ -16,7 +16,7 @@ from fairseq import utils ...@@ -16,7 +16,7 @@ from fairseq import utils
class SequenceGenerator(object): class SequenceGenerator(object):
def __init__(self, models, dst_dict, beam_size=1, minlen=1, maxlen=200, def __init__(self, models, beam_size=1, minlen=1, maxlen=200,
stop_early=True, normalize_scores=True, len_penalty=1): stop_early=True, normalize_scores=True, len_penalty=1):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
...@@ -29,13 +29,14 @@ class SequenceGenerator(object): ...@@ -29,13 +29,14 @@ class SequenceGenerator(object):
normalize_scores: Normalize scores by the length of the output. normalize_scores: Normalize scores by the length of the output.
""" """
self.models = models self.models = models
self.dict = dst_dict self.pad = models[0].dst_dict.pad()
self.pad = dst_dict.pad() self.eos = models[0].dst_dict.eos()
self.eos = dst_dict.eos() assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
self.vocab_size = len(dst_dict) assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size self.beam_size = beam_size
self.minlen = minlen self.minlen = minlen
self.maxlen = min(maxlen, *(m.decoder.max_positions() - self.pad - 2 for m in self.models)) self.maxlen = min(maxlen, *[m.decoder.max_positions() - self.pad - 2 for m in self.models])
self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2)) self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2))
self.decoder_context = models[0].decoder.context_size() self.decoder_context = models[0].decoder.context_size()
self.stop_early = stop_early self.stop_early = stop_early
...@@ -91,7 +92,7 @@ class SequenceGenerator(object): ...@@ -91,7 +92,7 @@ class SequenceGenerator(object):
# the max beam size is the dictionary size - 1, since we never select pad # the max beam size is the dictionary size - 1, since we never select pad
beam_size = beam_size if beam_size is not None else self.beam_size beam_size = beam_size if beam_size is not None else self.beam_size
beam_size = min(beam_size, len(self.dict) - 1) beam_size = min(beam_size, self.vocab_size - 1)
encoder_outs = [] encoder_outs = []
for model in self.models: for model in self.models:
......
...@@ -14,7 +14,7 @@ import traceback ...@@ -14,7 +14,7 @@ import traceback
from torch.autograd import Variable from torch.autograd import Variable
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq import criterions, data, models from fairseq import criterions, models
def parse_args_and_arch(parser): def parse_args_and_arch(parser):
...@@ -24,13 +24,13 @@ def parse_args_and_arch(parser): ...@@ -24,13 +24,13 @@ def parse_args_and_arch(parser):
return args return args
def build_model(args, dataset): def build_model(args, src_dict, dst_dict):
assert hasattr(models, args.model), 'Missing model type' assert hasattr(models, args.model), 'Missing model type'
return getattr(models, args.model).build_model(args, dataset) return getattr(models, args.model).build_model(args, src_dict, dst_dict)
def build_criterion(args, dataset): def build_criterion(args, src_dict, dst_dict):
padding_idx = dataset.dst_dict.pad() padding_idx = dst_dict.pad()
if args.label_smoothing > 0: if args.label_smoothing > 0:
return criterions.LabelSmoothedCrossEntropyCriterion(args.label_smoothing, padding_idx) return criterions.LabelSmoothedCrossEntropyCriterion(args.label_smoothing, padding_idx)
else: else:
...@@ -117,7 +117,7 @@ def _upgrade_state_dict(state): ...@@ -117,7 +117,7 @@ def _upgrade_state_dict(state):
return state return state
def load_ensemble_for_inference(filenames, data_path, split): def load_ensemble_for_inference(filenames, src_dict, dst_dict):
# load model architectures and weights # load model architectures and weights
states = [] states = []
for filename in filenames: for filename in filenames:
...@@ -126,19 +126,15 @@ def load_ensemble_for_inference(filenames, data_path, split): ...@@ -126,19 +126,15 @@ def load_ensemble_for_inference(filenames, data_path, split):
states.append( states.append(
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
) )
# load dataset
args = states[0]['args'] args = states[0]['args']
dataset = data.load(data_path, [split], args.source_lang, args.target_lang)
# build models # build ensemble
ensemble = [] ensemble = []
for state in states: for state in states:
model = build_model(args, dataset) model = build_model(args, src_dict, dst_dict)
model.load_state_dict(state['model']) model.load_state_dict(state['model'])
ensemble.append(model) ensemble.append(model)
return ensemble
return ensemble, dataset
def prepare_sample(sample, volatile=False, cuda_device=None): def prepare_sample(sample, volatile=False, cuda_device=None):
......
...@@ -11,7 +11,7 @@ import sys ...@@ -11,7 +11,7 @@ import sys
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import bleu, options, tokenizer, utils from fairseq import bleu, data, options, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
...@@ -37,9 +37,15 @@ def main(): ...@@ -37,9 +37,15 @@ def main():
progress_bar.enabled = False progress_bar.enabled = False
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Load model and dataset # Load dataset
dataset = data.load_with_check(args.data, [args.gen_subset], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path))) print('| loading model(s) from {}'.format(', '.join(args.path)))
models, dataset = utils.load_ensemble_for_inference(args.path, args.data, args.gen_subset) models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
...@@ -50,13 +56,13 @@ def main(): ...@@ -50,13 +56,13 @@ def main():
# ignore too long sentences # ignore too long sentences
args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models)) args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models))
# Optimize model for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_(not args.no_beamable_mm) model.make_generation_fast_(not args.no_beamable_mm)
# Initialize generator # Initialize generator
translator = SequenceGenerator( translator = SequenceGenerator(
models, dataset.dst_dict, beam_size=args.beam, stop_early=(not args.no_early_stop), models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen normalize_scores=(not args.unnormalized), len_penalty=args.lenpen
) )
if use_cuda: if use_cuda:
......
...@@ -12,11 +12,10 @@ import os ...@@ -12,11 +12,10 @@ import os
import torch import torch
import math import math
from fairseq import bleu, data, options, utils from fairseq import data, options, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer from fairseq.multiprocessing_trainer import MultiprocessingTrainer
from fairseq.progress_bar import progress_bar from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator
def main(): def main():
...@@ -53,7 +52,7 @@ def main(): ...@@ -53,7 +52,7 @@ def main():
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in dataset.splits: for split in ['train', 'valid']:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -63,8 +62,8 @@ def main(): ...@@ -63,8 +62,8 @@ def main():
print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens)) print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens))
# Build model and criterion # Build model and criterion
model = utils.build_model(args, dataset) model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
criterion = utils.build_criterion(args, dataset) criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
# Start multiprocessing # Start multiprocessing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment