Unverified Commit d770886f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Add `tpsp_resource` in the `MeshResource` map (#2113)



* clean up sharding
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added tpsp_resource
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* update tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* rework test for MeshResource
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add mesh_resource into fp8_autocast in test_helper.py
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent d972e76d
......@@ -267,7 +267,10 @@ def train_and_evaluate(args):
) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
mesh_resource=te.MeshResource(
dp_resource=DEVICE_DP_AXIS,
tpsp_resource=DEVICE_TP_AXIS,
),
):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
......
......@@ -264,7 +264,7 @@ def train_and_evaluate(args):
) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None),
mesh_resource=te.MeshResource(dp_resource=DEVICE_DP_AXIS),
):
rng = jax.random.PRNGKey(args.seed)
......
......@@ -382,7 +382,10 @@ def train_and_evaluate(args):
) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
mesh_resource=te.MeshResource(
dp_resource=DEVICE_DP_AXIS,
tpsp_resource=DEVICE_TP_AXIS,
),
):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
......
......@@ -22,7 +22,7 @@ def generate_configs():
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2")
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2")
)
if is_devices_enough(4):
......@@ -30,8 +30,8 @@ def generate_configs():
pytest.param(
4,
(2, 2),
("dp", "tp"),
MeshResource(dp_resource="dp", tp_resource="tp"),
("dp", "tpsp"),
MeshResource(dp_resource="dp", tpsp_resource="tpsp"),
id=f"n4_dp2_tp2",
)
)
......@@ -43,8 +43,8 @@ def generate_context_parallel_configs_for_attn():
"""Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only"""
configsL1 = []
configsL2 = []
mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp")
axes = ("dp", "cp", "tp")
mr = MeshResource(dp_resource="dp", cp_resource="cp", tpsp_resource="tpsp")
axes = ("dp", "cp", "tpsp")
DP_sizes = (1, 2)
CP_sizes = (1, 2, 4, 8)
TP_sizes = (1, 2)
......
......@@ -45,8 +45,8 @@ class TestDistributedSelfAttn:
_, seqlen, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None
tp_size = 1
if mesh_resource.tp_resource is not None:
idx = mesh_axes.index(mesh_resource.tp_resource)
if mesh_resource.tpsp_resource is not None:
idx = mesh_axes.index(mesh_resource.tpsp_resource)
tp_size = mesh_shape[idx]
all_reduce_loss_bytes = 4 # 1 * FP32
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import unittest
import jax
import numpy as np
from utils import pytest_parametrize_wrapper, is_devices_enough
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax import fp8_autocast
def generate_mesh_configs():
configs = []
if is_devices_enough(2):
configs.append(
[2, (1, 2), ("dp", "tpsp"), MeshResource(dp_resource="dp", tpsp_resource="tpsp")]
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ("fsdp", "tp"), MeshResource(tp_resource="tp", fsdp_resource="fsdp")]
)
return configs
class TestMeshResource(unittest.TestCase):
def test_fp8_autocast_with_mesh_resource(self):
for mesh_config in generate_mesh_configs():
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource):
self.assertEqual(mesh_resource, global_mesh_resource())
......@@ -62,16 +62,16 @@ BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64
# Only test with FSDP and TP as DP is not used
def generate_fsdp_and_tp_configs():
# Only test with FSDP and TPSP as DP is not used
def generate_fsdp_and_tpsp_configs():
configs = []
if is_devices_enough(2):
configs.append(
[2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
[2, (1, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")]
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
[4, (2, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")]
)
return configs
......@@ -186,12 +186,12 @@ class TestDistributedLayernormMLP:
with mesh, fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tpsp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
......@@ -247,7 +247,7 @@ class TestDistributedLayernormMLP:
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
......@@ -276,7 +276,7 @@ class TestDistributedLayernormMLP:
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
......@@ -408,7 +408,7 @@ class TestDistributedLayernormMLP:
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
......@@ -429,7 +429,7 @@ class TestDistributedLayernormMLP:
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
......@@ -452,7 +452,7 @@ class TestDistributedLayernormMLP:
)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
......@@ -473,7 +473,7 @@ class TestDistributedLayernormMLP:
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
......
......@@ -41,11 +41,11 @@ class TestDistributedSoftmax:
if not bad_sharding:
x_pspec = PartitionSpec(
mesh_resource.dp_resource, mesh_resource.tp_resource, None, None
mesh_resource.dp_resource, mesh_resource.tpsp_resource, None, None
)
else:
x_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource
mesh_resource.dp_resource, None, None, mesh_resource.tpsp_resource
)
if broadcast_batch_mask:
......
......@@ -397,7 +397,7 @@ class FusedAttnRunner:
self.mesh = Mesh(self.devices, self.mesh_axes)
self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1)
self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
......@@ -630,7 +630,7 @@ class FusedAttnRunner:
self.qkvo_psec = PartitionSpec(
self.mesh_resource.dp_resource,
self.mesh_resource.cp_resource,
self.mesh_resource.tp_resource,
self.mesh_resource.tpsp_resource,
None,
)
self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)
......@@ -658,7 +658,7 @@ class FusedAttnRunner:
if self.bias_shape == BiasShape._1HSS:
self.bias_pspec = PartitionSpec(
None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None
None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None
)
elif self.bias_shape == BiasShape._B1SS:
self.bias_pspec = PartitionSpec(
......
......@@ -71,20 +71,20 @@ class TestFP8Functions(unittest.TestCase):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
......@@ -95,20 +95,22 @@ class TestFP8Functions(unittest.TestCase):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
with fp8_autocast(
enabled=False, fp8_recipe=Float8CurrentScaling(), mesh_resource=MeshResource()
):
self._check_default_state()
self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs):
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs):
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
......@@ -119,46 +121,23 @@ class TestFP8Functions(unittest.TestCase):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
with fp8_autocast(
enabled=False, fp8_recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()
):
self._check_default_state()
self._check_default_state()
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs):
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs):
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
mesh_s = (
(MeshResource(None, None)),
(MeshResource("dp", None)),
(MeshResource(None, "tp")),
(MeshResource("dp", "tp")),
)
# TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with jax.sharding.Mesh(devices, ("dp", "tp")):
for sr in mesh_s:
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(sr, global_mesh_resource())
self._check_default_state()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
LOGICAL_RULES = [
[(("a1", None), ("a2", "ma2")), False],
[(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True],
[(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False],
[(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True],
[(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True],
]
MeshS = [
MeshResource(),
MeshResource("data", None),
MeshResource(None, "model"),
MeshResource("data", "model"),
]
class TestShardingSideAPI:
@pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES)
@pytest.mark.parametrize("sr", MeshS)
def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
with global_shard_guard(sr):
try:
target_te_rules = extend_logical_axis_rules(tuple())
extended_rules = extend_logical_axis_rules(base_rules)
assert extended_rules == (*base_rules, *target_te_rules)
assert not need_assert
except AssertionError as ae:
assert need_assert, f"{ae.args}"
......@@ -38,19 +38,10 @@ from .quantize import fp8_autocast, update_collections, get_delayed_scaling
from .quantize import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType
from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum
MajorShardingType = DeprecatedEnum(
MajorShardingType, "MajorShardingType is deprecating in the near feature."
)
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper(
ShardingResource,
"ShardingResource is renamed to MeshResource, and will be removed in the near feature.",
)
__all__ = [
"NVTE_FP8_COLLECTION_NAME",
......@@ -58,9 +49,6 @@ __all__ = [
"update_collections",
"get_delayed_scaling",
"MeshResource",
"MajorShardingType",
"ShardingResource",
"ShardingType",
"flax",
"quantize",
]
......@@ -453,6 +453,19 @@ class GemmPrimitive(BasePrimitive):
):
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
gsr = global_mesh_resource()
# Ensure that tensor sequence parallelism is not used via setting tp_resource
if gsr.tp_resource is not None:
for i in range(len(lhs_specs) - 1):
if lhs_specs[i] == gsr.tp_resource and lhs_specs[i + 1] == gsr.tp_resource:
warnings.warn(
"Tensor sequence parallelism is detected as"
f" tp_resource='{gsr.tp_resource}' appears twice consecutively in"
f" lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource for"
" tensor sequence parallelism to avoid potential issues."
)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims)
lhs_non_cdims, rhs_non_cdims = map(
......@@ -492,7 +505,7 @@ class GemmPrimitive(BasePrimitive):
# Non-contracting dims of RHS always needs to be gathered along the FSDP axis
rhs_non_cspecs = tuple(
None if spec is not None and spec == global_mesh_resource().fsdp_resource else spec
None if spec is not None and spec == gsr.fsdp_resource else spec
for spec in rhs_non_cspecs
)
......
......@@ -9,10 +9,8 @@ tensor parallelism (TP), pipeline parallelism (PP), and full-sharded data
parallelism (FSDP). It includes functions for sharding constraints, mesh management,
and collective operations.
"""
import os
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import warnings
import jax
......@@ -43,44 +41,46 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
return mesh.shape[resource], resource
def get_sharding_map_logic_axis_to_mesh_axis():
"""
Generate a dict to map logical axes to mesh axes.
"""
def _validate_mesh_resource_configuration():
"""Validate that the mesh resource configuration is consistent and conflict-free."""
gsr = global_mesh_resource()
IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False)))
is_dp_enabled = gsr.dp_resource is not None and get_mesh_axis_size(gsr.dp_resource) > 1
is_tp_enabled = gsr.tp_resource is not None and get_mesh_axis_size(gsr.tp_resource) > 1
is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1
is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1
batch_resources = (
[gsr.fsdp_resource, gsr.dp_resource]
if IS_FSDP_OUTER
else [gsr.dp_resource, gsr.fsdp_resource]
assert not (is_dp_enabled and is_fsdp_enabled), (
"Data parallelism and full-sharded data parallelism cannot be enabled at the same time."
f" Got dp_resource={gsr.dp_resource} and fsdp_resource={gsr.fsdp_resource}"
)
assert not (is_tp_enabled and is_tpsp_enabled), (
"Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time."
f" Got tp_resource={gsr.tp_resource} and tpsp_resource={gsr.tpsp_resource}"
)
batch_dim_rule = []
for resource in batch_resources:
if resource is not None and resource not in batch_dim_rule:
batch_dim_rule.append(resource)
if len(batch_dim_rule) <= 0:
batch_dim_rule = None
elif len(batch_dim_rule) == 1:
batch_dim_rule = batch_dim_rule[0]
else:
batch_dim_rule = tuple(batch_dim_rule)
def get_sharding_map_logic_axis_to_mesh_axis():
"""
Generate a dict to map logical axes to mesh axes.
"""
gsr = global_mesh_resource()
is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1
is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1
te_logical_axis_to_mesh_axis = {
BATCH_AXES: batch_dim_rule,
BATCH_AXES: gsr.fsdp_resource if is_fsdp_enabled else gsr.dp_resource,
SEQLEN_AXES: None,
SEQLEN_TP_AXES: gsr.tp_resource,
SEQLEN_TP_AXES: gsr.tpsp_resource,
SEQLEN_CP_AXES: gsr.cp_resource,
HEAD_AXES: gsr.tp_resource,
HEAD_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
HIDDEN_AXES: None,
HIDDEN_TP_AXES: gsr.tp_resource,
HIDDEN_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
JOINED_AXES: None,
W_NO_SHARD_AXES: None,
W_FSDP_AXES: gsr.fsdp_resource,
W_TP_AXES: gsr.tp_resource,
W_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
W_JOINED_AXES: None,
}
return te_logical_axis_to_mesh_axis
......@@ -274,6 +274,7 @@ class MeshResource:
Attributes:
dp_resource: Axis name for data parallelism (batch sharding), default is None
tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None
tpsp_resource: Axis name for tensor sequence parallelism (hidden and sequence sharding), default is None
fsdp_resource: Axis name for full-sharded data parallelism, default is None
pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
cp_resource: Axis name for context parallelism (sequence sharding), default is None
......@@ -281,6 +282,7 @@ class MeshResource:
dp_resource: str = None
tp_resource: str = None
tpsp_resource: str = None
fsdp_resource: str = None
pp_resource: str = None
cp_resource: str = None
......@@ -303,6 +305,7 @@ def global_shard_guard(resource: MeshResource):
old_resources = _GLOBAL_MESH_RESOURCE
try:
_GLOBAL_MESH_RESOURCE = resource
_validate_mesh_resource_configuration()
yield
finally:
_GLOBAL_MESH_RESOURCE = old_resources
......@@ -351,52 +354,3 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x
# Deprecating Items ---------------------------------------------------------------
ShardingResource = MeshResource
global_shard_resource = global_mesh_resource
class MajorShardingType(Enum):
"""Enumeration of major sharding types for distributed training.
This enum defines the basic sharding patterns available for distributed
training. Note that this class is deprecated and will be removed in the future.
Values:
SINGLE: Single process training
DP: Data parallel training
TP: Standard tensor parallel training
DPTP: Data and standard tensor parallel training
"""
SINGLE = 0
DP = 1
TP = 2
DPTP = 3
class ShardingType(Enum):
"""Enumeration of detailed sharding types for distributed training.
This enum defines specific sharding patterns for distributed training,
including combinations of data parallelism and different tensor parallelism
strategies. Note that this class is deprecated and will be removed in the future.
Values:
SINGLE: No sharding
DP: Sharding along data parallelism
TP_COL: Sharding along column-split tensor parallelism
TP_ROW: Sharding along row-split tensor parallelism
DP_TP_COL: Sharding along data and column-split tensor parallelism
DP_TP_ROW: Sharding along data and row-split tensor parallelism
"""
SINGLE = (MajorShardingType.SINGLE, "single")
DP = (MajorShardingType.DP, "dp")
TP_COL = (MajorShardingType.TP, "tp_col")
TP_ROW = (MajorShardingType.TP, "tp_row")
DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment