"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "f8cca8b95de7e5637bc52f1b55113b38a6ad0774"
Commit e589e28c authored by Phuong Nguyen's avatar Phuong Nguyen Committed by Kshitij Janardan Lakhani
Browse files

[JAX] Allow DP + FSDP and fixed sr_rng_state partitioning (#2418)



* allow dp + fsdp and fixed sr_rng_state partitioning
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cleanup for lint test
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 7ab2c9c4
...@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
amax_spec = get_padded_spec(arg_infos[2]) amax_spec = get_padded_spec(arg_infos[2])
sr_rng_state_spec = get_padded_spec(arg_infos[3])
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*x_spec), PartitionSpec(*x_spec),
...@@ -551,11 +552,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -551,11 +552,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
) )
arg_shardings = list(arg_i.sharding for arg_i in arg_infos) arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[3] = NamedSharding( if len(sr_rng_state_spec) > 1:
mesh, # sr_rng_state shape [n_devices, state_per_device]
PartitionSpec(tuple(x for x in x_spec if x is not None), None), sr_rng_state_spec = (*tuple(x for x in x_spec if x is not None), None)
desc="BaseDBiasQuantizePrimitive.sr_rng_state", arg_shardings[3] = NamedSharding(
) mesh,
PartitionSpec(*sr_rng_state_spec),
desc="BaseDBiasQuantizePrimitive.sr_rng_state",
)
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
...@@ -654,10 +658,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -654,10 +658,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",) dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",)
amax = (BATCHING + prefix + "_amax",) amax = (BATCHING + prefix + "_amax",)
scale = (BATCHING + prefix + "_scale",) scale = (BATCHING + prefix + "_scale",)
sr_rng_state = ( sr_rng_state = (BATCHING + prefix + "_sr_rng_state",)
BATCHING + prefix + "_sr_rng_state_partition_axis", if value_types[3].shape != [0]:
BATCHING + prefix + "sr_rng_state_data_axis", sr_rng_state = (
) BATCHING + prefix + "_sr_rng_state_devices",
prefix + "sr_rng_state_data",
)
post_rht_amax = (BATCHING + prefix + "_post_rht_amax",) post_rht_amax = (BATCHING + prefix + "_post_rht_amax",)
rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2") rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2")
...@@ -849,7 +855,7 @@ def _quantize_dbias_impl( ...@@ -849,7 +855,7 @@ def _quantize_dbias_impl(
if force_1x_quantization: if force_1x_quantization:
q_layout = QuantizeLayout.ROWWISE q_layout = QuantizeLayout.ROWWISE
sr_rng_state = None sr_rng_state = jnp.empty((0,), jnp.uint32)
if quantizer.scaling_mode.is_nvfp4_scaling: if quantizer.scaling_mode.is_nvfp4_scaling:
# Only NVFP4 scaling modes support stochastic rounding # Only NVFP4 scaling modes support stochastic rounding
if quantizer.stochastic_rounding_rng_state is not None: if quantizer.stochastic_rounding_rng_state is not None:
...@@ -866,11 +872,7 @@ def _quantize_dbias_impl( ...@@ -866,11 +872,7 @@ def _quantize_dbias_impl(
x.data, x.data,
scale, scale,
amax, amax,
( sr_rng_state,
sr_rng_state
if sr_rng_state is not None
else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
),
post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32), post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
rht_matrix, rht_matrix,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
...@@ -880,7 +882,7 @@ def _quantize_dbias_impl( ...@@ -880,7 +882,7 @@ def _quantize_dbias_impl(
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False, is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
is_outer=True, is_outer=True,
stochastic_rounding=sr_rng_state is not None, stochastic_rounding=sr_rng_state.size != 0,
use_rht=use_rht, use_rht=use_rht,
) )
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
......
...@@ -44,9 +44,6 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): ...@@ -44,9 +44,6 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
def _validate_mesh_resource_configuration(mesh_resource): def _validate_mesh_resource_configuration(mesh_resource):
"""Validate that the mesh resource configuration is consistent and conflict-free.""" """Validate that the mesh resource configuration is consistent and conflict-free."""
is_dp_enabled = (
mesh_resource.dp_resource is not None and get_mesh_axis_size(mesh_resource.dp_resource) > 1
)
is_tp_enabled = ( is_tp_enabled = (
mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1 mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1
) )
...@@ -54,16 +51,7 @@ def _validate_mesh_resource_configuration(mesh_resource): ...@@ -54,16 +51,7 @@ def _validate_mesh_resource_configuration(mesh_resource):
mesh_resource.tpsp_resource is not None mesh_resource.tpsp_resource is not None
and get_mesh_axis_size(mesh_resource.tpsp_resource) > 1 and get_mesh_axis_size(mesh_resource.tpsp_resource) > 1
) )
is_fsdp_enabled = (
mesh_resource.fsdp_resource is not None
and get_mesh_axis_size(mesh_resource.fsdp_resource) > 1
)
assert not (is_dp_enabled and is_fsdp_enabled), (
"Data parallelism and full-sharded data parallelism cannot be enabled at the same time."
f" Got dp_resource={mesh_resource.dp_resource} and"
f" fsdp_resource={mesh_resource.fsdp_resource}"
)
assert not (is_tp_enabled and is_tpsp_enabled), ( assert not (is_tp_enabled and is_tpsp_enabled), (
"Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time." "Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time."
f" Got tp_resource={mesh_resource.tp_resource} and" f" Got tp_resource={mesh_resource.tp_resource} and"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment