Unverified Commit 343057e1 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Fix bart conversion script (#9923)

* fix conversion script

* typo

* import nn
parent 0e3be1ac
......@@ -22,6 +22,7 @@ from pathlib import Path
import fairseq
import torch
from packaging import version
from torch import nn
from transformers import (
BartConfig,
......@@ -30,7 +31,6 @@ from transformers import (
BartModel,
BartTokenizer,
)
from transformers.models.bart.modeling_bart import _make_linear_from_emb
from transformers.utils import logging
......@@ -78,6 +78,13 @@ def load_xsum_checkpoint(checkpoint_path):
return hub_interface
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
@torch.no_grad()
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
"""
......@@ -119,7 +126,7 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
model.model.load_state_dict(state_dict)
if hasattr(model, "lm_head"):
model.lm_head = _make_linear_from_emb(model.model.shared)
model.lm_head = make_linear_from_emb(model.model.shared)
new_model_outputs = model.model(tokens)[0]
# Check results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment