Unverified Commit 818c5318 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Fix the Failures on Partition of ActPrimitives (#848)



Remove act_enum from the del list ActLuPrimitive.partition
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 05eb6deb
...@@ -2661,7 +2661,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -2661,7 +2661,7 @@ class ActLuPrimitive(BasePrimitive):
""" """
act_lu partitioning act_lu partitioning
""" """
del result_infos, act_enum del result_infos
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
...@@ -2790,7 +2790,7 @@ class DActLuPrimitive(BasePrimitive): ...@@ -2790,7 +2790,7 @@ class DActLuPrimitive(BasePrimitive):
""" """
dact_lu partition dact_lu partition
""" """
del result_infos, act_enum del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding out_shardings = dx_sharding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment