"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "4879a234c4bd3f2bbc99d9b09c44bd99fc337679"
Unverified Commit d3795d6c authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge internal changes (#136)

Changes:
- 7d19e36: Add `--sampling` flag to generate.py to sample instead of doing beam search
- c777340: Add `scripts/average_checkpoints.py` to average multiple checkpoints into a combined model
- 3ea882c: Add `--max-update` option to train.py to stop training after a given number of updates
- small bugfixes for distributed training, LSTM, inverse square root LR scheduler
parent 48836525
...@@ -8,9 +8,11 @@ ...@@ -8,9 +8,11 @@
import math import math
import torch.nn.functional as F import torch.nn.functional as F
from . import FairseqCriterion, register_criterion
from fairseq import utils from fairseq import utils
from . import FairseqCriterion, register_criterion
@register_criterion('cross_entropy') @register_criterion('cross_entropy')
class CrossEntropyCriterion(FairseqCriterion): class CrossEntropyCriterion(FairseqCriterion):
...@@ -28,7 +30,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -28,7 +30,7 @@ class CrossEntropyCriterion(FairseqCriterion):
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1)) lprobs = lprobs.view(-1, lprobs.size(-1))
target = sample['target'].view(-1) target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce) reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
...@@ -47,6 +49,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -47,6 +49,7 @@ class CrossEntropyCriterion(FairseqCriterion):
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = { agg_output = {
'loss': loss_sum / sample_size / math.log(2), 'loss': loss_sum / sample_size / math.log(2),
'sample_size': sample_size,
} }
if sample_size != ntokens: if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
......
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math import math
import torch
import torch.nn.functional as F
from fairseq import utils from fairseq import utils
...@@ -37,7 +35,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -37,7 +35,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
""" """
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample['target'].unsqueeze(-1) lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1, 1)
non_pad_mask = target.ne(self.padding_idx) non_pad_mask = target.ne(self.padding_idx)
nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask]
...@@ -64,4 +63,5 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -64,4 +63,5 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
return { return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2), 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2),
'sample_size': sample_size,
} }
...@@ -198,7 +198,7 @@ class LanguagePairDataset(torch.utils.data.Dataset): ...@@ -198,7 +198,7 @@ class LanguagePairDataset(torch.utils.data.Dataset):
def __getitem__(self, i): def __getitem__(self, i):
# subtract 1 for 0-based indexing # subtract 1 for 0-based indexing
source = self.src[i].long() - 1 source = self.src[i].long() - 1
res = { 'id': i, 'source': source } res = {'id': i, 'source': source}
if self.dst: if self.dst:
res['target'] = self.dst[i].long() - 1 res['target'] = self.dst[i].long() - 1
...@@ -283,9 +283,9 @@ def _valid_size(src_size, dst_size, max_positions): ...@@ -283,9 +283,9 @@ def _valid_size(src_size, dst_size, max_positions):
max_src_positions, max_dst_positions = max_positions, max_positions max_src_positions, max_dst_positions = max_positions, max_positions
else: else:
max_src_positions, max_dst_positions = max_positions max_src_positions, max_dst_positions = max_positions
if src_size < 2 or src_size > max_src_positions: if src_size < 1 or src_size > max_src_positions:
return False return False
if dst_size is not None and (dst_size < 2 or dst_size > max_dst_positions): if dst_size is not None and (dst_size < 1 or dst_size > max_dst_positions):
return False return False
return True return True
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math import math
import os
import torch import torch
...@@ -23,6 +24,9 @@ class Dictionary(object): ...@@ -23,6 +24,9 @@ class Dictionary(object):
self.unk_index = self.add_symbol(unk) self.unk_index = self.add_symbol(unk)
self.nspecial = len(self.symbols) self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx): def __getitem__(self, idx):
if idx < len(self.symbols): if idx < len(self.symbols):
return self.symbols[idx] return self.symbols[idx]
...@@ -97,8 +101,8 @@ class Dictionary(object): ...@@ -97,8 +101,8 @@ class Dictionary(object):
"""Helper to get index of unk symbol""" """Helper to get index of unk symbol"""
return self.unk_index return self.unk_index
@staticmethod @classmethod
def load(f): def load(cls, f):
"""Loads the dictionary from a text file with the format: """Loads the dictionary from a text file with the format:
``` ```
...@@ -111,14 +115,14 @@ class Dictionary(object): ...@@ -111,14 +115,14 @@ class Dictionary(object):
if isinstance(f, str): if isinstance(f, str):
try: try:
with open(f, 'r', encoding='utf-8') as fd: with open(f, 'r', encoding='utf-8') as fd:
return Dictionary.load(fd) return cls.load(fd)
except FileNotFoundError as fnfe: except FileNotFoundError as fnfe:
raise fnfe raise fnfe
except Exception: except Exception:
raise Exception("Incorrect encoding detected in {}, please " raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f)) "rebuild the dataset".format(f))
d = Dictionary() d = cls()
for line in f.readlines(): for line in f.readlines():
idx = line.rfind(' ') idx = line.rfind(' ')
word = line[:idx] word = line[:idx]
...@@ -131,6 +135,7 @@ class Dictionary(object): ...@@ -131,6 +135,7 @@ class Dictionary(object):
def save(self, f, threshold=3, nwords=-1): def save(self, f, threshold=3, nwords=-1):
"""Stores dictionary into a text file""" """Stores dictionary into a text file"""
if isinstance(f, str): if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd: with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd, threshold, nwords) return self.save(fd, threshold, nwords)
cnt = 0 cnt = 0
......
...@@ -10,6 +10,12 @@ import pickle ...@@ -10,6 +10,12 @@ import pickle
import torch.distributed import torch.distributed
from fairseq import utils
def is_master(args):
return args.distributed_rank == 0
def distributed_init(args): def distributed_init(args):
if args.distributed_world_size == 1: if args.distributed_world_size == 1:
...@@ -27,7 +33,7 @@ def distributed_init(args): ...@@ -27,7 +33,7 @@ def distributed_init(args):
world_size=args.distributed_world_size) world_size=args.distributed_world_size)
args.distributed_rank = torch.distributed.get_rank() args.distributed_rank = torch.distributed.get_rank()
if args.distributed_rank != 0: if not is_master(args):
suppress_output() suppress_output()
return args.distributed_rank return args.distributed_rank
...@@ -104,7 +110,7 @@ def all_gather_list(data, max_size=4096): ...@@ -104,7 +110,7 @@ def all_gather_list(data, max_size=4096):
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \ if not hasattr(all_gather_list, '_in_buffer') or \
max_size != all_gather_list._in_buffer.size(): max_size != all_gather_list._in_buffer.size():
all_gather_list._in_buffer = torch.ByteTensor(max_size) all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [ all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size) torch.cuda.ByteTensor(max_size)
for i in range(world_size) for i in range(world_size)
...@@ -113,18 +119,21 @@ def all_gather_list(data, max_size=4096): ...@@ -113,18 +119,21 @@ def all_gather_list(data, max_size=4096):
out_buffers = all_gather_list._out_buffers out_buffers = all_gather_list._out_buffers
enc = pickle.dumps(data) enc = pickle.dumps(data)
if len(enc) >= max_size: enc_size = len(enc)
raise ValueError('encoded data exceeds max_size: {}'.format(len(enc))) if enc_size + 2 > max_size:
in_buffer[0] = len(enc) raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
in_buffer[1:len(enc)+1] = torch.ByteTensor(list(enc)) assert max_size < 255*256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda()) torch.distributed.all_gather(out_buffers, in_buffer.cuda())
result = [] result = []
for i in range(world_size): for i in range(world_size):
out_buffer = out_buffers[i] out_buffer = out_buffers[i]
size = out_buffer[0] size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
result.append( result.append(
pickle.loads(bytes(out_buffer[1:size+1].tolist())) pickle.loads(bytes(out_buffer[2:size+2].tolist()))
) )
return result return result
...@@ -107,10 +107,12 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -107,10 +107,12 @@ class IndexedRawTextDataset(IndexedDataset):
"""Takes a text file as input and binarizes it in memory at instantiation. """Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory""" Original lines are also kept in memory"""
def __init__(self, path, dictionary): def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = [] self.tokens_list = []
self.lines = [] self.lines = []
self.sizes = [] self.sizes = []
self.append_eos = append_eos
self.reverse_order = reverse_order
self.read_data(path, dictionary) self.read_data(path, dictionary)
self.size = len(self.tokens_list) self.size = len(self.tokens_list)
...@@ -118,8 +120,10 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -118,8 +120,10 @@ class IndexedRawTextDataset(IndexedDataset):
with open(path, 'r') as f: with open(path, 'r') as f:
for line in f: for line in f:
self.lines.append(line.strip('\n')) self.lines.append(line.strip('\n'))
# +1 for Lua compatibility tokens = Tokenizer.tokenize(
tokens = Tokenizer.tokenize(line, dictionary, add_if_not_exist=False) + 1 line, dictionary, add_if_not_exist=False,
append_eos=self.append_eos, reverse_order=self.reverse_order,
) + 1 # +1 for Lua compatibility
self.tokens_list.append(tokens) self.tokens_list.append(tokens)
self.sizes.append(len(tokens)) self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes) self.sizes = np.array(self.sizes)
......
...@@ -21,12 +21,11 @@ class FairseqDecoder(nn.Module): ...@@ -21,12 +21,11 @@ class FairseqDecoder(nn.Module):
def get_normalized_probs(self, net_output, log_probs): def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
vocab = net_output.size(-1) logits = net_output[0]
net_output1 = net_output.view(-1, vocab)
if log_probs: if log_probs:
return F.log_softmax(net_output1, dim=1).view_as(net_output) return F.log_softmax(logits, dim=-1)
else: else:
return F.softmax(net_output1, dim=1).view_as(net_output) return F.softmax(logits, dim=-1)
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the decoder.""" """Maximum input length supported by the decoder."""
......
...@@ -41,13 +41,17 @@ class FairseqModel(nn.Module): ...@@ -41,13 +41,17 @@ class FairseqModel(nn.Module):
def forward(self, src_tokens, src_lengths, prev_output_tokens): def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths) encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out, _ = self.decoder(prev_output_tokens, encoder_out) decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out return decoder_out
def get_normalized_probs(self, net_output, log_probs): def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs) return self.decoder.get_normalized_probs(net_output, log_probs)
def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output."""
return sample['target']
def max_encoder_positions(self): def max_encoder_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.encoder.max_positions() return self.encoder.max_positions()
......
...@@ -48,6 +48,11 @@ class FConvModel(FairseqModel): ...@@ -48,6 +48,11 @@ class FConvModel(FairseqModel):
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance.""" """Build a new model instance."""
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions
if not hasattr(args, 'share_input_output_embed'):
args.share_input_output_embed = False
encoder = FConvEncoder( encoder = FConvEncoder(
src_dict, src_dict,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
......
...@@ -65,7 +65,7 @@ class LSTMModel(FairseqModel): ...@@ -65,7 +65,7 @@ class LSTMModel(FairseqModel):
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
out_embed_dim=args.decoder_out_embed_dim, out_embed_dim=args.decoder_out_embed_dim,
num_layers=args.decoder_layers, num_layers=args.decoder_layers,
attention=bool(args.decoder_attention), attention=bool(eval(args.decoder_attention)),
dropout_in=args.decoder_dropout_in, dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out, dropout_out=args.decoder_dropout_out,
) )
...@@ -178,7 +178,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -178,7 +178,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else embed_dim, embed_dim) LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else embed_dim, embed_dim)
for layer in range(num_layers) for layer in range(num_layers)
]) ])
self.attention = AttentionLayer(encoder_embed_dim, embed_dim) self.attention = AttentionLayer(encoder_embed_dim, embed_dim) if attention else None
if embed_dim != out_embed_dim: if embed_dim != out_embed_dim:
self.additional_fc = Linear(embed_dim, out_embed_dim) self.additional_fc = Linear(embed_dim, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
...@@ -229,7 +229,10 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -229,7 +229,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
prev_cells[i] = cell prev_cells[i] = cell
# apply attention using the last layer's hidden state # apply attention using the last layer's hidden state
out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs) if self.attention is not None:
out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs)
else:
out = hidden
out = F.dropout(out, p=self.dropout_out, training=self.training) out = F.dropout(out, p=self.dropout_out, training=self.training)
# input feeding # input feeding
......
...@@ -8,8 +8,6 @@ ...@@ -8,8 +8,6 @@
import torch import torch
from torch.nn.modules.utils import _single from torch.nn.modules.utils import _single
from fairseq import utils
class ConvTBC(torch.nn.Module): class ConvTBC(torch.nn.Module):
"""1D convolution over an input of shape (time x batch x channel) """1D convolution over an input of shape (time x batch x channel)
......
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch
from torch.autograd import Variable from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
from fairseq import utils
class LearnedPositionalEmbedding(nn.Embedding): class LearnedPositionalEmbedding(nn.Embedding):
"""This module learns positional embeddings up to a fixed maximum size. """This module learns positional embeddings up to a fixed maximum size.
...@@ -25,27 +26,11 @@ class LearnedPositionalEmbedding(nn.Embedding): ...@@ -25,27 +26,11 @@ class LearnedPositionalEmbedding(nn.Embedding):
"""Input is expected to be of size [bsz x seqlen].""" """Input is expected to be of size [bsz x seqlen]."""
if incremental_state is not None: if incremental_state is not None:
# positions is the same for every token when decoding a single step # positions is the same for every token when decoding a single step
positions = Variable( positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
input.data.new(1, 1).fill_(self.padding_idx + input.size(1)))
else: else:
positions = Variable(self.make_positions(input.data)) positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return super().forward(positions) return super().forward(Variable(positions))
def max_positions(self): def max_positions(self):
"""Maximum number of supported positions.""" """Maximum number of supported positions."""
return self.num_embeddings - self.padding_idx - 1 return self.num_embeddings - self.padding_idx - 1
def make_positions(self, input):
"""Replace non-padding symbols with their position numbers."""
if not hasattr(self, 'range_buf'):
self.range_buf = input.new()
seqlen = input.size(1)
if self.range_buf.numel() < seqlen:
# offset positions by the padding index
torch.arange(self.padding_idx + 1, self.padding_idx + 1 + seqlen,
out=self.range_buf)
mask = input.ne(self.padding_idx)
positions = self.range_buf[:seqlen].expand_as(input)
if self.left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return input.clone().masked_scatter_(mask, positions[mask])
...@@ -16,6 +16,7 @@ OPTIMIZER_CLASS_NAMES = set() ...@@ -16,6 +16,7 @@ OPTIMIZER_CLASS_NAMES = set()
def build_optimizer(args, params): def build_optimizer(args, params):
params = filter(lambda p: p.requires_grad, params)
return OPTIMIZER_REGISTRY[args.optimizer](args, params) return OPTIMIZER_REGISTRY[args.optimizer](args, params)
......
...@@ -23,6 +23,8 @@ class FairseqAdam(FairseqOptimizer): ...@@ -23,6 +23,8 @@ class FairseqAdam(FairseqOptimizer):
"""Add optimizer-specific arguments to the parser.""" """Add optimizer-specific arguments to the parser."""
parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer') help='betas for Adam optimizer')
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer')
@property @property
def optimizer_config(self): def optimizer_config(self):
...@@ -35,6 +37,7 @@ class FairseqAdam(FairseqOptimizer): ...@@ -35,6 +37,7 @@ class FairseqAdam(FairseqOptimizer):
return { return {
'lr': self.args.lr[0], 'lr': self.args.lr[0],
'betas': eval(self.args.adam_betas), 'betas': eval(self.args.adam_betas),
'eps': self.args.adam_eps,
'weight_decay': self.args.weight_decay, 'weight_decay': self.args.weight_decay,
} }
......
...@@ -69,7 +69,7 @@ class InverseSquareRootSchedule(FairseqLRScheduler): ...@@ -69,7 +69,7 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
def step_update(self, num_updates): def step_update(self, num_updates):
"""Update the learning rate after each update.""" """Update the learning rate after each update."""
if num_updates < self.args.warmup_updates: if num_updates < self.args.warmup_updates:
self.lr += self.lr_step self.lr = self.args.warmup_init_lr + num_updates*self.lr_step
else: else:
self.lr = self.decay_factor * num_updates**-0.5 self.lr = self.decay_factor * num_updates**-0.5
self.optimizer.set_lr(self.lr) self.optimizer.set_lr(self.lr)
......
...@@ -32,11 +32,12 @@ def get_generation_parser(): ...@@ -32,11 +32,12 @@ def get_generation_parser():
return parser return parser
def parse_args_and_arch(parser, _args=None): def parse_args_and_arch(parser, input_args=None):
# The parser doesn't know about model/criterion/optimizer-specific args, so # The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we # we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments. # parse a second time after adding the *-specific arguments.
args, _ = parser.parse_known_args(_args) # If input_args is given, we will parse those args instead of sys.argv.
args, _ = parser.parse_known_args(input_args)
# Add model-specific args to parser. # Add model-specific args to parser.
model_specific_group = parser.add_argument_group( model_specific_group = parser.add_argument_group(
...@@ -53,7 +54,7 @@ def parse_args_and_arch(parser, _args=None): ...@@ -53,7 +54,7 @@ def parse_args_and_arch(parser, _args=None):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser) LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
# Parse a second time. # Parse a second time.
args = parser.parse_args(_args) args = parser.parse_args(input_args)
# Post-process args. # Post-process args.
args.lr = list(map(float, args.lr.split(','))) args.lr = list(map(float, args.lr.split(',')))
...@@ -140,6 +141,8 @@ def add_optimization_args(parser): ...@@ -140,6 +141,8 @@ def add_optimization_args(parser):
group = parser.add_argument_group('Optimization') group = parser.add_argument_group('Optimization')
group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N', group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
help='force stop training at specified epoch') help='force stop training at specified epoch')
group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
help='force stop training at specified update')
group.add_argument('--clip-norm', default=25, type=float, metavar='NORM', group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
help='clip threshold of gradients') help='clip threshold of gradients')
group.add_argument('--sentence-avg', action='store_true', group.add_argument('--sentence-avg', action='store_true',
...@@ -188,12 +191,14 @@ def add_checkpoint_args(parser): ...@@ -188,12 +191,14 @@ def add_checkpoint_args(parser):
help='don\'t save models and checkpoints') help='don\'t save models and checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true', group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints') help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N',
help='validate every N epochs')
return group return group
def add_generation_args(parser): def add_generation_args(parser):
group = parser.add_argument_group('Generation') group = parser.add_argument_group('Generation')
group.add_argument('--path', metavar='FILE', required=True, action='append', group.add_argument('--path', metavar='FILE', action='append',
help='path(s) to model file(s)') help='path(s) to model file(s)')
group.add_argument('--beam', default=5, type=int, metavar='N', group.add_argument('--beam', default=5, type=int, metavar='N',
help='beam size') help='beam size')
...@@ -228,6 +233,8 @@ def add_generation_args(parser): ...@@ -228,6 +233,8 @@ def add_generation_args(parser):
help='just score the reference translation') help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS', group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help=('initialize generation by target prefix of given length')) help=('initialize generation by target prefix of given length'))
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
return group return group
......
...@@ -107,9 +107,9 @@ class json_progress_bar(progress_bar): ...@@ -107,9 +107,9 @@ class json_progress_bar(progress_bar):
yield obj yield obj
if self.stats is not None and i > 0 and \ if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0: self.log_interval is not None and i % self.log_interval == 0:
update = self.epoch + float(i / size) if self.epoch is not None else None update = self.epoch - 1 + float(i / size) if self.epoch is not None else None
stats = self._format_stats(self.stats, epoch=self.epoch, update=update) stats = self._format_stats(self.stats, epoch=self.epoch, update=update)
print('sweep_log: ' + json.dumps(stats), flush=True) print(json.dumps(stats), flush=True)
def log(self, stats): def log(self, stats):
"""Log intermediate stats according to log_interval.""" """Log intermediate stats according to log_interval."""
......
...@@ -15,7 +15,7 @@ from fairseq.models import FairseqIncrementalDecoder ...@@ -15,7 +15,7 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object): class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=None, def __init__(self, models, beam_size=1, minlen=1, maxlen=None,
stop_early=True, normalize_scores=True, len_penalty=1, stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0, retain_dropout=False): unk_penalty=0, retain_dropout=False, sampling=False):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
Args: Args:
...@@ -36,7 +36,7 @@ class SequenceGenerator(object): ...@@ -36,7 +36,7 @@ class SequenceGenerator(object):
self.vocab_size = len(models[0].dst_dict) self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size self.beam_size = beam_size
self.minlen = minlen self.minlen = minlen
max_decoder_len = min([m.max_decoder_positions() for m in self.models]) max_decoder_len = min(m.max_decoder_positions() for m in self.models)
max_decoder_len -= 1 # we define maxlen not including the EOS marker max_decoder_len -= 1 # we define maxlen not including the EOS marker
self.maxlen = max_decoder_len if maxlen is None else min(maxlen, max_decoder_len) self.maxlen = max_decoder_len if maxlen is None else min(maxlen, max_decoder_len)
self.stop_early = stop_early self.stop_early = stop_early
...@@ -44,6 +44,7 @@ class SequenceGenerator(object): ...@@ -44,6 +44,7 @@ class SequenceGenerator(object):
self.len_penalty = len_penalty self.len_penalty = len_penalty
self.unk_penalty = unk_penalty self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout self.retain_dropout = retain_dropout
self.sampling = sampling
def cuda(self): def cuda(self):
for model in self.models: for model in self.models:
...@@ -78,7 +79,7 @@ class SequenceGenerator(object): ...@@ -78,7 +79,7 @@ class SequenceGenerator(object):
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None, prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
) )
if timer is not None: if timer is not None:
timer.stop(sum([len(h[0]['tokens']) for h in hypos])) timer.stop(sum(len(h[0]['tokens']) for h in hypos))
for i, id in enumerate(s['id'].data): for i, id in enumerate(s['id'].data):
src = input['src_tokens'].data[i, :] src = input['src_tokens'].data[i, :]
# remove padding from ref # remove padding from ref
...@@ -255,9 +256,10 @@ class SequenceGenerator(object): ...@@ -255,9 +256,10 @@ class SequenceGenerator(object):
probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous() probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
scores = scores.type_as(probs) scores = scores.type_as(probs)
scores_buf = scores_buf.type_as(probs) scores_buf = scores_buf.type_as(probs)
else: elif not self.sampling:
# make probs contain cumulative scores for each hypothesis # make probs contain cumulative scores for each hypothesis
probs.add_(scores[:, step-1].view(-1, 1)) probs.add_(scores[:, step-1].view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty probs[:, self.unk] -= self.unk_penalty # apply unk penalty
...@@ -278,6 +280,31 @@ class SequenceGenerator(object): ...@@ -278,6 +280,31 @@ class SequenceGenerator(object):
).expand(-1, cand_size) ).expand(-1, cand_size)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
cand_beams.resize_as_(cand_indices).fill_(0) cand_beams.resize_as_(cand_indices).fill_(0)
elif self.sampling:
assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
exp_probs = probs.exp_().view(-1, self.vocab_size)
if step == 0:
# we exclude the first two vocab items, one of which is pad
torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
cand_indices.add_(2)
else:
torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
cand_indices.add_(2)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
cand_scores.log_()
cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
if step == 0:
cand_beams = torch.zeros(bsz, cand_size).type_as(cand_indices)
else:
cand_beams = torch.arange(0, beam_size).repeat(bsz, 2).type_as(cand_indices)
# make scores cumulative
cand_scores.add_(
torch.gather(
scores[:, step-1].view(bsz, beam_size), dim=1,
index=cand_beams,
)
)
else: else:
# take the best 2 x beam_size predictions. We'll choose the first # take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with. # beam_size of these which don't predict eos to continue with.
...@@ -411,8 +438,13 @@ class SequenceGenerator(object): ...@@ -411,8 +438,13 @@ class SequenceGenerator(object):
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
with utils.maybe_no_grad(): with utils.maybe_no_grad():
decoder_out, attn = model.decoder(tokens, encoder_out, incremental_states[model]) if incremental_states[model] is not None:
probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
else: else:
......
...@@ -63,10 +63,11 @@ class SequenceScorer(object): ...@@ -63,10 +63,11 @@ class SequenceScorer(object):
net_input['src_tokens'], net_input['src_tokens'],
net_input['src_lengths'], net_input['src_lengths'],
) )
decoder_out, attn = model.decoder( decoder_out = model.decoder(
net_input['prev_output_tokens'], net_input['prev_output_tokens'],
encoder_out, encoder_out,
) )
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data probs = model.get_normalized_probs(decoder_out, log_probs=False).data
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
......
...@@ -58,10 +58,14 @@ class Tokenizer: ...@@ -58,10 +58,14 @@ class Tokenizer:
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)} return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)}
@staticmethod @staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True, consumer=None): def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = tokenize(line) words = tokenize(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words) nwords = len(words)
ids = torch.IntTensor(nwords + 1) ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words): for i, word in enumerate(words):
if add_if_not_exist: if add_if_not_exist:
idx = dict.add_symbol(word) idx = dict.add_symbol(word)
...@@ -70,5 +74,6 @@ class Tokenizer: ...@@ -70,5 +74,6 @@ class Tokenizer:
if consumer is not None: if consumer is not None:
consumer(word, idx) consumer(word, idx)
ids[i] = idx ids[i] = idx
ids[nwords] = dict.eos_index if append_eos:
ids[nwords] = dict.eos_index
return ids return ids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment