Commit 4db6579a authored by Myle Ott's avatar Myle Ott
Browse files

Move normalization of model output (e.g., via LSM) into model definition

parent c21a6e29
......@@ -26,10 +26,10 @@ class CrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
input = net_output.view(-1, net_output.size(-1))
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0] if reduce else loss.data,
......
......@@ -62,9 +62,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1)), dim=1)
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(input, target, self.eps, self.padding_idx, self.weights, reduce)
loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, self.weights, reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0] if reduce else loss.data,
......
......@@ -7,6 +7,7 @@
#
import torch.nn as nn
import torch.nn.functional as F
class FairseqDecoder(nn.Module):
......@@ -15,6 +16,15 @@ class FairseqDecoder(nn.Module):
def __init__(self):
super().__init__()
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
vocab = net_output.size(-1)
net_output1 = net_output.view(-1, vocab)
if log_probs:
return F.log_softmax(net_output1, dim=1).view_as(net_output)
else:
return F.softmax(net_output1, dim=1).view_as(net_output)
def max_positions(self):
"""Maximum input length supported by the decoder."""
raise NotImplementedError
......
......@@ -37,7 +37,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
with model.decoder.incremental_inference():
for step in range(maxlen):
out, _ = model.decoder(tokens[:, :step], encoder_out)
probs = torch.nn.functional.log_softmax(out[:, -1, :])
probs = model.get_normalized_probs(out[:, -1, :], log_probs=False)
```
"""
class IncrementalInference(object):
......
......@@ -35,6 +35,10 @@ class FairseqModel(nn.Module):
decoder_out, _ = self.decoder(input_tokens, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1))
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs)
def max_encoder_positions(self):
"""Maximum input length supported by the encoder."""
return self.encoder.max_positions()
......
......@@ -328,7 +328,7 @@ class SequenceGenerator(object):
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, encoder_out)
probs = F.softmax(decoder_out[:, -1, :], dim=1).data
probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data
if avg_probs is None:
avg_probs = probs
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment