Commit 753935ef authored by Myle Ott's avatar Myle Ott
Browse files

Merge internal changes

parent c7c567a7
......@@ -34,6 +34,17 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1, 1)
......@@ -45,15 +56,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
smooth_loss = smooth_loss.sum()
eps_i = self.eps / lprobs.size(-1)
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
......
......@@ -100,6 +100,14 @@ class BaseFairseqModel(nn.Module):
self.eval()
self.train = train
def prepare_for_onnx_export_(self, **kwargs):
"""Make model exportable via ONNX trace."""
def apply_prepare_for_onnx_export_(module):
if module != self and hasattr(module, 'prepare_for_onnx_export_'):
module.prepare_for_onnx_export_(**kwargs)
self.apply(apply_prepare_for_onnx_export_)
class FairseqModel(BaseFairseqModel):
"""Base class for encoder-decoder models."""
......
......@@ -27,7 +27,7 @@ class ConvTBC(torch.nn.Module):
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
def forward(self, input):
return input.contiguous().conv_tbc(self.weight, self.bias, self.padding[0])
return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0])
def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
......
......@@ -161,17 +161,12 @@ class MultiheadAttention(nn.Module):
def in_proj_v(self, value):
return self._in_proj(value, start=2*self.embed_dim)
def _in_proj(self, input, start=None, end=None):
def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias
if end is not None:
weight = weight[:end, :]
if bias is not None:
bias = bias[:end]
if start is not None:
weight = weight[start:, :]
if bias is not None:
bias = bias[start:]
weight = weight[start:end, :]
if bias is not None:
bias = bias[start:end]
return F.linear(input, weight, bias)
def buffered_mask(self, tensor):
......
......@@ -30,8 +30,12 @@ class SinusoidalPositionalEmbedding(nn.Module):
embedding_dim,
padding_idx,
)
self.onnx_trace = False
self.register_buffer('_float_tensor', torch.FloatTensor(1))
def prepare_for_onnx_export_(self):
self.onnx_trace = True
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
......@@ -68,7 +72,14 @@ class SinusoidalPositionalEmbedding(nn.Module):
# positions is the same for every token when decoding a single step
return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
positions = utils.make_positions(input, self.padding_idx, self.left_pad, self.onnx_trace)
if self.onnx_trace:
bsz = torch.onnx.operators.shape_as_tensor(input)[0]
seq_len = torch.onnx.operators.shape_as_tensor(input)[1]
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1])))
embeddings = torch.onnx.operators.reshape_from_tensor_shape(flat_embeddings, embedding_shape)
return embeddings
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
def max_positions(self):
......
......@@ -46,7 +46,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
extra_state = {}
state_dict = {
'args': args,
'model': convert_state_dict_type(model.state_dict()),
'model': model.state_dict() if model else {},
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
......@@ -298,7 +298,7 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dic
return hypo_tokens, hypo_str, alignment
def make_positions(tensor, padding_idx, left_pad):
def make_positions(tensor, padding_idx, left_pad, onnx_trace=False):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1.
......@@ -306,6 +306,14 @@ def make_positions(tensor, padding_idx, left_pad):
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
if onnx_trace:
range_buf = torch._dim_arange(like=tensor, dim=1) + padding_idx + 1
mask = tensor.ne(padding_idx)
positions = range_buf.expand_as(tensor)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return positions * mask.long() + positions * (1 - mask.long())
max_pos = padding_idx + 1 + tensor.size(1)
if not hasattr(make_positions, 'range_buf'):
make_positions.range_buf = tensor.new()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment