Commit 6a21b232 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Backward compatibility + updated links for pretrained language models

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/624

Differential Revision: D15595746

Pulled By: myleott

fbshipit-source-id: b79e489de9ff37ee7cbf939092a6e5ec0dbebbf5
parent 8c03ff2d
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
Description | Parameters | Dataset | Model and Test set(s) Description | Parameters | Dataset | Model and Test set(s)
---|---:|---|--- ---|---:|---|---
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2) Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2) Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2)
## Example usage ## Example usage
......
...@@ -473,11 +473,13 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -473,11 +473,13 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if k in state_dict: if k in state_dict:
state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k] state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k]
del state_dict[k] del state_dict[k]
if utils.item(state_dict.get('{}.version'.format(name), torch.Tensor([1]))[0]) < 2:
version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers # earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None self.layer_norm = None
self.normalize = False self.normalize = False
state_dict['{}.version'.format(name)] = torch.Tensor([1]) state_dict[version_key] = torch.Tensor([1])
return state_dict return state_dict
......
...@@ -98,10 +98,6 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -98,10 +98,6 @@ class TransformerLanguageModel(FairseqLanguageModel):
# make sure all arguments are present in older models # make sure all arguments are present in older models
base_lm_architecture(args) base_lm_architecture(args)
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj is False:
# backward compatibility
args.tie_adaptive_proj = True
if getattr(args, 'max_target_positions', None) is None: if getattr(args, 'max_target_positions', None) is None:
args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS) args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)
...@@ -135,6 +131,14 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -135,6 +131,14 @@ class TransformerLanguageModel(FairseqLanguageModel):
@register_model_architecture('transformer_lm', 'transformer_lm') @register_model_architecture('transformer_lm', 'transformer_lm')
def base_lm_architecture(args): def base_lm_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, 'decoder_final_norm'):
args.no_decoder_final_norm = not args.decoder_final_norm
if not hasattr(args, 'no_decoder_final_norm'):
args.no_decoder_final_norm = True # old models always set this to True
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048) args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
args.decoder_layers = getattr(args, 'decoder_layers', 6) args.decoder_layers = getattr(args, 'decoder_layers', 6)
...@@ -154,6 +158,7 @@ def base_lm_architecture(args): ...@@ -154,6 +158,7 @@ def base_lm_architecture(args):
# Model training is not stable without this # Model training is not stable without this
args.decoder_normalize_before = True args.decoder_normalize_before = True
args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', False)
args.adaptive_input = getattr(args, 'adaptive_input', False) args.adaptive_input = getattr(args, 'adaptive_input', False)
args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4) args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment