Commit f69206c8 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

fix adaptive softmax indexing

parent af38ed48
...@@ -57,11 +57,19 @@ class BaseFairseqModel(nn.Module): ...@@ -57,11 +57,19 @@ class BaseFairseqModel(nn.Module):
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
assert state_dict is not None assert state_dict is not None
def do_upgrade(m): def do_upgrade(m, prefix):
if m != self and hasattr(m, 'upgrade_state_dict'): if len(prefix) > 0:
m.upgrade_state_dict(state_dict) prefix += '.'
self.apply(do_upgrade) for n, c in m.named_children():
name = prefix + n
if hasattr(c, 'upgrade_state_dict_named'):
c.upgrade_state_dict_named(state_dict, name)
elif hasattr(c, 'upgrade_state_dict'):
c.upgrade_state_dict(state_dict)
do_upgrade(c, name)
do_upgrade(self, '')
def make_generation_fast_(self, **kwargs): def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation.""" """Optimize model for faster generation."""
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -51,6 +52,16 @@ class AdaptiveSoftmax(nn.Module): ...@@ -51,6 +52,16 @@ class AdaptiveSoftmax(nn.Module):
self.apply(init_weights) self.apply(init_weights)
self.register_buffer('version', torch.LongTensor([1]))
# versions prior to 1 had a bug that offset indices on the head by 1
self.buggy_offset = 0
def upgrade_state_dict_named(self, state_dict, name):
version_name = name + '.version'
if version_name not in state_dict:
self.buggy_offset = 1
state_dict[version_name] = torch.LongTensor([1])
def adapt_target(self, target): def adapt_target(self, target):
""" """
In order to be efficient, the AdaptiveSoftMax does not compute the In order to be efficient, the AdaptiveSoftMax does not compute the
...@@ -65,7 +76,7 @@ class AdaptiveSoftmax(nn.Module): ...@@ -65,7 +76,7 @@ class AdaptiveSoftmax(nn.Module):
for i in range(len(self.cutoff) - 1): for i in range(len(self.cutoff) - 1):
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
new_target[0][mask] = self.cutoff[0] + i - 1 new_target[0][mask] = self.cutoff[0] + i - self.buggy_offset
if mask.any(): if mask.any():
target_idxs.append(mask.nonzero().squeeze(1)) target_idxs.append(mask.nonzero().squeeze(1))
...@@ -118,7 +129,7 @@ class AdaptiveSoftmax(nn.Module): ...@@ -118,7 +129,7 @@ class AdaptiveSoftmax(nn.Module):
head_sz = self.cutoff[0] + len(self.tail) head_sz = self.cutoff[0] + len(self.tail)
log_probs[:, :head_sz] = self.lsm(head_y) log_probs[:, :head_sz] = self.lsm(head_y)
tail_priors = log_probs[:, self.cutoff[0] - 1: head_sz - 1].clone() tail_priors = log_probs[:, self.cutoff[0] - self.buggy_offset: head_sz - self.buggy_offset].clone()
for i in range(len(self.tail)): for i in range(len(self.tail)):
start = self.cutoff[i] start = self.cutoff[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment