sharding.py 14.6 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
"""Sharding utilities for Transformer Engine in JAX.

This module provides utilities for managing tensor sharding in distributed training,
including support for various parallelism strategies like data parallelism (DP),
tensor parallelism (TP), pipeline parallelism (PP), and full-sharded data
parallelism (FSDP). It includes functions for sharding constraints, mesh management,
and collective operations.
11
12
13
"""
from contextlib import contextmanager
from dataclasses import dataclass
14
from typing import Callable, Optional
15
import warnings
Phuong Nguyen's avatar
Phuong Nguyen committed
16

17
18
import jax
import jax.numpy as jnp
19
20
from jax.interpreters import pxla
from jax.sharding import PartitionSpec, get_abstract_mesh
21
import numpy as np
22
23
24

_PXLA_THREAD_RESOURCES = pxla.thread_resources

25
# Axis Names
26
27
28
BATCH_AXES = "nvte_batch"
SEQLEN_AXES = "nvte_seqlen"
SEQLEN_TP_AXES = "nvte_seqlen_tp"
29
SEQLEN_CP_AXES = "nvte_seqlen_cp"
30
31
32
33
34
35
36
37
HEAD_AXES = "nvte_head"
HIDDEN_AXES = "nvte_hidden"
HIDDEN_TP_AXES = "nvte_hidden_tp"
JOINED_AXES = "nvte_joined"
W_NO_SHARD_AXES = "nvte_w_no_shard"
W_FSDP_AXES = "nvte_w_fsdp"
W_TP_AXES = "nvte_w_tp"
W_JOINED_AXES = "nvte_w_joined"
38

39

40
41
42
43
44
45
46
47
48
def _get_mesh():
    # Handle Mesh's set via `with mesh:`
    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    if mesh is not None and not mesh.empty:
        return mesh
    # Handle Mesh's set via `jax.set_mesh(mesh)`
    return jax.sharding.get_abstract_mesh()


49
def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
50
    assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
51
52
53
    return mesh.shape[resource], resource


54
def _validate_mesh_resource_configuration(mesh_resource):
55
    """Validate that the mesh resource configuration is consistent and conflict-free."""
56
57
58
59
60
61
62
    is_tp_enabled = (
        mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1
    )
    is_tpsp_enabled = (
        mesh_resource.tpsp_resource is not None
        and get_mesh_axis_size(mesh_resource.tpsp_resource) > 1
    )
63

64
65
    assert not (is_tp_enabled and is_tpsp_enabled), (
        "Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time."
66
67
        f" Got tp_resource={mesh_resource.tp_resource} and"
        f" tpsp_resource={mesh_resource.tpsp_resource}"
68
    )
69
70


71
72
73
74
def is_mesh_available() -> bool:
    """
    Check if a physical mesh is available.
    """
75
    mesh = _get_mesh()
76
77
78
    return mesh is not None and not mesh.empty


79
80
81
82
def get_sharding_map_logic_axis_to_mesh_axis():
    """
    Generate a dict to map logical axes to mesh axes.
    """
83
    mesh = _get_mesh()
84
85
86
87
88
89
90
91
92
    if mesh is None or mesh.empty:
        # If no mesh is defined, return an empty dict and do not require a MeshResource context to be present
        return {}

    abstract_mesh = get_abstract_mesh()
    if sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names):
        # If all mesh axes are manual axes, return an empty dict and do not require a MeshResource context to be present
        return {}

93
94
95
96
    gsr = global_mesh_resource()

    is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1
    is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1
97
98

    te_logical_axis_to_mesh_axis = {
99
        BATCH_AXES: gsr.fsdp_resource if is_fsdp_enabled else gsr.dp_resource,
100
        SEQLEN_AXES: None,
101
        SEQLEN_TP_AXES: gsr.tpsp_resource,
102
        SEQLEN_CP_AXES: gsr.cp_resource,
103
        HEAD_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
104
        HIDDEN_AXES: None,
105
        HIDDEN_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
106
107
108
        JOINED_AXES: None,
        W_NO_SHARD_AXES: None,
        W_FSDP_AXES: gsr.fsdp_resource,
109
        W_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
110
111
112
113
114
        W_JOINED_AXES: None,
    }
    return te_logical_axis_to_mesh_axis


115
def _generate_pspec(logical_axis_names):
116
    """
117
118
119
120
121
122
123
    Convert TransformerEngine logical axes (e.g. BATCH_AXES) to a JAX PartitionSpec.
    Note, this method does not support Flax logical axes.

    Args:
        logical_axis_names: TransformerEngine logical axes to convert to a JAX PartitionSpec.
    Returns:
        A JAX PartitionSpec with the mesh axes corresponding to the given TransformerEngine logical axis names
124
    """
125
126
127
    rules = get_sharding_map_logic_axis_to_mesh_axis()

    mesh_axis_names = [rules.get(name) for name in logical_axis_names]
128
129
130
131
132
133
    pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
    return pspec


def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
    """
134
135
136
137
    A wrapper function to jax.lax.with_sharding_constraint
        1. Does nothing if mesh is empty.
        2. If all mesh axes are manual axes, replaces pspec with all Nones.
        3. Otherwise, strips only the manual axes.
138
139
140
141
    """
    if pspec is None:
        return x

142
    mesh = _get_mesh()
143
144
    if mesh.empty:
        return x
145
146
147
148

    # We want to exclude the axes that already used by shard_map and shard_map
    # only sets those in the abstract_mesh, not the physical one
    manual_axis_names = get_abstract_mesh().manual_axes
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

    # Multiple mesh axes can be mapped to a single shape axis, so we need to unpack and process tuples here too
    def filter_manual_axes(name_or_tuple):
        if isinstance(name_or_tuple, tuple):
            out = tuple(n for n in name_or_tuple if n not in manual_axis_names)
            if len(out) == 0:
                return None
            return out
        if name_or_tuple in manual_axis_names:
            return None
        return name_or_tuple

    cleaned_axis_names = tuple(filter_manual_axes(name_or_tuple) for name_or_tuple in pspec)

    if cleaned_axis_names == (None,) * len(cleaned_axis_names):
        return x
165
166
167

    cleaned_pspec = PartitionSpec(*cleaned_axis_names)
    return jax.lax.with_sharding_constraint(x, cleaned_pspec)
168
169


170
171
172
def with_sharding_constraint_by_logical_axes(
    x: jnp.array, logical_axis_names: Optional[tuple | list]
):
173
    """
174
175
176
    A wrapper function to flax.linen.with_logical_constraint.

    DEPRECATED USE CASE: If no Flax logical axis rules are available, this function falls back to jax.lax.with_sharding_constraint using a hardcoded logical axis rule table from TE rules, such as BATCH_AXES. This functionality will be removed in the future.
177
178
179
180
181
182
183
184
185
186
187

    If logical_axis_names = None, this means no sharding constraint is applied.

    If logical_axis_names = (None, None, ...), this means a sharding constraint is applied and the tensor is replicated across all devices.

    Args:
        x: Input tensor to apply sharding constraint
        logical_axis_names: Logical axis names to apply sharding constraint
    Returns:
        Tensor with sharding constraint applied, or the original tensor if no logical axes are provided.

188
    """
189
    if not logical_axis_names:
190
191
        return x

192
193
194
195
196
197
198
    try:
        # Check if Flax logical axis rules are available, if so use them
        import flax

        flax_rules = flax.linen.get_logical_axis_rules()
        if len(flax_rules) > 0:
            return flax.linen.with_logical_constraint(
199
                x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.AXIS_IS_UNSHARDED
200
201
202
203
204
205
206
207
208
209
210
211
212
213
            )
    except ImportError:
        pass

    warnings.warn(
        "TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and"
        " will be removed in a future version. Please use Flax logical axes with a"
        " flax.linen.logical_axis_rules context and optionally use"
        " transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your"
        " rules.",
        DeprecationWarning,
    )

    # If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
214
    assert len(x.shape) == len(logical_axis_names)
215
    pspec = _generate_pspec(logical_axis_names)
216
217
218
    return with_sharding_constraint(x, pspec)


219
220
221
222
def get_all_mesh_axes():
    """
    Get all name of mesh axes
    """
223
    mesh = _get_mesh()
224
225
226
227
228
229
230
231
232
233
234
235
236
    return mesh.axis_names


def get_padded_spec(spec, ndim):
    """
    Get padded spec for partitioning from arguments' information
    """
    if spec is None:
        return (None,) * ndim
    assert len(spec) <= ndim
    return spec + (None,) * (ndim - len(spec))


237
238
239
def lax_paral_op(
    x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs
):
240
241
242
243
    """
    A wrapper function to invoke lax.p* operations, like psum.
    """
    if mesh_resource is not None:
244
        _, resource = _get_mesh_info(mesh_resource, mesh)
245
        return ops(x, resource, **kwargs)
246
247
248
249
250
251
252
253
254
255
    return x


def num_of_devices():
    """
    Get total number of detected devices
    """
    return len(jax.devices())


256
257
258
259
260
261
262
def get_num_devices_in_mesh(mesh=None):
    """
    Get the number of devices in the given mesh.
    If the mesh is None, it would be replaced
    by the global mesh.
    """
    if mesh is None:
263
        mesh = _get_mesh()
264
265
266
267
268
    if mesh.empty:
        return 1
    return np.prod(list(mesh.shape.values()))


269
270
271
272
273
274
275
def get_mesh_axis_size(axis, mesh=None):
    """
    Get the axis size of the given mesh.
    If the mesh is None, it would be replaced
    by the global mesh.
    """
    if mesh is None:
276
        mesh = _get_mesh()
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

    if axis is None:
        return 1

    assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}"
    return mesh.shape[axis]


def get_mesh_axis_rank(axis: str, mesh=None):
    """
    Gets the local axis rank of the `axis` of the array.
    If the mesh is None the rank is 0.
    """
    if mesh is None:
        return 0
    _, axis_name = _get_mesh_info(axis, mesh)
    return jax.lax.axis_index(axis_name)


296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def get_mesh_axis_rank_host(axis, mesh) -> int:
    """
    Same as get_mesh_axis_rank(), but return a host value instead of a
    traced device value.
    """
    if axis not in mesh.axis_names:
        raise ValueError(f"Axis {axis} not found in mesh axis names: {mesh.axis_names}")

    axis_index = mesh.axis_names.index(axis)

    # Convert mesh.devices (ndarray of Device objects) to flat list
    devices = mesh.devices
    local_device = jax.devices()[jax.process_index()]  # Pick one device on this host

    # Find index of local_device in mesh.devices
    coords = np.argwhere(devices == local_device)
    if coords.size == 0:
        raise ValueError(f"Local device {local_device} not found in mesh.devices.")
    coords = tuple(coords[0])  # Coordinates in the mesh array

    # Get the mesh rank along the specified axis
    rank = coords[axis_index]
    return int(rank)


321
@dataclass
322
class MeshResource:
323
324
325
326
327
328
329
330
    """A data container for managing mesh resources in distributed training.

    This class defines the mapping between logical axes and physical mesh axes
    for different types of parallelism in distributed training.

    Attributes:
        dp_resource: Axis name for data parallelism (batch sharding), default is None
        tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None
331
        tpsp_resource: Axis name for tensor sequence parallelism (hidden and sequence sharding), default is None
332
333
334
        fsdp_resource: Axis name for full-sharded data parallelism, default is None
        pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
        cp_resource: Axis name for context parallelism (sequence sharding), default is None
335
    """
336

337
338
    dp_resource: str = None
    tp_resource: str = None
339
    tpsp_resource: str = None
340
    fsdp_resource: str = None
341
    pp_resource: str = None
342
    cp_resource: str = None
343
344


345
_GLOBAL_MESH_RESOURCE = None
346
347
348


@contextmanager
349
def global_shard_guard(resource: MeshResource):
350
351
352
353
354
355
356
    """Context manager for setting global sharding configuration.

    This context manager allows temporarily setting the global mesh resource
    configuration for sharding operations.

    Args:
        resource: MeshResource instance defining the sharding configuration
357
    """
358
    global _GLOBAL_MESH_RESOURCE
359
    old_resources = _GLOBAL_MESH_RESOURCE
360
    try:
361
        _GLOBAL_MESH_RESOURCE = resource
362
363
        yield
    finally:
364
        _GLOBAL_MESH_RESOURCE = old_resources
365

366

367
def global_mesh_resource() -> MeshResource:
368
369
370
371
    """Get the current global mesh resource configuration.

    Returns:
        The current MeshResource instance
372
    """
373
374
375
376
377
    assert _GLOBAL_MESH_RESOURCE is not None, (
        "Global mesh resource is not set. Please set the MeshResource via a global_shard_guard"
        " context. If you are not using multiple GPUs, you can use an empty MeshResource by"
        " wrapping your program in 'with global_shard_guard(MeshResource()):'"
    )
378
    _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE)
379
    return _GLOBAL_MESH_RESOURCE
380

381

382
def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
383
384
385
386
387
388
389
390
    """Perform all-reduce sum operation along data parallelism and FSDP axes.

    Args:
        x: Input tensor to reduce
        mesh: JAX mesh for distributed computation

    Returns:
        Reduced tensor
391
    """
392
393
    x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
    return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408


def all_reduce_sum_along_dp_fsdp_tpsp(x: jnp.array, mesh: jax.sharding.Mesh):
    """Perform all-reduce sum operation along data parallelism and sequence parallelism axes.

    Args:
        x: Input tensor to reduce
        mesh: JAX mesh for distributed computation

    Returns:
        Reduced tensor
    """
    x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().tpsp_resource, mesh)
    x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
    return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
409
410


411
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
412
413
414
415
416
417
418
419
    """Perform all-reduce max operation along all axes except pipeline parallelism.

    Args:
        x: Input tensor to reduce
        mesh: JAX mesh for distributed computation

    Returns:
        Reduced tensor
420
421
422
423
    """
    all_axes = get_all_mesh_axes()
    for axis in all_axes:
        if axis != global_mesh_resource().pp_resource:
424
            x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
425
    return x
Phuong Nguyen's avatar
Phuong Nguyen committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443


def tpsp_axis_size():
    """
    Get the size of the tensor parallelism axis.
    Return 1 if no TP axis is set.
    """
    return get_mesh_axis_size(global_mesh_resource().tpsp_resource)


def dp_or_fsdp_axis_size():
    """
    Get the size of the data parallelism or FSDP axis.
    Return 1 if no DP/FSDP axis is set.
    """
    dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource)
    fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource)
    return dp_size if dp_size > 1 else fsdp_size