sharding.py 9.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
"""Sharding utilities for Transformer Engine in JAX.

This module provides utilities for managing tensor sharding in distributed training,
including support for various parallelism strategies like data parallelism (DP),
tensor parallelism (TP), pipeline parallelism (PP), and full-sharded data
parallelism (FSDP). It includes functions for sharding constraints, mesh management,
and collective operations.
11
"""
12
import os
13
14
15
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
16
from typing import Callable
17
18
19
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
20
from jax.sharding import PartitionSpec
21
22
23

_PXLA_THREAD_RESOURCES = pxla.thread_resources

24
# Axis Names
25
26
27
BATCH_AXES = "nvte_batch"
SEQLEN_AXES = "nvte_seqlen"
SEQLEN_TP_AXES = "nvte_seqlen_tp"
28
SEQLEN_CP_AXES = "nvte_seqlen_cp"
29
30
31
32
33
34
35
36
HEAD_AXES = "nvte_head"
HIDDEN_AXES = "nvte_hidden"
HIDDEN_TP_AXES = "nvte_hidden_tp"
JOINED_AXES = "nvte_joined"
W_NO_SHARD_AXES = "nvte_w_no_shard"
W_FSDP_AXES = "nvte_w_fsdp"
W_TP_AXES = "nvte_w_tp"
W_JOINED_AXES = "nvte_w_joined"
37

38

39
def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
40
    assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
41
42
43
    return mesh.shape[resource], resource


44
45
46
47
48
49
50
51
def get_sharding_map_logic_axis_to_mesh_axis():
    """
    Generate a dict to map logical axes to mesh axes.
    """
    gsr = global_mesh_resource()

    IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False)))

52
53
54
55
56
    batch_resources = (
        [gsr.fsdp_resource, gsr.dp_resource]
        if IS_FSDP_OUTER
        else [gsr.dp_resource, gsr.fsdp_resource]
    )
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    batch_dim_rule = []
    for resource in batch_resources:
        if resource is not None and resource not in batch_dim_rule:
            batch_dim_rule.append(resource)

    if len(batch_dim_rule) <= 0:
        batch_dim_rule = None
    elif len(batch_dim_rule) == 1:
        batch_dim_rule = batch_dim_rule[0]
    else:
        batch_dim_rule = tuple(batch_dim_rule)

    te_logical_axis_to_mesh_axis = {
        BATCH_AXES: batch_dim_rule,
        SEQLEN_AXES: None,
        SEQLEN_TP_AXES: gsr.tp_resource,
74
        SEQLEN_CP_AXES: gsr.cp_resource,
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        HEAD_AXES: gsr.tp_resource,
        HIDDEN_AXES: None,
        HIDDEN_TP_AXES: gsr.tp_resource,
        JOINED_AXES: None,
        W_NO_SHARD_AXES: None,
        W_FSDP_AXES: gsr.fsdp_resource,
        W_TP_AXES: gsr.tp_resource,
        W_JOINED_AXES: None,
    }
    return te_logical_axis_to_mesh_axis


def generate_pspec(logical_axis_names):
    """
    Convert logical axes to PartitionSpec
    """
    rules = get_sharding_map_logic_axis_to_mesh_axis()
    mesh_axis_names = [rules[name] for name in logical_axis_names]
    pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
    return pspec


def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
    """
    A wrapper function to jax.lax.with_sharding_constraint to
    support the case that Mesh is empty.
    """
    if pspec is None:
        return x

    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    if mesh.empty:
        return x
    return jax.lax.with_sharding_constraint(x, pspec)


def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: tuple | list):
    """
    A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
    """
    if logical_axis_names is None:
        return x

    assert len(x.shape) == len(logical_axis_names)
    pspec = generate_pspec(logical_axis_names)
    return with_sharding_constraint(x, pspec)


123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def get_all_mesh_axes():
    """
    Get all name of mesh axes
    """
    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    return mesh.axis_names


def get_padded_spec(spec, ndim):
    """
    Get padded spec for partitioning from arguments' information
    """
    if spec is None:
        return (None,) * ndim
    assert len(spec) <= ndim
    return spec + (None,) * (ndim - len(spec))


141
142
143
def lax_paral_op(
    x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs
):
144
145
146
147
    """
    A wrapper function to invoke lax.p* operations, like psum.
    """
    if mesh_resource is not None:
148
        _, resource = _get_mesh_info(mesh_resource, mesh)
149
        return ops(x, resource, **kwargs)
150
151
152
153
154
155
156
157
158
159
    return x


def num_of_devices():
    """
    Get total number of detected devices
    """
    return len(jax.devices())


160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def get_mesh_axis_size(axis, mesh=None):
    """
    Get the axis size of the given mesh.
    If the mesh is None, it would be replaced
    by the global mesh.
    """
    if mesh is None:
        mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh

    if axis is None:
        return 1

    assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}"
    return mesh.shape[axis]


def get_mesh_axis_rank(axis: str, mesh=None):
    """
    Gets the local axis rank of the `axis` of the array.
    If the mesh is None the rank is 0.
    """
    if mesh is None:
        return 0
    _, axis_name = _get_mesh_info(axis, mesh)
    return jax.lax.axis_index(axis_name)


187
@dataclass
188
class MeshResource:
189
190
191
192
193
194
195
196
197
198
199
    """A data container for managing mesh resources in distributed training.

    This class defines the mapping between logical axes and physical mesh axes
    for different types of parallelism in distributed training.

    Attributes:
        dp_resource: Axis name for data parallelism (batch sharding), default is None
        tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None
        fsdp_resource: Axis name for full-sharded data parallelism, default is None
        pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
        cp_resource: Axis name for context parallelism (sequence sharding), default is None
200
    """
201

202
203
    dp_resource: str = None
    tp_resource: str = None
204
    fsdp_resource: str = None
205
    pp_resource: str = None
206
    cp_resource: str = None
207
208


209
_GLOBAL_MESH_RESOURCE = MeshResource()
210
211
212


@contextmanager
213
def global_shard_guard(resource: MeshResource):
214
215
216
217
218
219
220
    """Context manager for setting global sharding configuration.

    This context manager allows temporarily setting the global mesh resource
    configuration for sharding operations.

    Args:
        resource: MeshResource instance defining the sharding configuration
221
    """
222
    global _GLOBAL_MESH_RESOURCE
223
    old_resources = _GLOBAL_MESH_RESOURCE
224
    try:
225
        _GLOBAL_MESH_RESOURCE = resource
226
227
        yield
    finally:
228
        _GLOBAL_MESH_RESOURCE = old_resources
229

230

231
def global_mesh_resource() -> MeshResource:
232
233
234
235
    """Get the current global mesh resource configuration.

    Returns:
        The current MeshResource instance
236
237
    """
    return _GLOBAL_MESH_RESOURCE
238

239

240
def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
241
242
243
244
245
246
247
248
    """Perform all-reduce sum operation along data parallelism and FSDP axes.

    Args:
        x: Input tensor to reduce
        mesh: JAX mesh for distributed computation

    Returns:
        Reduced tensor
249
    """
250
251
    x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
    return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
252
253


254
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
255
256
257
258
259
260
261
262
    """Perform all-reduce max operation along all axes except pipeline parallelism.

    Args:
        x: Input tensor to reduce
        mesh: JAX mesh for distributed computation

    Returns:
        Reduced tensor
263
264
265
266
    """
    all_axes = get_all_mesh_axes()
    for axis in all_axes:
        if axis != global_mesh_resource().pp_resource:
267
            x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
268
269
270
271
272
273
274
    return x


# Deprecating Items ---------------------------------------------------------------
ShardingResource = MeshResource

global_shard_resource = global_mesh_resource
275
276
277


class MajorShardingType(Enum):
278
279
280
281
282
283
284
285
286
287
    """Enumeration of major sharding types for distributed training.

    This enum defines the basic sharding patterns available for distributed
    training. Note that this class is deprecated and will be removed in the future.

    Values:
        SINGLE: Single process training
        DP: Data parallel training
        TP: Standard tensor parallel training
        DPTP: Data and standard tensor parallel training
288
    """
289

290
291
292
293
294
295
296
    SINGLE = 0
    DP = 1
    TP = 2
    DPTP = 3


class ShardingType(Enum):
297
298
299
300
301
302
303
304
305
306
307
308
309
    """Enumeration of detailed sharding types for distributed training.

    This enum defines specific sharding patterns for distributed training,
    including combinations of data parallelism and different tensor parallelism
    strategies. Note that this class is deprecated and will be removed in the future.

    Values:
        SINGLE: No sharding
        DP: Sharding along data parallelism
        TP_COL: Sharding along column-split tensor parallelism
        TP_ROW: Sharding along row-split tensor parallelism
        DP_TP_COL: Sharding along data and column-split tensor parallelism
        DP_TP_ROW: Sharding along data and row-split tensor parallelism
310
    """
311

312
313
314
315
316
317
    SINGLE = (MajorShardingType.SINGLE, "single")
    DP = (MajorShardingType.DP, "dp")
    TP_COL = (MajorShardingType.TP, "tp_col")
    TP_ROW = (MajorShardingType.TP, "tp_row")
    DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
    DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")