Unverified Commit 391f2645 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Use main in conversion script (#25973)



* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 6f125aaa
...@@ -9,52 +9,53 @@ library. After conversion, performance (especially for generation) should improv ...@@ -9,52 +9,53 @@ library. After conversion, performance (especially for generation) should improv
without needing trust_remote_code=True. without needing trust_remote_code=True.
""" """
parser = ArgumentParser() if __name__ == "__main__":
parser.add_argument( parser = ArgumentParser()
parser.add_argument(
"--checkpoint_dir", "--checkpoint_dir",
type=Path, type=Path,
required=True, required=True,
help="Directory containing a custom code checkpoint to convert to a modern Falcon checkpoint.", help="Directory containing a custom code checkpoint to convert to a modern Falcon checkpoint.",
) )
args = parser.parse_args() args = parser.parse_args()
if not args.checkpoint_dir.is_dir(): if not args.checkpoint_dir.is_dir():
raise ValueError("--checkpoint_dir argument should be a directory!") raise ValueError("--checkpoint_dir argument should be a directory!")
if ( if (
not (args.checkpoint_dir / "configuration_RW.py").is_file() not (args.checkpoint_dir / "configuration_RW.py").is_file()
or not (args.checkpoint_dir / "modelling_RW.py").is_file() or not (args.checkpoint_dir / "modelling_RW.py").is_file()
): ):
raise ValueError( raise ValueError(
"The model directory should contain configuration_RW.py and modelling_RW.py files! Are you sure this is a custom code checkpoint?" "The model directory should contain configuration_RW.py and modelling_RW.py files! Are you sure this is a custom code checkpoint?"
) )
(args.checkpoint_dir / "configuration_RW.py").unlink() (args.checkpoint_dir / "configuration_RW.py").unlink()
(args.checkpoint_dir / "modelling_RW.py").unlink() (args.checkpoint_dir / "modelling_RW.py").unlink()
config = args.checkpoint_dir / "config.json" config = args.checkpoint_dir / "config.json"
text = config.read_text() text = config.read_text()
text = text.replace("RWForCausalLM", "FalconForCausalLM") text = text.replace("RWForCausalLM", "FalconForCausalLM")
text = text.replace("RefinedWebModel", "falcon") text = text.replace("RefinedWebModel", "falcon")
text = text.replace("RefinedWeb", "falcon") text = text.replace("RefinedWeb", "falcon")
json_config = json.loads(text) json_config = json.loads(text)
del json_config["auto_map"] del json_config["auto_map"]
if "n_head" in json_config: if "n_head" in json_config:
json_config["num_attention_heads"] = json_config.pop("n_head") json_config["num_attention_heads"] = json_config.pop("n_head")
if "n_layer" in json_config: if "n_layer" in json_config:
json_config["num_hidden_layers"] = json_config.pop("n_layer") json_config["num_hidden_layers"] = json_config.pop("n_layer")
if "n_head_kv" in json_config: if "n_head_kv" in json_config:
json_config["num_kv_heads"] = json_config.pop("n_head_kv") json_config["num_kv_heads"] = json_config.pop("n_head_kv")
json_config["new_decoder_architecture"] = True json_config["new_decoder_architecture"] = True
else: else:
json_config["new_decoder_architecture"] = False json_config["new_decoder_architecture"] = False
bos_token_id = json_config.get("bos_token_id", 1) bos_token_id = json_config.get("bos_token_id", 1)
eos_token_id = json_config.get("eos_token_id", 2) eos_token_id = json_config.get("eos_token_id", 2)
config.unlink() config.unlink()
config.write_text(json.dumps(json_config, indent=2, sort_keys=True)) config.write_text(json.dumps(json_config, indent=2, sort_keys=True))
tokenizer_config = args.checkpoint_dir / "tokenizer_config.json" tokenizer_config = args.checkpoint_dir / "tokenizer_config.json"
if tokenizer_config.is_file(): if tokenizer_config.is_file():
text = tokenizer_config.read_text() text = tokenizer_config.read_text()
json_config = json.loads(text) json_config = json.loads(text)
if json_config["tokenizer_class"] == "PreTrainedTokenizerFast": if json_config["tokenizer_class"] == "PreTrainedTokenizerFast":
...@@ -62,12 +63,12 @@ if tokenizer_config.is_file(): ...@@ -62,12 +63,12 @@ if tokenizer_config.is_file():
tokenizer_config.unlink() tokenizer_config.unlink()
tokenizer_config.write_text(json.dumps(json_config, indent=2, sort_keys=True)) tokenizer_config.write_text(json.dumps(json_config, indent=2, sort_keys=True))
generation_config_path = args.checkpoint_dir / "generation_config.json" generation_config_path = args.checkpoint_dir / "generation_config.json"
generation_dict = { generation_dict = {
"_from_model_config": True, "_from_model_config": True,
"bos_token_id": bos_token_id, "bos_token_id": bos_token_id,
"eos_token_id": eos_token_id, "eos_token_id": eos_token_id,
"transformers_version": "4.33.0.dev0", "transformers_version": "4.33.0.dev0",
} }
generation_config_path.write_text(json.dumps(generation_dict, indent=2, sort_keys=True)) generation_config_path.write_text(json.dumps(generation_dict, indent=2, sort_keys=True))
print("Done! Please double-check that the new checkpoint works as expected.") print("Done! Please double-check that the new checkpoint works as expected.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment