"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c9d2e855ea20cfa1767ed26ca58db430a3bc170a"
Unverified Commit 391f2645 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Use main in conversion script (#25973)



* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 6f125aaa
...@@ -9,65 +9,66 @@ library. After conversion, performance (especially for generation) should improv ...@@ -9,65 +9,66 @@ library. After conversion, performance (especially for generation) should improv
without needing trust_remote_code=True. without needing trust_remote_code=True.
""" """
parser = ArgumentParser() if __name__ == "__main__":
parser.add_argument( parser = ArgumentParser()
"--checkpoint_dir", parser.add_argument(
type=Path, "--checkpoint_dir",
required=True, type=Path,
help="Directory containing a custom code checkpoint to convert to a modern Falcon checkpoint.", required=True,
) help="Directory containing a custom code checkpoint to convert to a modern Falcon checkpoint.",
args = parser.parse_args()
if not args.checkpoint_dir.is_dir():
raise ValueError("--checkpoint_dir argument should be a directory!")
if (
not (args.checkpoint_dir / "configuration_RW.py").is_file()
or not (args.checkpoint_dir / "modelling_RW.py").is_file()
):
raise ValueError(
"The model directory should contain configuration_RW.py and modelling_RW.py files! Are you sure this is a custom code checkpoint?"
) )
(args.checkpoint_dir / "configuration_RW.py").unlink() args = parser.parse_args()
(args.checkpoint_dir / "modelling_RW.py").unlink()
config = args.checkpoint_dir / "config.json" if not args.checkpoint_dir.is_dir():
text = config.read_text() raise ValueError("--checkpoint_dir argument should be a directory!")
text = text.replace("RWForCausalLM", "FalconForCausalLM")
text = text.replace("RefinedWebModel", "falcon")
text = text.replace("RefinedWeb", "falcon")
json_config = json.loads(text)
del json_config["auto_map"]
if "n_head" in json_config: if (
json_config["num_attention_heads"] = json_config.pop("n_head") not (args.checkpoint_dir / "configuration_RW.py").is_file()
if "n_layer" in json_config: or not (args.checkpoint_dir / "modelling_RW.py").is_file()
json_config["num_hidden_layers"] = json_config.pop("n_layer") ):
if "n_head_kv" in json_config: raise ValueError(
json_config["num_kv_heads"] = json_config.pop("n_head_kv") "The model directory should contain configuration_RW.py and modelling_RW.py files! Are you sure this is a custom code checkpoint?"
json_config["new_decoder_architecture"] = True )
else: (args.checkpoint_dir / "configuration_RW.py").unlink()
json_config["new_decoder_architecture"] = False (args.checkpoint_dir / "modelling_RW.py").unlink()
bos_token_id = json_config.get("bos_token_id", 1)
eos_token_id = json_config.get("eos_token_id", 2)
config.unlink()
config.write_text(json.dumps(json_config, indent=2, sort_keys=True))
tokenizer_config = args.checkpoint_dir / "tokenizer_config.json" config = args.checkpoint_dir / "config.json"
if tokenizer_config.is_file(): text = config.read_text()
text = tokenizer_config.read_text() text = text.replace("RWForCausalLM", "FalconForCausalLM")
text = text.replace("RefinedWebModel", "falcon")
text = text.replace("RefinedWeb", "falcon")
json_config = json.loads(text) json_config = json.loads(text)
if json_config["tokenizer_class"] == "PreTrainedTokenizerFast": del json_config["auto_map"]
json_config["model_input_names"] = ["input_ids", "attention_mask"]
tokenizer_config.unlink() if "n_head" in json_config:
tokenizer_config.write_text(json.dumps(json_config, indent=2, sort_keys=True)) json_config["num_attention_heads"] = json_config.pop("n_head")
if "n_layer" in json_config:
json_config["num_hidden_layers"] = json_config.pop("n_layer")
if "n_head_kv" in json_config:
json_config["num_kv_heads"] = json_config.pop("n_head_kv")
json_config["new_decoder_architecture"] = True
else:
json_config["new_decoder_architecture"] = False
bos_token_id = json_config.get("bos_token_id", 1)
eos_token_id = json_config.get("eos_token_id", 2)
config.unlink()
config.write_text(json.dumps(json_config, indent=2, sort_keys=True))
tokenizer_config = args.checkpoint_dir / "tokenizer_config.json"
if tokenizer_config.is_file():
text = tokenizer_config.read_text()
json_config = json.loads(text)
if json_config["tokenizer_class"] == "PreTrainedTokenizerFast":
json_config["model_input_names"] = ["input_ids", "attention_mask"]
tokenizer_config.unlink()
tokenizer_config.write_text(json.dumps(json_config, indent=2, sort_keys=True))
generation_config_path = args.checkpoint_dir / "generation_config.json" generation_config_path = args.checkpoint_dir / "generation_config.json"
generation_dict = { generation_dict = {
"_from_model_config": True, "_from_model_config": True,
"bos_token_id": bos_token_id, "bos_token_id": bos_token_id,
"eos_token_id": eos_token_id, "eos_token_id": eos_token_id,
"transformers_version": "4.33.0.dev0", "transformers_version": "4.33.0.dev0",
} }
generation_config_path.write_text(json.dumps(generation_dict, indent=2, sort_keys=True)) generation_config_path.write_text(json.dumps(generation_dict, indent=2, sort_keys=True))
print("Done! Please double-check that the new checkpoint works as expected.") print("Done! Please double-check that the new checkpoint works as expected.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment