Commit f30fc7d7 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix MultiheadAttention and torch hub

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/895

Reviewed By: akinh

Differential Revision: D18246479

Pulled By: myleott

fbshipit-source-id: a610f1e4943619d32a523601a572fb09cdc5638d
parent 856d8b82
......@@ -101,7 +101,7 @@ class RobertaHubInterface(nn.Module):
)
def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
features = self.extract_features(tokens)
features = self.extract_features(tokens.to(device=self.device))
logits = self.model.classification_heads[head](features)
if return_logits:
return logits
......
......@@ -146,6 +146,8 @@ class RobertaModel(FairseqLanguageModel):
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
prefix = name + '.' if name != '' else ''
current_head_names = [] if not hasattr(self, 'classification_heads') else \
self.classification_heads.keys()
......
......@@ -187,7 +187,7 @@ class Wav2VecModel(BaseFairseqModel):
return result
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
super().upgrade_state_dict_named(state_dict, name)
def max_positions(self):
"""Maximum length supported by the model."""
......
......@@ -63,11 +63,6 @@ class MultiheadAttention(nn.Module):
else:
self.enable_torch_version = False
@property
def in_proj_weight(self):
# TODO: Remove this backward compatibility code (in_proj_weight)
return torch.cat((self.q_proj.weight, self.k_proj.weight, self.v_proj.weight))
@property
def in_proj_bias(self):
# TODO: Remove this backward compatibility code (in_proj_bias)
......@@ -312,8 +307,6 @@ class MultiheadAttention(nn.Module):
return attn_weights
def upgrade_state_dict_named(self, state_dict, name):
# TODO: Remove this backward compatibility code (in_proj_weight)
# here, we convert in_proj_weight to individual q,k,v weights
prefix = name + '.' if name != '' else ''
items_to_add = {}
keys_to_remove = []
......@@ -341,5 +334,3 @@ class MultiheadAttention(nn.Module):
for key, value in items_to_add.items():
state_dict[key] = value
return state_dict
......@@ -40,7 +40,9 @@ def init_bert_params(module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
module.q_proj.weight.data.normal_(mean=0.0, std=0.02)
module.k_proj.weight.data.normal_(mean=0.0, std=0.02)
module.v_proj.weight.data.normal_(mean=0.0, std=0.02)
class TransformerSentenceEncoder(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment