Unverified Commit 854a0d52 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Handle PyTorch to Flax conversion of 1D convolutions (#15519)

parent 486260c6
......@@ -88,6 +88,12 @@ def rename_key_and_reshape_tensor(
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
return renamed_pt_tuple_key, pt_tensor
# conv1d layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 3 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
pt_tensor = pt_tensor.transpose(2, 1, 0)
return renamed_pt_tuple_key, pt_tensor
# linear layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment