Commit f69206c8 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

fix adaptive softmax indexing

parent af38ed48
......@@ -57,11 +57,19 @@ class BaseFairseqModel(nn.Module):
def upgrade_state_dict(self, state_dict):
assert state_dict is not None
def do_upgrade(m):
if m != self and hasattr(m, 'upgrade_state_dict'):
m.upgrade_state_dict(state_dict)
self.apply(do_upgrade)
def do_upgrade(m, prefix):
if len(prefix) > 0:
prefix += '.'
for n, c in m.named_children():
name = prefix + n
if hasattr(c, 'upgrade_state_dict_named'):
c.upgrade_state_dict_named(state_dict, name)
elif hasattr(c, 'upgrade_state_dict'):
c.upgrade_state_dict(state_dict)
do_upgrade(c, name)
do_upgrade(self, '')
def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation."""
......
......@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn.functional as F
from torch import nn
......@@ -51,6 +52,16 @@ class AdaptiveSoftmax(nn.Module):
self.apply(init_weights)
self.register_buffer('version', torch.LongTensor([1]))
# versions prior to 1 had a bug that offset indices on the head by 1
self.buggy_offset = 0
def upgrade_state_dict_named(self, state_dict, name):
version_name = name + '.version'
if version_name not in state_dict:
self.buggy_offset = 1
state_dict[version_name] = torch.LongTensor([1])
def adapt_target(self, target):
"""
In order to be efficient, the AdaptiveSoftMax does not compute the
......@@ -65,7 +76,7 @@ class AdaptiveSoftmax(nn.Module):
for i in range(len(self.cutoff) - 1):
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
new_target[0][mask] = self.cutoff[0] + i - 1
new_target[0][mask] = self.cutoff[0] + i - self.buggy_offset
if mask.any():
target_idxs.append(mask.nonzero().squeeze(1))
......@@ -118,7 +129,7 @@ class AdaptiveSoftmax(nn.Module):
head_sz = self.cutoff[0] + len(self.tail)
log_probs[:, :head_sz] = self.lsm(head_y)
tail_priors = log_probs[:, self.cutoff[0] - 1: head_sz - 1].clone()
tail_priors = log_probs[:, self.cutoff[0] - self.buggy_offset: head_sz - self.buggy_offset].clone()
for i in range(len(self.tail)):
start = self.cutoff[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment