Unverified Commit 0ec63afe authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

fix bug in pegasus converter (#7094)

parent b76cb1c3
...@@ -47,8 +47,8 @@ PATTERNS = [ ...@@ -47,8 +47,8 @@ PATTERNS = [
def rename_state_dict_key(k): def rename_state_dict_key(k):
for pegasus_name, bart_name in PATTERNS: for pegasus_name, hf_name in PATTERNS:
k = k.replace(pegasus_name, bart_name) k = k.replace(pegasus_name, hf_name)
return k return k
...@@ -57,13 +57,12 @@ def rename_state_dict_key(k): ...@@ -57,13 +57,12 @@ def rename_state_dict_key(k):
# TODO(SS): one constant # TODO(SS): one constant
def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration: def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
cfg_kwargs = DEFAULTS.copy() cfg_kwargs = DEFAULTS.copy()
cfg_kwargs.update(cfg_updates) cfg_kwargs.update(cfg_updates)
cfg = PegasusConfig(**cfg_kwargs)
cfg = PegasusConfig(**cfg_updates) torch_model = PegasusForConditionalGeneration(cfg)
bart = PegasusForConditionalGeneration(cfg) sd = torch_model.model.state_dict()
sd = bart.model.state_dict()
mapping = {} mapping = {}
for k, v in tf_weights.items(): for k, v in tf_weights.items():
new_k = rename_state_dict_key(k) new_k = rename_state_dict_key(k)
...@@ -80,13 +79,13 @@ def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForCo ...@@ -80,13 +79,13 @@ def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForCo
mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"] mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"]
empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping} empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping}
mapping.update(**empty_biases) mapping.update(**empty_biases)
missing, extra = bart.model.load_state_dict(mapping, strict=False) missing, extra = torch_model.model.load_state_dict(mapping, strict=False)
unexpected_missing = [ unexpected_missing = [
k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"] k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"]
] ]
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}" assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
assert extra == [], f"no matches found for the following tf keys {extra}" assert extra == [], f"no matches found for the following tf keys {extra}"
return bart return torch_model
def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict: def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
...@@ -115,7 +114,7 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str): ...@@ -115,7 +114,7 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str):
cfg_updates = task_specific_params[f"summarization_{dataset}"] cfg_updates = task_specific_params[f"summarization_{dataset}"]
if dataset == "large": if dataset == "large":
cfg_updates["task_specific_params"] = task_specific_params cfg_updates["task_specific_params"] = task_specific_params
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates) torch_model = convert_pegasus(tf_weights, cfg_updates)
torch_model.save_pretrained(save_dir) torch_model.save_pretrained(save_dir)
sd = torch_model.state_dict() sd = torch_model.state_dict()
sd.pop("model.decoder.embed_positions.weight") sd.pop("model.decoder.embed_positions.weight")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment