Commit 753935ef authored by Myle Ott's avatar Myle Ott
Browse files

Merge internal changes

parent c7c567a7
...@@ -34,6 +34,17 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -34,6 +34,17 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training 3) logging outputs to display while training
""" """
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1)) lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1, 1) target = model.get_targets(sample, net_output).view(-1, 1)
...@@ -45,15 +56,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -45,15 +56,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
smooth_loss = smooth_loss.sum() smooth_loss = smooth_loss.sum()
eps_i = self.eps / lprobs.size(-1) eps_i = self.eps / lprobs.size(-1)
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
@staticmethod @staticmethod
def aggregate_logging_outputs(logging_outputs): def aggregate_logging_outputs(logging_outputs):
......
...@@ -100,6 +100,14 @@ class BaseFairseqModel(nn.Module): ...@@ -100,6 +100,14 @@ class BaseFairseqModel(nn.Module):
self.eval() self.eval()
self.train = train self.train = train
def prepare_for_onnx_export_(self, **kwargs):
"""Make model exportable via ONNX trace."""
def apply_prepare_for_onnx_export_(module):
if module != self and hasattr(module, 'prepare_for_onnx_export_'):
module.prepare_for_onnx_export_(**kwargs)
self.apply(apply_prepare_for_onnx_export_)
class FairseqModel(BaseFairseqModel): class FairseqModel(BaseFairseqModel):
"""Base class for encoder-decoder models.""" """Base class for encoder-decoder models."""
......
...@@ -27,7 +27,7 @@ class ConvTBC(torch.nn.Module): ...@@ -27,7 +27,7 @@ class ConvTBC(torch.nn.Module):
self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
def forward(self, input): def forward(self, input):
return input.contiguous().conv_tbc(self.weight, self.bias, self.padding[0]) return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0])
def __repr__(self): def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
......
...@@ -161,17 +161,12 @@ class MultiheadAttention(nn.Module): ...@@ -161,17 +161,12 @@ class MultiheadAttention(nn.Module):
def in_proj_v(self, value): def in_proj_v(self, value):
return self._in_proj(value, start=2*self.embed_dim) return self._in_proj(value, start=2*self.embed_dim)
def _in_proj(self, input, start=None, end=None): def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight weight = self.in_proj_weight
bias = self.in_proj_bias bias = self.in_proj_bias
if end is not None: weight = weight[start:end, :]
weight = weight[:end, :]
if bias is not None: if bias is not None:
bias = bias[:end] bias = bias[start:end]
if start is not None:
weight = weight[start:, :]
if bias is not None:
bias = bias[start:]
return F.linear(input, weight, bias) return F.linear(input, weight, bias)
def buffered_mask(self, tensor): def buffered_mask(self, tensor):
......
...@@ -30,8 +30,12 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -30,8 +30,12 @@ class SinusoidalPositionalEmbedding(nn.Module):
embedding_dim, embedding_dim,
padding_idx, padding_idx,
) )
self.onnx_trace = False
self.register_buffer('_float_tensor', torch.FloatTensor(1)) self.register_buffer('_float_tensor', torch.FloatTensor(1))
def prepare_for_onnx_export_(self):
self.onnx_trace = True
@staticmethod @staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None): def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings. """Build sinusoidal embeddings.
...@@ -68,7 +72,14 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -68,7 +72,14 @@ class SinusoidalPositionalEmbedding(nn.Module):
# positions is the same for every token when decoding a single step # positions is the same for every token when decoding a single step
return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1) return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad) positions = utils.make_positions(input, self.padding_idx, self.left_pad, self.onnx_trace)
if self.onnx_trace:
bsz = torch.onnx.operators.shape_as_tensor(input)[0]
seq_len = torch.onnx.operators.shape_as_tensor(input)[1]
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1])))
embeddings = torch.onnx.operators.reshape_from_tensor_shape(flat_embeddings, embedding_shape)
return embeddings
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
def max_positions(self): def max_positions(self):
......
...@@ -46,7 +46,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, ...@@ -46,7 +46,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
extra_state = {} extra_state = {}
state_dict = { state_dict = {
'args': args, 'args': args,
'model': convert_state_dict_type(model.state_dict()), 'model': model.state_dict() if model else {},
'optimizer_history': optim_history + [ 'optimizer_history': optim_history + [
{ {
'criterion_name': criterion.__class__.__name__, 'criterion_name': criterion.__class__.__name__,
...@@ -298,7 +298,7 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dic ...@@ -298,7 +298,7 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dic
return hypo_tokens, hypo_str, alignment return hypo_tokens, hypo_str, alignment
def make_positions(tensor, padding_idx, left_pad): def make_positions(tensor, padding_idx, left_pad, onnx_trace=False):
"""Replace non-padding symbols with their position numbers. """Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Position numbers begin at padding_idx+1.
...@@ -306,6 +306,14 @@ def make_positions(tensor, padding_idx, left_pad): ...@@ -306,6 +306,14 @@ def make_positions(tensor, padding_idx, left_pad):
Padding symbols are ignored, but it is necessary to specify whether padding Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False). is added on the left side (left_pad=True) or right side (left_pad=False).
""" """
if onnx_trace:
range_buf = torch._dim_arange(like=tensor, dim=1) + padding_idx + 1
mask = tensor.ne(padding_idx)
positions = range_buf.expand_as(tensor)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return positions * mask.long() + positions * (1 - mask.long())
max_pos = padding_idx + 1 + tensor.size(1) max_pos = padding_idx + 1 + tensor.size(1)
if not hasattr(make_positions, 'range_buf'): if not hasattr(make_positions, 'range_buf'):
make_positions.range_buf = tensor.new() make_positions.range_buf = tensor.new()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment