Commit 84b82dc6 authored by Myle Ott's avatar Myle Ott
Browse files

Simplify deps of build_model to only depend on dict (instead of dataset)

parent d646a4a8
......@@ -15,12 +15,18 @@ from fairseq.modules import BeamableMM, LinearizedConvolution
class FConvModel(nn.Module):
def __init__(self, encoder, decoder, padding_idx=1):
def __init__(self, encoder, decoder):
super(FConvModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self.encoder.num_attention_layers = sum([layer is not None for layer in decoder.attention])
self.padding_idx = padding_idx
self._is_generation_fast = False
def forward(self, src_tokens, src_positions, input_tokens, input_positions):
......@@ -67,11 +73,15 @@ class FConvModel(nn.Module):
class Encoder(nn.Module):
"""Convolutional encoder"""
def __init__(self, num_embeddings, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, padding_idx=1):
def __init__(self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1):
super(Encoder, self).__init__()
self.dictionary = dictionary
self.dropout = dropout
self.num_attention_layers = None
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx)
......@@ -160,10 +170,11 @@ class AttentionLayer(nn.Module):
class Decoder(nn.Module):
"""Convolutional decoder"""
def __init__(self, num_embeddings, embed_dim=512, out_embed_dim=256,
def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1, padding_idx=1):
attention=True, dropout=0.1):
super(Decoder, self).__init__()
self.dictionary = dictionary
self.dropout = dropout
in_channels = convolutions[0][0]
......@@ -171,8 +182,11 @@ class Decoder(nn.Module):
# expand True into [True, True, ...] and do the same with False
attention = [attention] * len(convolutions)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList()
......@@ -503,24 +517,21 @@ def parse_arch(args):
return args
def build_model(args, dataset):
padding_idx = dataset.dst_dict.pad()
def build_model(args, src_dict, dst_dict):
encoder = Encoder(
len(dataset.src_dict),
src_dict,
embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
padding_idx=padding_idx,
max_positions=args.max_positions,
)
decoder = Decoder(
len(dataset.dst_dict),
dst_dict,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
padding_idx=padding_idx,
max_positions=args.max_positions,
)
return FConvModel(encoder, decoder, padding_idx)
return FConvModel(encoder, decoder)
......@@ -16,7 +16,7 @@ from fairseq import utils
class SequenceGenerator(object):
def __init__(self, models, dst_dict, beam_size=1, minlen=1, maxlen=200,
def __init__(self, models, beam_size=1, minlen=1, maxlen=200,
stop_early=True, normalize_scores=True, len_penalty=1):
"""Generates translations of a given source sentence.
......@@ -29,13 +29,14 @@ class SequenceGenerator(object):
normalize_scores: Normalize scores by the length of the output.
"""
self.models = models
self.dict = dst_dict
self.pad = dst_dict.pad()
self.eos = dst_dict.eos()
self.vocab_size = len(dst_dict)
self.pad = models[0].dst_dict.pad()
self.eos = models[0].dst_dict.eos()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size
self.minlen = minlen
self.maxlen = min(maxlen, *(m.decoder.max_positions() - self.pad - 2 for m in self.models))
self.maxlen = min(maxlen, *[m.decoder.max_positions() - self.pad - 2 for m in self.models])
self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2))
self.decoder_context = models[0].decoder.context_size()
self.stop_early = stop_early
......@@ -91,7 +92,7 @@ class SequenceGenerator(object):
# the max beam size is the dictionary size - 1, since we never select pad
beam_size = beam_size if beam_size is not None else self.beam_size
beam_size = min(beam_size, len(self.dict) - 1)
beam_size = min(beam_size, self.vocab_size - 1)
encoder_outs = []
for model in self.models:
......
......@@ -14,7 +14,7 @@ import traceback
from torch.autograd import Variable
from torch.serialization import default_restore_location
from fairseq import criterions, data, models
from fairseq import criterions, models
def parse_args_and_arch(parser):
......@@ -24,13 +24,13 @@ def parse_args_and_arch(parser):
return args
def build_model(args, dataset):
def build_model(args, src_dict, dst_dict):
assert hasattr(models, args.model), 'Missing model type'
return getattr(models, args.model).build_model(args, dataset)
return getattr(models, args.model).build_model(args, src_dict, dst_dict)
def build_criterion(args, dataset):
padding_idx = dataset.dst_dict.pad()
def build_criterion(args, src_dict, dst_dict):
padding_idx = dst_dict.pad()
if args.label_smoothing > 0:
return criterions.LabelSmoothedCrossEntropyCriterion(args.label_smoothing, padding_idx)
else:
......@@ -117,7 +117,7 @@ def _upgrade_state_dict(state):
return state
def load_ensemble_for_inference(filenames, data_path, split):
def load_ensemble_for_inference(filenames, src_dict, dst_dict):
# load model architectures and weights
states = []
for filename in filenames:
......@@ -126,19 +126,15 @@ def load_ensemble_for_inference(filenames, data_path, split):
states.append(
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
)
# load dataset
args = states[0]['args']
dataset = data.load(data_path, [split], args.source_lang, args.target_lang)
# build models
# build ensemble
ensemble = []
for state in states:
model = build_model(args, dataset)
model = build_model(args, src_dict, dst_dict)
model.load_state_dict(state['model'])
ensemble.append(model)
return ensemble, dataset
return ensemble
def prepare_sample(sample, volatile=False, cuda_device=None):
......
......@@ -11,7 +11,7 @@ import sys
import torch
from torch.autograd import Variable
from fairseq import bleu, options, tokenizer, utils
from fairseq import bleu, data, options, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator
......@@ -37,9 +37,15 @@ def main():
progress_bar.enabled = False
use_cuda = torch.cuda.is_available() and not args.cpu
# Load model and dataset
# Load dataset
dataset = data.load_with_check(args.data, [args.gen_subset], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models, dataset = utils.load_ensemble_for_inference(args.path, args.data, args.gen_subset)
models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
......@@ -50,13 +56,13 @@ def main():
# ignore too long sentences
args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models))
# Optimize model for generation
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(not args.no_beamable_mm)
# Initialize generator
translator = SequenceGenerator(
models, dataset.dst_dict, beam_size=args.beam, stop_early=(not args.no_early_stop),
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen
)
if use_cuda:
......
......@@ -12,11 +12,10 @@ import os
import torch
import math
from fairseq import bleu, data, options, utils
from fairseq import data, options, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator
def main():
......@@ -53,7 +52,7 @@ def main():
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in dataset.splits:
for split in ['train', 'valid']:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
if not torch.cuda.is_available():
......@@ -63,8 +62,8 @@ def main():
print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens))
# Build model and criterion
model = utils.build_model(args, dataset)
criterion = utils.build_criterion(args, dataset)
model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
# Start multiprocessing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment