Unverified Commit 818c5318 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Fix the Failures on Partition of ActPrimitives (#848)



Remove act_enum from the del list ActLuPrimitive.partition
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 05eb6deb
......@@ -2661,7 +2661,7 @@ class ActLuPrimitive(BasePrimitive):
"""
act_lu partitioning
"""
del result_infos, act_enum
del result_infos
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
......@@ -2790,7 +2790,7 @@ class DActLuPrimitive(BasePrimitive):
"""
dact_lu partition
"""
del result_infos, act_enum
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment