sharding.py 12.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
"""Sharding utilities for Transformer Engine in JAX.

This module provides utilities for managing tensor sharding in distributed training,
including support for various parallelism strategies like data parallelism (DP),
tensor parallelism (TP), pipeline parallelism (PP), and full-sharded data
parallelism (FSDP). It includes functions for sharding constraints, mesh management,
and collective operations.
11
12
13
"""
from contextlib import contextmanager
from dataclasses import dataclass
14
from typing import Callable, Optional
15
import warnings
16
17
import jax
import jax.numpy as jnp
18
19
from jax.interpreters import pxla
from jax.sharding import PartitionSpec, get_abstract_mesh
20
import numpy as np
21
22
23

_PXLA_THREAD_RESOURCES = pxla.thread_resources

24
# Axis Names
25
26
27
BATCH_AXES = "nvte_batch"
SEQLEN_AXES = "nvte_seqlen"
SEQLEN_TP_AXES = "nvte_seqlen_tp"
28
SEQLEN_CP_AXES = "nvte_seqlen_cp"
29
30
31
32
33
34
35
36
HEAD_AXES = "nvte_head"
HIDDEN_AXES = "nvte_hidden"
HIDDEN_TP_AXES = "nvte_hidden_tp"
JOINED_AXES = "nvte_joined"
W_NO_SHARD_AXES = "nvte_w_no_shard"
W_FSDP_AXES = "nvte_w_fsdp"
W_TP_AXES = "nvte_w_tp"
W_JOINED_AXES = "nvte_w_joined"
37

38

39
def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
40
    assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
41
42
43
    return mesh.shape[resource], resource


44
def _validate_mesh_resource_configuration(mesh_resource):
45
    """Validate that the mesh resource configuration is consistent and conflict-free."""
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    is_dp_enabled = (
        mesh_resource.dp_resource is not None and get_mesh_axis_size(mesh_resource.dp_resource) > 1
    )
    is_tp_enabled = (
        mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1
    )
    is_tpsp_enabled = (
        mesh_resource.tpsp_resource is not None
        and get_mesh_axis_size(mesh_resource.tpsp_resource) > 1
    )
    is_fsdp_enabled = (
        mesh_resource.fsdp_resource is not None
        and get_mesh_axis_size(mesh_resource.fsdp_resource) > 1
    )
60

61
62
    assert not (is_dp_enabled and is_fsdp_enabled), (
        "Data parallelism and full-sharded data parallelism cannot be enabled at the same time."
63
64
        f" Got dp_resource={mesh_resource.dp_resource} and"
        f" fsdp_resource={mesh_resource.fsdp_resource}"
65
66
67
    )
    assert not (is_tp_enabled and is_tpsp_enabled), (
        "Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time."
68
69
        f" Got tp_resource={mesh_resource.tp_resource} and"
        f" tpsp_resource={mesh_resource.tpsp_resource}"
70
    )
71
72


73
74
75
76
77
78
79
80
def get_sharding_map_logic_axis_to_mesh_axis():
    """
    Generate a dict to map logical axes to mesh axes.
    """
    gsr = global_mesh_resource()

    is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1
    is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1
81
82

    te_logical_axis_to_mesh_axis = {
83
        BATCH_AXES: gsr.fsdp_resource if is_fsdp_enabled else gsr.dp_resource,
84
        SEQLEN_AXES: None,
85
        SEQLEN_TP_AXES: gsr.tpsp_resource,
86
        SEQLEN_CP_AXES: gsr.cp_resource,
87
        HEAD_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
88
        HIDDEN_AXES: None,
89
        HIDDEN_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
90
91
92
        JOINED_AXES: None,
        W_NO_SHARD_AXES: None,
        W_FSDP_AXES: gsr.fsdp_resource,
93
        W_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
94
95
96
97
98
        W_JOINED_AXES: None,
    }
    return te_logical_axis_to_mesh_axis


99
def _generate_pspec(logical_axis_names):
100
    """
101
102
103
104
105
106
107
    Convert TransformerEngine logical axes (e.g. BATCH_AXES) to a JAX PartitionSpec.
    Note, this method does not support Flax logical axes.

    Args:
        logical_axis_names: TransformerEngine logical axes to convert to a JAX PartitionSpec.
    Returns:
        A JAX PartitionSpec with the mesh axes corresponding to the given TransformerEngine logical axis names
108
    """
109
110
111
    rules = get_sharding_map_logic_axis_to_mesh_axis()

    mesh_axis_names = [rules.get(name) for name in logical_axis_names]
112
113
114
115
116
117
    pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
    return pspec


def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
    """
118
119
120
121
    A wrapper function to jax.lax.with_sharding_constraint
        1. Does nothing if mesh is empty.
        2. If all mesh axes are manual axes, replaces pspec with all Nones.
        3. Otherwise, strips only the manual axes.
122
123
124
125
126
127
128
    """
    if pspec is None:
        return x

    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    if mesh.empty:
        return x
129
130
131
132
133
134
135
136

    # We want to exclude the axes that already used by shard_map and shard_map
    # only sets those in the abstract_mesh, not the physical one
    manual_axis_names = get_abstract_mesh().manual_axes
    cleaned_axis_names = tuple(name if name not in manual_axis_names else None for name in pspec)

    cleaned_pspec = PartitionSpec(*cleaned_axis_names)
    return jax.lax.with_sharding_constraint(x, cleaned_pspec)
137
138


139
140
141
def with_sharding_constraint_by_logical_axes(
    x: jnp.array, logical_axis_names: Optional[tuple | list]
):
142
    """
143
144
145
    A wrapper function to flax.linen.with_logical_constraint.

    DEPRECATED USE CASE: If no Flax logical axis rules are available, this function falls back to jax.lax.with_sharding_constraint using a hardcoded logical axis rule table from TE rules, such as BATCH_AXES. This functionality will be removed in the future.
146
147
148
149
150
151
152
153
154
155
156

    If logical_axis_names = None, this means no sharding constraint is applied.

    If logical_axis_names = (None, None, ...), this means a sharding constraint is applied and the tensor is replicated across all devices.

    Args:
        x: Input tensor to apply sharding constraint
        logical_axis_names: Logical axis names to apply sharding constraint
    Returns:
        Tensor with sharding constraint applied, or the original tensor if no logical axes are provided.

157
    """
158
    if not logical_axis_names:
159
160
        return x

161
162
163
164
165
166
167
    try:
        # Check if Flax logical axis rules are available, if so use them
        import flax

        flax_rules = flax.linen.get_logical_axis_rules()
        if len(flax_rules) > 0:
            return flax.linen.with_logical_constraint(
168
                x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.AXIS_IS_UNSHARDED
169
170
171
172
173
174
175
176
177
178
179
180
181
182
            )
    except ImportError:
        pass

    warnings.warn(
        "TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and"
        " will be removed in a future version. Please use Flax logical axes with a"
        " flax.linen.logical_axis_rules context and optionally use"
        " transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your"
        " rules.",
        DeprecationWarning,
    )

    # If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
183
    assert len(x.shape) == len(logical_axis_names)
184
    pspec = _generate_pspec(logical_axis_names)
185
186
187
    return with_sharding_constraint(x, pspec)


188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def get_all_mesh_axes():
    """
    Get all name of mesh axes
    """
    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    return mesh.axis_names


def get_padded_spec(spec, ndim):
    """
    Get padded spec for partitioning from arguments' information
    """
    if spec is None:
        return (None,) * ndim
    assert len(spec) <= ndim
    return spec + (None,) * (ndim - len(spec))


206
207
208
def lax_paral_op(
    x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs
):
209
210
211
212
    """
    A wrapper function to invoke lax.p* operations, like psum.
    """
    if mesh_resource is not None:
213
        _, resource = _get_mesh_info(mesh_resource, mesh)
214
        return ops(x, resource, **kwargs)
215
216
217
218
219
220
221
222
223
224
    return x


def num_of_devices():
    """
    Get total number of detected devices
    """
    return len(jax.devices())


225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def get_mesh_axis_size(axis, mesh=None):
    """
    Get the axis size of the given mesh.
    If the mesh is None, it would be replaced
    by the global mesh.
    """
    if mesh is None:
        mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh

    if axis is None:
        return 1

    assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}"
    return mesh.shape[axis]


def get_mesh_axis_rank(axis: str, mesh=None):
    """
    Gets the local axis rank of the `axis` of the array.
    If the mesh is None the rank is 0.
    """
    if mesh is None:
        return 0
    _, axis_name = _get_mesh_info(axis, mesh)
    return jax.lax.axis_index(axis_name)


252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def get_mesh_axis_rank_host(axis, mesh) -> int:
    """
    Same as get_mesh_axis_rank(), but return a host value instead of a
    traced device value.
    """
    if axis not in mesh.axis_names:
        raise ValueError(f"Axis {axis} not found in mesh axis names: {mesh.axis_names}")

    axis_index = mesh.axis_names.index(axis)

    # Convert mesh.devices (ndarray of Device objects) to flat list
    devices = mesh.devices
    local_device = jax.devices()[jax.process_index()]  # Pick one device on this host

    # Find index of local_device in mesh.devices
    coords = np.argwhere(devices == local_device)
    if coords.size == 0:
        raise ValueError(f"Local device {local_device} not found in mesh.devices.")
    coords = tuple(coords[0])  # Coordinates in the mesh array

    # Get the mesh rank along the specified axis
    rank = coords[axis_index]
    return int(rank)


277
@dataclass
278
class MeshResource:
279
280
281
282
283
284
285
286
    """A data container for managing mesh resources in distributed training.

    This class defines the mapping between logical axes and physical mesh axes
    for different types of parallelism in distributed training.

    Attributes:
        dp_resource: Axis name for data parallelism (batch sharding), default is None
        tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None
287
        tpsp_resource: Axis name for tensor sequence parallelism (hidden and sequence sharding), default is None
288
289
290
        fsdp_resource: Axis name for full-sharded data parallelism, default is None
        pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
        cp_resource: Axis name for context parallelism (sequence sharding), default is None
291
    """
292

293
294
    dp_resource: str = None
    tp_resource: str = None
295
    tpsp_resource: str = None
296
    fsdp_resource: str = None
297
    pp_resource: str = None
298
    cp_resource: str = None
299
300


301
_GLOBAL_MESH_RESOURCE = None
302
303
304


@contextmanager
305
def global_shard_guard(resource: MeshResource):
306
307
308
309
310
311
312
    """Context manager for setting global sharding configuration.

    This context manager allows temporarily setting the global mesh resource
    configuration for sharding operations.

    Args:
        resource: MeshResource instance defining the sharding configuration
313
    """
314
    global _GLOBAL_MESH_RESOURCE
315
    old_resources = _GLOBAL_MESH_RESOURCE
316
    try:
317
        _GLOBAL_MESH_RESOURCE = resource
318
319
        yield
    finally:
320
        _GLOBAL_MESH_RESOURCE = old_resources
321

322

323
def global_mesh_resource() -> MeshResource:
324
325
326
327
    """Get the current global mesh resource configuration.

    Returns:
        The current MeshResource instance
328
    """
329
330
331
332
333
    assert _GLOBAL_MESH_RESOURCE is not None, (
        "Global mesh resource is not set. Please set the MeshResource via a global_shard_guard"
        " context. If you are not using multiple GPUs, you can use an empty MeshResource by"
        " wrapping your program in 'with global_shard_guard(MeshResource()):'"
    )
334
    _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE)
335
    return _GLOBAL_MESH_RESOURCE
336

337

338
def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
339
340
341
342
343
344
345
346
    """Perform all-reduce sum operation along data parallelism and FSDP axes.

    Args:
        x: Input tensor to reduce
        mesh: JAX mesh for distributed computation

    Returns:
        Reduced tensor
347
    """
348
349
    x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
    return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
350
351


352
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
353
354
355
356
357
358
359
360
    """Perform all-reduce max operation along all axes except pipeline parallelism.

    Args:
        x: Input tensor to reduce
        mesh: JAX mesh for distributed computation

    Returns:
        Reduced tensor
361
362
363
364
    """
    all_axes = get_all_mesh_axes()
    for axis in all_axes:
        if axis != global_mesh_resource().pp_resource:
365
            x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
366
    return x