"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "4efbac6d3593ed35fd5b6ccb3958bd96b2c9b4da"
Unverified Commit 87e4d6c3 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Fixes for the issue with ActLuPrimitive in PAXML (#837)



* fixes for ActLuPrimitive in PAXML

* changed indices for arg_infos in sharding func in dbias_cast_transpose primitive

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 2bdeb6f5
...@@ -2655,16 +2655,20 @@ class ActLuPrimitive(BasePrimitive): ...@@ -2655,16 +2655,20 @@ class ActLuPrimitive(BasePrimitive):
""" """
act_lu partitioning act_lu partitioning
""" """
del result_infos, act_enum del result_infos
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
impl = ActLuPrimitive.impl
return mesh, impl, out_sharding, arg_shardings def sharded_impl(x):
return ActLuPrimitive.impl(x, act_enum=act_enum)
return mesh, sharded_impl, out_sharding, arg_shardings
register_primitive(ActLuPrimitive) register_primitive(ActLuPrimitive)
def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray: def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
""" """
act_lu wrapper act_lu wrapper
...@@ -2779,12 +2783,15 @@ class DActLuPrimitive(BasePrimitive): ...@@ -2779,12 +2783,15 @@ class DActLuPrimitive(BasePrimitive):
""" """
dact_lu partition dact_lu partition
""" """
del result_infos, act_enum del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding out_shardings = dx_sharding
impl = DActLuPrimitive.impl
return mesh, impl, out_shardings, arg_shardings def sharded_impl(dz, x):
return DActLuPrimitive.impl(dz, x, act_enum=act_enum)
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DActLuPrimitive) register_primitive(DActLuPrimitive)
...@@ -4304,20 +4311,20 @@ class DBiasCastTransposePrimitive(BasePrimitive): ...@@ -4304,20 +4311,20 @@ class DBiasCastTransposePrimitive(BasePrimitive):
def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
arg_infos, result_infos): arg_infos, result_infos):
del out_dtype, result_infos del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding( dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod @staticmethod
def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
result_infos): result_infos):
del result_infos del result_infos
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
...@@ -4325,7 +4332,7 @@ class DBiasCastTransposePrimitive(BasePrimitive): ...@@ -4325,7 +4332,7 @@ class DBiasCastTransposePrimitive(BasePrimitive):
dbias_shaprding = NamedSharding( dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding, out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding,
amax_sharding) amax_sharding)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment