"vscode:/vscode.git/clone" did not exist on "e602ab1b56889c8f999f07aeddb55d641fba1014"
Commit 9f3ccaa6 authored by Myle Ott's avatar Myle Ott
Browse files

Fix weight norm dimension in decoder (fixes #73)

parent 99493a85
...@@ -18,3 +18,6 @@ class FairseqDecoder(nn.Module): ...@@ -18,3 +18,6 @@ class FairseqDecoder(nn.Module):
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the decoder.""" """Maximum input length supported by the decoder."""
raise NotImplementedError raise NotImplementedError
def upgrade_state_dict(self, state_dict):
return state_dict
...@@ -18,3 +18,6 @@ class FairseqEncoder(nn.Module): ...@@ -18,3 +18,6 @@ class FairseqEncoder(nn.Module):
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
raise NotImplementedError raise NotImplementedError
def upgrade_state_dict(self, state_dict):
return state_dict
...@@ -43,6 +43,11 @@ class FairseqModel(nn.Module): ...@@ -43,6 +43,11 @@ class FairseqModel(nn.Module):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.decoder.max_positions() return self.decoder.max_positions()
def upgrade_state_dict(self, state_dict):
state_dict = self.encoder.upgrade_state_dict(state_dict)
state_dict = self.decoder.upgrade_state_dict(state_dict)
return state_dict
def make_generation_fast_(self, **kwargs): def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation.""" """Optimize model for faster generation."""
if self._is_generation_fast: if self._is_generation_fast:
......
...@@ -154,6 +154,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -154,6 +154,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
max_positions=1024, convolutions=((512, 3),) * 20, max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1): attention=True, dropout=0.1):
super().__init__() super().__init__()
self.register_buffer('version', torch.Tensor([2]))
self.dictionary = dictionary self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
...@@ -265,6 +266,16 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -265,6 +266,16 @@ class FConvDecoder(FairseqIncrementalDecoder):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1 return self.embed_positions.num_embeddings - self.dictionary.pad() - 1
def upgrade_state_dict(self, state_dict):
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
# old models use incorrect weight norm dimension
for i, conv in enumerate(self.convolutions):
# reconfigure weight norm
nn.utils.remove_weight_norm(conv)
self.convolutions[i] = nn.utils.weight_norm(conv, dim=0)
state_dict['decoder.version'] = torch.Tensor([1])
return state_dict
def _split_encoder_out(self, encoder_out): def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs. """Split and transpose encoder outputs.
...@@ -307,7 +318,7 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs ...@@ -307,7 +318,7 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std) m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_() m.bias.data.zero_()
return nn.utils.weight_norm(m) return nn.utils.weight_norm(m, dim=2)
def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs): def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
......
...@@ -94,6 +94,7 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device= ...@@ -94,6 +94,7 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device)) map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
) )
state = _upgrade_state_dict(state) state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model'])
# load model parameters # load model parameters
try: try:
...@@ -168,6 +169,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di ...@@ -168,6 +169,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
ensemble = [] ensemble = []
for state in states: for state in states:
model = build_model(args, src_dict, dst_dict) model = build_model(args, src_dict, dst_dict)
state['model'] = model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model']) model.load_state_dict(state['model'])
ensemble.append(model) ensemble.append(model)
return ensemble, args return ensemble, args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment