Unverified Commit 66ae3030 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Allow DP + FSDP and fixed sr_rng_state partitioning (#2418)



* allow dp + fsdp and fixed sr_rng_state partitioning
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cleanup for lint test
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent f612b749
......@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x_spec = get_padded_spec(arg_infos[0])
amax_spec = get_padded_spec(arg_infos[2])
sr_rng_state_spec = get_padded_spec(arg_infos[3])
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec),
......@@ -551,9 +552,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
)
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
if len(sr_rng_state_spec) > 1:
# sr_rng_state shape [n_devices, state_per_device]
sr_rng_state_spec = (*tuple(x for x in x_spec if x is not None), None)
arg_shardings[3] = NamedSharding(
mesh,
PartitionSpec(tuple(x for x in x_spec if x is not None), None),
PartitionSpec(*sr_rng_state_spec),
desc="BaseDBiasQuantizePrimitive.sr_rng_state",
)
arg_shardings = tuple(arg_shardings)
......@@ -654,9 +658,11 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",)
amax = (BATCHING + prefix + "_amax",)
scale = (BATCHING + prefix + "_scale",)
sr_rng_state = (BATCHING + prefix + "_sr_rng_state",)
if value_types[3].shape != [0]:
sr_rng_state = (
BATCHING + prefix + "_sr_rng_state_partition_axis",
BATCHING + prefix + "sr_rng_state_data_axis",
BATCHING + prefix + "_sr_rng_state_devices",
prefix + "sr_rng_state_data",
)
post_rht_amax = (BATCHING + prefix + "_post_rht_amax",)
......@@ -849,7 +855,7 @@ def _quantize_dbias_impl(
if force_1x_quantization:
q_layout = QuantizeLayout.ROWWISE
sr_rng_state = None
sr_rng_state = jnp.empty((0,), jnp.uint32)
if quantizer.scaling_mode.is_nvfp4_scaling:
# Only NVFP4 scaling modes support stochastic rounding
if quantizer.stochastic_rounding_rng_state is not None:
......@@ -866,11 +872,7 @@ def _quantize_dbias_impl(
x.data,
scale,
amax,
(
sr_rng_state
if sr_rng_state is not None
else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
),
sr_rng_state,
post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
rht_matrix,
out_dtype=quantizer.q_dtype,
......@@ -880,7 +882,7 @@ def _quantize_dbias_impl(
scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
is_outer=True,
stochastic_rounding=sr_rng_state is not None,
stochastic_rounding=sr_rng_state.size != 0,
use_rht=use_rht,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
......
......@@ -44,9 +44,6 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
def _validate_mesh_resource_configuration(mesh_resource):
"""Validate that the mesh resource configuration is consistent and conflict-free."""
is_dp_enabled = (
mesh_resource.dp_resource is not None and get_mesh_axis_size(mesh_resource.dp_resource) > 1
)
is_tp_enabled = (
mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1
)
......@@ -54,16 +51,7 @@ def _validate_mesh_resource_configuration(mesh_resource):
mesh_resource.tpsp_resource is not None
and get_mesh_axis_size(mesh_resource.tpsp_resource) > 1
)
is_fsdp_enabled = (
mesh_resource.fsdp_resource is not None
and get_mesh_axis_size(mesh_resource.fsdp_resource) > 1
)
assert not (is_dp_enabled and is_fsdp_enabled), (
"Data parallelism and full-sharded data parallelism cannot be enabled at the same time."
f" Got dp_resource={mesh_resource.dp_resource} and"
f" fsdp_resource={mesh_resource.fsdp_resource}"
)
assert not (is_tp_enabled and is_tpsp_enabled), (
"Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time."
f" Got tp_resource={mesh_resource.tp_resource} and"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment