Commit f30fc7d7 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix MultiheadAttention and torch hub

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/895

Reviewed By: akinh

Differential Revision: D18246479

Pulled By: myleott

fbshipit-source-id: a610f1e4943619d32a523601a572fb09cdc5638d
parent 856d8b82
...@@ -101,7 +101,7 @@ class RobertaHubInterface(nn.Module): ...@@ -101,7 +101,7 @@ class RobertaHubInterface(nn.Module):
) )
def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False): def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
features = self.extract_features(tokens) features = self.extract_features(tokens.to(device=self.device))
logits = self.model.classification_heads[head](features) logits = self.model.classification_heads[head](features)
if return_logits: if return_logits:
return logits return logits
......
...@@ -146,6 +146,8 @@ class RobertaModel(FairseqLanguageModel): ...@@ -146,6 +146,8 @@ class RobertaModel(FairseqLanguageModel):
return RobertaHubInterface(x['args'], x['task'], x['models'][0]) return RobertaHubInterface(x['args'], x['task'], x['models'][0])
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
prefix = name + '.' if name != '' else '' prefix = name + '.' if name != '' else ''
current_head_names = [] if not hasattr(self, 'classification_heads') else \ current_head_names = [] if not hasattr(self, 'classification_heads') else \
self.classification_heads.keys() self.classification_heads.keys()
......
...@@ -187,7 +187,7 @@ class Wav2VecModel(BaseFairseqModel): ...@@ -187,7 +187,7 @@ class Wav2VecModel(BaseFairseqModel):
return result return result
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
return state_dict super().upgrade_state_dict_named(state_dict, name)
def max_positions(self): def max_positions(self):
"""Maximum length supported by the model.""" """Maximum length supported by the model."""
......
...@@ -63,11 +63,6 @@ class MultiheadAttention(nn.Module): ...@@ -63,11 +63,6 @@ class MultiheadAttention(nn.Module):
else: else:
self.enable_torch_version = False self.enable_torch_version = False
@property
def in_proj_weight(self):
# TODO: Remove this backward compatibility code (in_proj_weight)
return torch.cat((self.q_proj.weight, self.k_proj.weight, self.v_proj.weight))
@property @property
def in_proj_bias(self): def in_proj_bias(self):
# TODO: Remove this backward compatibility code (in_proj_bias) # TODO: Remove this backward compatibility code (in_proj_bias)
...@@ -312,8 +307,6 @@ class MultiheadAttention(nn.Module): ...@@ -312,8 +307,6 @@ class MultiheadAttention(nn.Module):
return attn_weights return attn_weights
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
# TODO: Remove this backward compatibility code (in_proj_weight)
# here, we convert in_proj_weight to individual q,k,v weights
prefix = name + '.' if name != '' else '' prefix = name + '.' if name != '' else ''
items_to_add = {} items_to_add = {}
keys_to_remove = [] keys_to_remove = []
...@@ -341,5 +334,3 @@ class MultiheadAttention(nn.Module): ...@@ -341,5 +334,3 @@ class MultiheadAttention(nn.Module):
for key, value in items_to_add.items(): for key, value in items_to_add.items():
state_dict[key] = value state_dict[key] = value
return state_dict
...@@ -40,7 +40,9 @@ def init_bert_params(module): ...@@ -40,7 +40,9 @@ def init_bert_params(module):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention): if isinstance(module, MultiheadAttention):
module.in_proj_weight.data.normal_(mean=0.0, std=0.02) module.q_proj.weight.data.normal_(mean=0.0, std=0.02)
module.k_proj.weight.data.normal_(mean=0.0, std=0.02)
module.v_proj.weight.data.normal_(mean=0.0, std=0.02)
class TransformerSentenceEncoder(nn.Module): class TransformerSentenceEncoder(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment