Commit 4db6579a authored by Myle Ott's avatar Myle Ott
Browse files

Move normalization of model output (e.g., via LSM) into model definition

parent c21a6e29
...@@ -26,9 +26,9 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -26,9 +26,9 @@ class CrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training 3) logging outputs to display while training
""" """
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
input = net_output.view(-1, net_output.size(-1)) lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx, loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce) reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
......
...@@ -62,9 +62,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -62,9 +62,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training 3) logging outputs to display while training
""" """
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1)), dim=1) lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(input, target, self.eps, self.padding_idx, self.weights, reduce) loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, self.weights, reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0] if reduce else loss.data, 'loss': loss.data[0] if reduce else loss.data,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# #
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class FairseqDecoder(nn.Module): class FairseqDecoder(nn.Module):
...@@ -15,6 +16,15 @@ class FairseqDecoder(nn.Module): ...@@ -15,6 +16,15 @@ class FairseqDecoder(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
vocab = net_output.size(-1)
net_output1 = net_output.view(-1, vocab)
if log_probs:
return F.log_softmax(net_output1, dim=1).view_as(net_output)
else:
return F.softmax(net_output1, dim=1).view_as(net_output)
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the decoder.""" """Maximum input length supported by the decoder."""
raise NotImplementedError raise NotImplementedError
......
...@@ -37,7 +37,7 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -37,7 +37,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
with model.decoder.incremental_inference(): with model.decoder.incremental_inference():
for step in range(maxlen): for step in range(maxlen):
out, _ = model.decoder(tokens[:, :step], encoder_out) out, _ = model.decoder(tokens[:, :step], encoder_out)
probs = torch.nn.functional.log_softmax(out[:, -1, :]) probs = model.get_normalized_probs(out[:, -1, :], log_probs=False)
``` ```
""" """
class IncrementalInference(object): class IncrementalInference(object):
......
...@@ -35,6 +35,10 @@ class FairseqModel(nn.Module): ...@@ -35,6 +35,10 @@ class FairseqModel(nn.Module):
decoder_out, _ = self.decoder(input_tokens, encoder_out) decoder_out, _ = self.decoder(input_tokens, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1)) return decoder_out.view(-1, decoder_out.size(-1))
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs)
def max_encoder_positions(self): def max_encoder_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.encoder.max_positions() return self.encoder.max_positions()
......
...@@ -328,7 +328,7 @@ class SequenceGenerator(object): ...@@ -328,7 +328,7 @@ class SequenceGenerator(object):
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, encoder_out) decoder_out, attn = model.decoder(tokens, encoder_out)
probs = F.softmax(decoder_out[:, -1, :], dim=1).data probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment