Unverified Commit ed6b8f31 authored by Lilian Bordeau's avatar Lilian Bordeau Committed by GitHub
Browse files

Update to match renamed attributes in fairseq master (#5972)



* Update to match renamed attributes in fairseq master

RobertaModel no longer have model.encoder and args.num_classes attributes as of 5/28/20.

* Quality
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent d9149f00
...@@ -47,7 +47,7 @@ def convert_roberta_checkpoint_to_pytorch( ...@@ -47,7 +47,7 @@ def convert_roberta_checkpoint_to_pytorch(
""" """
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path) roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
roberta.eval() # disable dropout roberta.eval() # disable dropout
roberta_sent_encoder = roberta.model.decoder.sentence_encoder roberta_sent_encoder = roberta.model.encoder.sentence_encoder
config = RobertaConfig( config = RobertaConfig(
vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings, vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
hidden_size=roberta.args.encoder_embed_dim, hidden_size=roberta.args.encoder_embed_dim,
...@@ -59,7 +59,7 @@ def convert_roberta_checkpoint_to_pytorch( ...@@ -59,7 +59,7 @@ def convert_roberta_checkpoint_to_pytorch(
layer_norm_eps=1e-5, # PyTorch default used in fairseq layer_norm_eps=1e-5, # PyTorch default used in fairseq
) )
if classification_head: if classification_head:
config.num_labels = roberta.args.num_classes config.num_labels = roberta.model.classification_heads["mnli"].out_proj.weight.shape[0]
print("Our BERT config:", config) print("Our BERT config:", config)
model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config) model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config)
...@@ -126,12 +126,12 @@ def convert_roberta_checkpoint_to_pytorch( ...@@ -126,12 +126,12 @@ def convert_roberta_checkpoint_to_pytorch(
model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias
else: else:
# LM Head # LM Head
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight model.lm_head.dense.weight = roberta.model.encoder.lm_head.dense.weight
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias model.lm_head.dense.bias = roberta.model.encoder.lm_head.dense.bias
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight model.lm_head.layer_norm.weight = roberta.model.encoder.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias model.lm_head.layer_norm.bias = roberta.model.encoder.lm_head.layer_norm.bias
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight model.lm_head.decoder.weight = roberta.model.encoder.lm_head.weight
model.lm_head.decoder.bias = roberta.model.decoder.lm_head.bias model.lm_head.decoder.bias = roberta.model.encoder.lm_head.bias
# Let's check that we get the same results. # Let's check that we get the same results.
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment