"vscode:/vscode.git/clone" did not exist on "ba7e48455acd0f4d9580f95b4b29e437027ee2a1"
Commit 9f3ccaa6 authored by Myle Ott's avatar Myle Ott
Browse files

Fix weight norm dimension in decoder (fixes #73)

parent 99493a85
......@@ -18,3 +18,6 @@ class FairseqDecoder(nn.Module):
def max_positions(self):
"""Maximum input length supported by the decoder."""
raise NotImplementedError
def upgrade_state_dict(self, state_dict):
return state_dict
......@@ -18,3 +18,6 @@ class FairseqEncoder(nn.Module):
def max_positions(self):
"""Maximum input length supported by the encoder."""
raise NotImplementedError
def upgrade_state_dict(self, state_dict):
return state_dict
......@@ -43,6 +43,11 @@ class FairseqModel(nn.Module):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
def upgrade_state_dict(self, state_dict):
state_dict = self.encoder.upgrade_state_dict(state_dict)
state_dict = self.decoder.upgrade_state_dict(state_dict)
return state_dict
def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation."""
if self._is_generation_fast:
......
......@@ -154,6 +154,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1):
super().__init__()
self.register_buffer('version', torch.Tensor([2]))
self.dictionary = dictionary
self.dropout = dropout
......@@ -265,6 +266,16 @@ class FConvDecoder(FairseqIncrementalDecoder):
"""Maximum output length supported by the decoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1
def upgrade_state_dict(self, state_dict):
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
# old models use incorrect weight norm dimension
for i, conv in enumerate(self.convolutions):
# reconfigure weight norm
nn.utils.remove_weight_norm(conv)
self.convolutions[i] = nn.utils.weight_norm(conv, dim=0)
state_dict['decoder.version'] = torch.Tensor([1])
return state_dict
def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs.
......@@ -307,7 +318,7 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_()
return nn.utils.weight_norm(m)
return nn.utils.weight_norm(m, dim=2)
def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
......
......@@ -94,6 +94,7 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
)
state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model'])
# load model parameters
try:
......@@ -168,6 +169,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
ensemble = []
for state in states:
model = build_model(args, src_dict, dst_dict)
state['model'] = model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'])
ensemble.append(model)
return ensemble, args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment