Unverified Commit 0ec63afe authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

fix bug in pegasus converter (#7094)

parent b76cb1c3
......@@ -47,8 +47,8 @@ PATTERNS = [
def rename_state_dict_key(k):
for pegasus_name, bart_name in PATTERNS:
k = k.replace(pegasus_name, bart_name)
for pegasus_name, hf_name in PATTERNS:
k = k.replace(pegasus_name, hf_name)
return k
......@@ -57,13 +57,12 @@ def rename_state_dict_key(k):
# TODO(SS): one constant
def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
cfg_kwargs = DEFAULTS.copy()
cfg_kwargs.update(cfg_updates)
cfg = PegasusConfig(**cfg_updates)
bart = PegasusForConditionalGeneration(cfg)
sd = bart.model.state_dict()
cfg = PegasusConfig(**cfg_kwargs)
torch_model = PegasusForConditionalGeneration(cfg)
sd = torch_model.model.state_dict()
mapping = {}
for k, v in tf_weights.items():
new_k = rename_state_dict_key(k)
......@@ -80,13 +79,13 @@ def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForCo
mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"]
empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping}
mapping.update(**empty_biases)
missing, extra = bart.model.load_state_dict(mapping, strict=False)
missing, extra = torch_model.model.load_state_dict(mapping, strict=False)
unexpected_missing = [
k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"]
]
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
assert extra == [], f"no matches found for the following tf keys {extra}"
return bart
return torch_model
def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
......@@ -115,7 +114,7 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str):
cfg_updates = task_specific_params[f"summarization_{dataset}"]
if dataset == "large":
cfg_updates["task_specific_params"] = task_specific_params
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
torch_model = convert_pegasus(tf_weights, cfg_updates)
torch_model.save_pretrained(save_dir)
sd = torch_model.state_dict()
sd.pop("model.decoder.embed_positions.weight")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment