sharding.py 11.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
"""Sharding utilities for Transformer Engine in JAX.

This module provides utilities for managing tensor sharding in distributed training,
including support for various parallelism strategies like data parallelism (DP),
tensor parallelism (TP), pipeline parallelism (PP), and full-sharded data
parallelism (FSDP). It includes functions for sharding constraints, mesh management,
and collective operations.
11
"""
12
import os
13
14
15
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
16
from typing import Callable, Optional
17
18
19
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
20
from jax.sharding import PartitionSpec
21
import numpy as np
22
23
24

_PXLA_THREAD_RESOURCES = pxla.thread_resources

25
# Axis Names
26
27
28
BATCH_AXES = "nvte_batch"
SEQLEN_AXES = "nvte_seqlen"
SEQLEN_TP_AXES = "nvte_seqlen_tp"
29
SEQLEN_CP_AXES = "nvte_seqlen_cp"
30
31
32
33
34
35
36
37
HEAD_AXES = "nvte_head"
HIDDEN_AXES = "nvte_hidden"
HIDDEN_TP_AXES = "nvte_hidden_tp"
JOINED_AXES = "nvte_joined"
W_NO_SHARD_AXES = "nvte_w_no_shard"
W_FSDP_AXES = "nvte_w_fsdp"
W_TP_AXES = "nvte_w_tp"
W_JOINED_AXES = "nvte_w_joined"
38

39

40
def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
41
    assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
42
43
44
    return mesh.shape[resource], resource


45
46
47
48
49
50
51
52
def get_sharding_map_logic_axis_to_mesh_axis():
    """
    Generate a dict to map logical axes to mesh axes.
    """
    gsr = global_mesh_resource()

    IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False)))

53
54
55
56
57
    batch_resources = (
        [gsr.fsdp_resource, gsr.dp_resource]
        if IS_FSDP_OUTER
        else [gsr.dp_resource, gsr.fsdp_resource]
    )
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    batch_dim_rule = []
    for resource in batch_resources:
        if resource is not None and resource not in batch_dim_rule:
            batch_dim_rule.append(resource)

    if len(batch_dim_rule) <= 0:
        batch_dim_rule = None
    elif len(batch_dim_rule) == 1:
        batch_dim_rule = batch_dim_rule[0]
    else:
        batch_dim_rule = tuple(batch_dim_rule)

    te_logical_axis_to_mesh_axis = {
        BATCH_AXES: batch_dim_rule,
        SEQLEN_AXES: None,
        SEQLEN_TP_AXES: gsr.tp_resource,
75
        SEQLEN_CP_AXES: gsr.cp_resource,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        HEAD_AXES: gsr.tp_resource,
        HIDDEN_AXES: None,
        HIDDEN_TP_AXES: gsr.tp_resource,
        JOINED_AXES: None,
        W_NO_SHARD_AXES: None,
        W_FSDP_AXES: gsr.fsdp_resource,
        W_TP_AXES: gsr.tp_resource,
        W_JOINED_AXES: None,
    }
    return te_logical_axis_to_mesh_axis


def generate_pspec(logical_axis_names):
    """
    Convert logical axes to PartitionSpec
    """
    rules = get_sharding_map_logic_axis_to_mesh_axis()
93
94
95
96
97
    # mesh_axis_names = [rules[name] for name in logical_axis_names]
    mesh_axis_names = []
    for name in logical_axis_names:
        axis_name = rules[name] if name in rules else None
        mesh_axis_names.append(axis_name)
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
    return pspec


def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
    """
    A wrapper function to jax.lax.with_sharding_constraint to
    support the case that Mesh is empty.
    """
    if pspec is None:
        return x

    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    if mesh.empty:
        return x
    return jax.lax.with_sharding_constraint(x, pspec)


116
117
118
def with_sharding_constraint_by_logical_axes(
    x: jnp.array, logical_axis_names: Optional[tuple | list]
):
119
120
    """
    A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
121
122
123
124
125
126
127
128
129
130
131

    If logical_axis_names = None, this means no sharding constraint is applied.

    If logical_axis_names = (None, None, ...), this means a sharding constraint is applied and the tensor is replicated across all devices.

    Args:
        x: Input tensor to apply sharding constraint
        logical_axis_names: Logical axis names to apply sharding constraint
    Returns:
        Tensor with sharding constraint applied, or the original tensor if no logical axes are provided.

132
    """
133
    if not logical_axis_names:
134
135
136
137
138
139
140
        return x

    assert len(x.shape) == len(logical_axis_names)
    pspec = generate_pspec(logical_axis_names)
    return with_sharding_constraint(x, pspec)


141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def get_all_mesh_axes():
    """
    Get all name of mesh axes
    """
    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    return mesh.axis_names


def get_padded_spec(spec, ndim):
    """
    Get padded spec for partitioning from arguments' information
    """
    if spec is None:
        return (None,) * ndim
    assert len(spec) <= ndim
    return spec + (None,) * (ndim - len(spec))


159
160
161
def lax_paral_op(
    x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs
):
162
163
164
165
    """
    A wrapper function to invoke lax.p* operations, like psum.
    """
    if mesh_resource is not None:
166
        _, resource = _get_mesh_info(mesh_resource, mesh)
167
        return ops(x, resource, **kwargs)
168
169
170
171
172
173
174
175
176
177
    return x


def num_of_devices():
    """
    Get total number of detected devices
    """
    return len(jax.devices())


178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def get_mesh_axis_size(axis, mesh=None):
    """
    Get the axis size of the given mesh.
    If the mesh is None, it would be replaced
    by the global mesh.
    """
    if mesh is None:
        mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh

    if axis is None:
        return 1

    assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}"
    return mesh.shape[axis]


def get_mesh_axis_rank(axis: str, mesh=None):
    """
    Gets the local axis rank of the `axis` of the array.
    If the mesh is None the rank is 0.
    """
    if mesh is None:
        return 0
    _, axis_name = _get_mesh_info(axis, mesh)
    return jax.lax.axis_index(axis_name)


205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def get_mesh_axis_rank_host(axis, mesh) -> int:
    """
    Same as get_mesh_axis_rank(), but return a host value instead of a
    traced device value.
    """
    if axis not in mesh.axis_names:
        raise ValueError(f"Axis {axis} not found in mesh axis names: {mesh.axis_names}")

    axis_index = mesh.axis_names.index(axis)

    # Convert mesh.devices (ndarray of Device objects) to flat list
    devices = mesh.devices
    local_device = jax.devices()[jax.process_index()]  # Pick one device on this host

    # Find index of local_device in mesh.devices
    coords = np.argwhere(devices == local_device)
    if coords.size == 0:
        raise ValueError(f"Local device {local_device} not found in mesh.devices.")
    coords = tuple(coords[0])  # Coordinates in the mesh array

    # Get the mesh rank along the specified axis
    rank = coords[axis_index]
    return int(rank)


230
@dataclass
231
class MeshResource:
232
233
234
235
236
237
238
239
240
241
242
    """A data container for managing mesh resources in distributed training.

    This class defines the mapping between logical axes and physical mesh axes
    for different types of parallelism in distributed training.

    Attributes:
        dp_resource: Axis name for data parallelism (batch sharding), default is None
        tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None
        fsdp_resource: Axis name for full-sharded data parallelism, default is None
        pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
        cp_resource: Axis name for context parallelism (sequence sharding), default is None
243
    """
244

245
246
    dp_resource: str = None
    tp_resource: str = None
247
    fsdp_resource: str = None
248
    pp_resource: str = None
249
    cp_resource: str = None
250
251


252
_GLOBAL_MESH_RESOURCE = MeshResource()
253
254
255


@contextmanager
256
def global_shard_guard(resource: MeshResource):
257
258
259
260
261
262
263
    """Context manager for setting global sharding configuration.

    This context manager allows temporarily setting the global mesh resource
    configuration for sharding operations.

    Args:
        resource: MeshResource instance defining the sharding configuration
264
    """
265
    global _GLOBAL_MESH_RESOURCE
266
    old_resources = _GLOBAL_MESH_RESOURCE
267
    try:
268
        _GLOBAL_MESH_RESOURCE = resource
269
270
        yield
    finally:
271
        _GLOBAL_MESH_RESOURCE = old_resources
272

273

274
def global_mesh_resource() -> MeshResource:
275
276
277
278
    """Get the current global mesh resource configuration.

    Returns:
        The current MeshResource instance
279
280
    """
    return _GLOBAL_MESH_RESOURCE
281

282

283
def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
284
285
286
287
288
289
290
291
    """Perform all-reduce sum operation along data parallelism and FSDP axes.

    Args:
        x: Input tensor to reduce
        mesh: JAX mesh for distributed computation

    Returns:
        Reduced tensor
292
    """
293
294
    x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
    return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
295
296


297
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
298
299
300
301
302
303
304
305
    """Perform all-reduce max operation along all axes except pipeline parallelism.

    Args:
        x: Input tensor to reduce
        mesh: JAX mesh for distributed computation

    Returns:
        Reduced tensor
306
307
308
309
    """
    all_axes = get_all_mesh_axes()
    for axis in all_axes:
        if axis != global_mesh_resource().pp_resource:
310
            x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
311
312
313
314
315
316
317
    return x


# Deprecating Items ---------------------------------------------------------------
ShardingResource = MeshResource

global_shard_resource = global_mesh_resource
318
319
320


class MajorShardingType(Enum):
321
322
323
324
325
326
327
328
329
330
    """Enumeration of major sharding types for distributed training.

    This enum defines the basic sharding patterns available for distributed
    training. Note that this class is deprecated and will be removed in the future.

    Values:
        SINGLE: Single process training
        DP: Data parallel training
        TP: Standard tensor parallel training
        DPTP: Data and standard tensor parallel training
331
    """
332

333
334
335
336
337
338
339
    SINGLE = 0
    DP = 1
    TP = 2
    DPTP = 3


class ShardingType(Enum):
340
341
342
343
344
345
346
347
348
349
350
351
352
    """Enumeration of detailed sharding types for distributed training.

    This enum defines specific sharding patterns for distributed training,
    including combinations of data parallelism and different tensor parallelism
    strategies. Note that this class is deprecated and will be removed in the future.

    Values:
        SINGLE: No sharding
        DP: Sharding along data parallelism
        TP_COL: Sharding along column-split tensor parallelism
        TP_ROW: Sharding along row-split tensor parallelism
        DP_TP_COL: Sharding along data and column-split tensor parallelism
        DP_TP_ROW: Sharding along data and row-split tensor parallelism
353
    """
354

355
356
357
358
359
360
    SINGLE = (MajorShardingType.SINGLE, "single")
    DP = (MajorShardingType.DP, "dp")
    TP_COL = (MajorShardingType.TP, "tp_col")
    TP_ROW = (MajorShardingType.TP, "tp_row")
    DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
    DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
361
362


363
364
365
def get_non_contracting_logical_axes(
    ndim, logical_axes: tuple[Optional[str]], contracting_dims
) -> tuple[Optional[str]]:
366
367
368
369
370
371
372
373
374
375
    """Get logical axes for non-contracting dimensions.

    Args:
        ndim: Number of dimensions in the tensor.
        logical_axes: Tuple of logical axes for each dimension.
        contracting_dims: Set of dimensions that are being contracted.

    Returns:
        Tuple of logical axes for non-contracting dimensions.
    """
376
377
    assert logical_axes is not None, "Logical axes must be a tuple and cannot be None."
    assert len(logical_axes) == ndim, "Logical axes must match the number of dimensions."
378
379
380
381

    non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims]
    non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims)
    return non_contracting_logical_axes