Unverified Commit e17fab14 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Fix partitioning issues in LayerNorm and LayerNormMLP layers (#1743)



* Enforce input sharding of norm primitive does not shard hidden dim
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix partitioning issue in dact primitive causing NaN and add better shape checks before calling TE API
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Move dact shape assertion from cpp to python
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent a9656283
...@@ -501,6 +501,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -501,6 +501,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
ir_hidden_size = dz_aval.shape[-1] ir_hidden_size = dz_aval.shape[-1]
gi_hidden_size = act_len * x_aval.shape[-1] gi_hidden_size = act_len * x_aval.shape[-1]
assert act_len * ir_hidden_size == gi_hidden_size assert act_len * ir_hidden_size == gi_hidden_size
assert (
x_aval.shape[:-2] == dz_aval.shape[:-1]
), "dz and x should have the same leading dimensions"
out_shape = x_aval.shape out_shape = x_aval.shape
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
...@@ -821,8 +824,12 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -821,8 +824,12 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv" mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
) )
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
# Ensure dz and x are partitioned the same way.
arg_shardings[0] = NamedSharding(
mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]), desc="DActLuDBiasQuantizePrimitive.dz"
)
arg_shardings = tuple(arg_shardings)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
......
...@@ -502,7 +502,16 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -502,7 +502,16 @@ class NormFwdPrimitive(BasePrimitive):
) )
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax")
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
# Enforce no sharding of hidden dim for x, gamma and beta
arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.x")
arg_shardings[2] = NamedSharding(
mesh, PartitionSpec(*g_spec[:-1], None), desc="NormFwdPrimitive.gamma"
)
arg_shardings[3] = NamedSharding(
mesh, PartitionSpec(*b_spec[:-1], None), desc="NormFwdPrimitive.beta"
)
arg_shardings = tuple(arg_shardings)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
......
...@@ -245,8 +245,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -245,8 +245,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims // n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
auto act_len = act_input_dims[act_input_dims.size() - 2]; auto act_len = act_input_dims[act_input_dims.size() - 2];
NVTE_CHECK(act_input_dims.back() == input_dims.back(), NVTE_CHECK(act_len == 1 || act_len == 2,
"Shape mismatch between activation input and gradient input"); "The value of the activation dimension (axis=-2) must be 1 for non-gated or 2 for "
"gated activation, got ",
act_len);
auto m = product(act_input_dims, 0, act_input_dims.size() - 2); auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = input_dims.back(); auto n = input_dims.back();
...@@ -257,8 +260,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -257,8 +260,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto dbias_shape = std::vector<size_t>{n * act_len}; auto dbias_shape = std::vector<size_t>{n * act_len};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end()); std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor =
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); TensorWrapper(input, input_shape, convert_ffi_datatype_to_te_dtype(input_buf.element_type()));
auto act_input_tensor = TensorWrapper(
act_input, act_input_shape, convert_ffi_datatype_to_te_dtype(act_input_buf.element_type()));
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment