Unverified Commit d522afea authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Gemma`] Supports converting directly in half-precision (#29529)

* Update convert_gemma_weights_to_hf.py

* Update src/transformers/models/gemma/convert_gemma_weights_to_hf.py

* fixup
parent d4796653
......@@ -65,7 +65,7 @@ CONFIG_MAPPING = {"2B": gemma_2b_config, "7B": gemma_7b_config}
LAYER_NAME_MAPPING = {"embedder.weight": "model.embed_tokens.weight"}
def write_model(save_path, input_base_path, config, safe_serialization=True, push_to_hub=False):
def write_model(save_path, input_base_path, config, safe_serialization=True, push_to_hub=False, dtype=torch.float32):
num_attn_heads = config.num_attention_heads
hidden_size = config.hidden_size
num_kv_heads = config.num_key_value_heads
......@@ -107,6 +107,8 @@ def write_model(save_path, input_base_path, config, safe_serialization=True, pus
else:
state_dict[k] = v
torch.set_default_dtype(dtype)
print("Loading the checkpoint in a Gemma model.")
with init_empty_weights():
model = GemmaForCausalLM(config)
......@@ -174,6 +176,11 @@ def main():
action="store_true",
default=False,
)
parser.add_argument(
"--dtype",
default="float32",
help="Target dtype of the converted model",
)
args = parser.parse_args()
if args.convert_tokenizer:
......@@ -184,12 +191,14 @@ def main():
write_tokenizer(spm_path, args.output_dir, args.push_to_hub)
config = CONFIG_MAPPING[args.model_size]
dtype = getattr(torch, args.dtype)
write_model(
config=config,
input_base_path=args.input_checkpoint,
save_path=args.output_dir,
safe_serialization=not args.pickle_serialization,
push_to_hub=args.push_to_hub,
dtype=dtype,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment