"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a4713054873b0b695246a6d7fc5d5edc2c48052a"
Commit c0a5d29e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix torch.hub for MNLI

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/1006

Differential Revision: D16753078

Pulled By: myleott

fbshipit-source-id: 970055632edffcce4e75931ed93b42a249120a4a
parent 83249196
...@@ -76,6 +76,8 @@ class RobertaModel(FairseqLanguageModel): ...@@ -76,6 +76,8 @@ class RobertaModel(FairseqLanguageModel):
help='dropout probability in the masked_lm pooler layers') help='dropout probability in the masked_lm pooler layers')
parser.add_argument('--max-positions', type=int, parser.add_argument('--max-positions', type=int,
help='number of positional embeddings to learn') help='number of positional embeddings to learn')
parser.add_argument('--load-checkpoint-heads', action='store_true',
help='(re-)register and load heads when loading checkpoints')
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
...@@ -92,7 +94,7 @@ class RobertaModel(FairseqLanguageModel): ...@@ -92,7 +94,7 @@ class RobertaModel(FairseqLanguageModel):
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs): def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs):
assert classification_head_name is None or features_only, \ assert classification_head_name is None or features_only, \
"If passing classification_head_name argument, features_only must be set to True" 'If passing classification_head_name argument, features_only must be set to True'
x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs) x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs)
...@@ -102,6 +104,16 @@ class RobertaModel(FairseqLanguageModel): ...@@ -102,6 +104,16 @@ class RobertaModel(FairseqLanguageModel):
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
"""Register a classification head.""" """Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
print(
'WARNING: re-registering head "{}" with num_classes {} (prev: {}) '
'and inner_dim {} (prev: {})'.format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = RobertaClassificationHead( self.classification_heads[name] = RobertaClassificationHead(
self.args.encoder_embed_dim, self.args.encoder_embed_dim,
inner_dim or self.args.encoder_embed_dim, inner_dim or self.args.encoder_embed_dim,
...@@ -123,6 +135,7 @@ class RobertaModel(FairseqLanguageModel): ...@@ -123,6 +135,7 @@ class RobertaModel(FairseqLanguageModel):
data_name_or_path, data_name_or_path,
archive_map=cls.hub_models(), archive_map=cls.hub_models(),
bpe='gpt2', bpe='gpt2',
load_checkpoint_heads=True,
**kwargs, **kwargs,
) )
return RobertaHubInterface(x['args'], x['task'], x['models'][0]) return RobertaHubInterface(x['args'], x['task'], x['models'][0])
...@@ -132,30 +145,35 @@ class RobertaModel(FairseqLanguageModel): ...@@ -132,30 +145,35 @@ class RobertaModel(FairseqLanguageModel):
current_head_names = [] if not hasattr(self, 'classification_heads') else \ current_head_names = [] if not hasattr(self, 'classification_heads') else \
self.classification_heads.keys() self.classification_heads.keys()
# Handle new classification heads present in the state dict.
keys_to_delete = [] keys_to_delete = []
# Delete any heads present in state_dict, that are not in current constructed model.
for k in state_dict.keys(): for k in state_dict.keys():
if not k.startswith(prefix + 'classification_heads.'): if not k.startswith(prefix + 'classification_heads.'):
continue continue
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0] head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
num_classes = state_dict[ num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
prefix + 'classification_heads.' + head_name + '.out_proj.weight' inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
].size(0)
inner_dim = state_dict[ if getattr(self.args, 'load_checkpoint_heads', False):
prefix + 'classification_heads.' + head_name + '.dense.weight' if head_name not in current_head_names:
].size(0) self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names: if head_name not in current_head_names:
print("WARNING: deleting classification head ({}) from checkpoint not present in current model: {}".format(head_name, k)) print(
keys_to_delete.append(k) 'WARNING: deleting classification head ({}) from checkpoint '
elif ( 'not present in current model: {}'.format(head_name, k)
num_classes != self.classification_heads[head_name].out_proj.out_features )
or inner_dim != self.classification_heads[head_name].dense.out_features keys_to_delete.append(k)
): elif (
print("WARNING: deleting classification head ({}) from checkpoint with different dimensions than current model: {}".format(head_name, k)) num_classes != self.classification_heads[head_name].out_proj.out_features
keys_to_delete.append(k) or inner_dim != self.classification_heads[head_name].dense.out_features
):
print(
'WARNING: deleting classification head ({}) from checkpoint '
'with different dimensions than current model: {}'.format(head_name, k)
)
keys_to_delete.append(k)
for k in keys_to_delete: for k in keys_to_delete:
del state_dict[k] del state_dict[k]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment