Unverified Commit 2dc2d79a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

correct conversion (#11394)

parent b48cf712
...@@ -86,7 +86,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -86,7 +86,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
# Correctly rename weight parameters # Correctly rename weight parameters
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: if pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("scale",) pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment