sharding.py 8.49 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Sharding Meta for xmap with CustomCall
"""
7
import os
8
9
10
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
11
from typing import Callable
12
13
14
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
15
from jax.sharding import PartitionSpec
16
17
18

_PXLA_THREAD_RESOURCES = pxla.thread_resources

19
# Axis Names
20
21
22
BATCH_AXES = "nvte_batch"
SEQLEN_AXES = "nvte_seqlen"
SEQLEN_TP_AXES = "nvte_seqlen_tp"
23
SEQLEN_CP_AXES = "nvte_seqlen_cp"
24
25
26
27
28
29
30
31
HEAD_AXES = "nvte_head"
HIDDEN_AXES = "nvte_hidden"
HIDDEN_TP_AXES = "nvte_hidden_tp"
JOINED_AXES = "nvte_joined"
W_NO_SHARD_AXES = "nvte_w_no_shard"
W_FSDP_AXES = "nvte_w_fsdp"
W_TP_AXES = "nvte_w_tp"
W_JOINED_AXES = "nvte_w_joined"
32

33

34
def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
35
    assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
36
37
38
    return mesh.shape[resource], resource


39
40
41
42
43
44
45
46
def get_sharding_map_logic_axis_to_mesh_axis():
    """
    Generate a dict to map logical axes to mesh axes.
    """
    gsr = global_mesh_resource()

    IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False)))

47
48
49
50
51
    batch_resources = (
        [gsr.fsdp_resource, gsr.dp_resource]
        if IS_FSDP_OUTER
        else [gsr.dp_resource, gsr.fsdp_resource]
    )
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

    batch_dim_rule = []
    for resource in batch_resources:
        if resource is not None and resource not in batch_dim_rule:
            batch_dim_rule.append(resource)

    if len(batch_dim_rule) <= 0:
        batch_dim_rule = None
    elif len(batch_dim_rule) == 1:
        batch_dim_rule = batch_dim_rule[0]
    else:
        batch_dim_rule = tuple(batch_dim_rule)

    te_logical_axis_to_mesh_axis = {
        BATCH_AXES: batch_dim_rule,
        SEQLEN_AXES: None,
        SEQLEN_TP_AXES: gsr.tp_resource,
69
        SEQLEN_CP_AXES: gsr.cp_resource,
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        HEAD_AXES: gsr.tp_resource,
        HIDDEN_AXES: None,
        HIDDEN_TP_AXES: gsr.tp_resource,
        JOINED_AXES: None,
        W_NO_SHARD_AXES: None,
        W_FSDP_AXES: gsr.fsdp_resource,
        W_TP_AXES: gsr.tp_resource,
        W_JOINED_AXES: None,
    }
    return te_logical_axis_to_mesh_axis


def generate_pspec(logical_axis_names):
    """
    Convert logical axes to PartitionSpec
    """
    rules = get_sharding_map_logic_axis_to_mesh_axis()
    mesh_axis_names = [rules[name] for name in logical_axis_names]
    pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
    return pspec


def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
    """
    A wrapper function to jax.lax.with_sharding_constraint to
    support the case that Mesh is empty.
    """
    if pspec is None:
        return x

    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    if mesh.empty:
        return x
    return jax.lax.with_sharding_constraint(x, pspec)


def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: tuple | list):
    """
    A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
    """
    if logical_axis_names is None:
        return x

    assert len(x.shape) == len(logical_axis_names)
    pspec = generate_pspec(logical_axis_names)
    return with_sharding_constraint(x, pspec)


118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def get_all_mesh_axes():
    """
    Get all name of mesh axes
    """
    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    return mesh.axis_names


def get_padded_spec(spec, ndim):
    """
    Get padded spec for partitioning from arguments' information
    """
    if spec is None:
        return (None,) * ndim
    assert len(spec) <= ndim
    return spec + (None,) * (ndim - len(spec))


136
137
138
def lax_paral_op(
    x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs
):
139
140
141
142
    """
    A wrapper function to invoke lax.p* operations, like psum.
    """
    if mesh_resource is not None:
143
        _, resource = _get_mesh_info(mesh_resource, mesh)
144
        return ops(x, resource, **kwargs)
145
146
147
148
149
150
151
152
153
154
    return x


def num_of_devices():
    """
    Get total number of detected devices
    """
    return len(jax.devices())


155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def get_mesh_axis_size(axis, mesh=None):
    """
    Get the axis size of the given mesh.
    If the mesh is None, it would be replaced
    by the global mesh.
    """
    if mesh is None:
        mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh

    if axis is None:
        return 1

    assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}"
    return mesh.shape[axis]


def get_mesh_axis_rank(axis: str, mesh=None):
    """
    Gets the local axis rank of the `axis` of the array.
    If the mesh is None the rank is 0.
    """
    if mesh is None:
        return 0
    _, axis_name = _get_mesh_info(axis, mesh)
    return jax.lax.axis_index(axis_name)


182
@dataclass
183
class MeshResource:
184
185
186
187
188
189
190
    """
    A data container to indicate which axis in Mesh for data parallelism and
    which for tensor parallelism.

    Parameters
    ----------
    dp_resource : str, default = None
191
192
        The axis name in Mesh used to shard batches along.
        If it is None, then data parallelism is disabled.
193
    tp_resource : str, default = None
194
195
        The axis name in Mesh used to split the hidden dimensions along.
        If it is None, then tensor parallelism is disabled.
196
197
198
199
    fsdp_resource : str, default = None
        The axis name in Mesh used to split the batch and weights along.
        If it is None, then full-sharded data parallelism is disabled.
    pp_resource : str, default = None
200
        The axis name in Mesh used to split model layers along.
201
        If it is None, then pipeline parallelism is disabled.
202
203
204
    cp_resource : str, default = None
        The axis name in Mesh used to split sequence (context) dimensions along
        in the attention. If it is None, then context parallelism is disabled.
205
    """
206

207
208
    dp_resource: str = None
    tp_resource: str = None
209
    fsdp_resource: str = None
210
    pp_resource: str = None
211
    cp_resource: str = None
212
213


214
_GLOBAL_MESH_RESOURCE = MeshResource()
215
216
217


@contextmanager
218
def global_shard_guard(resource: MeshResource):
219
    """
220
    A context manager to switch the global MeshResource
221
    """
222
223
    global _GLOBAL_MESH_RESOURCE
    prev_gmr = _GLOBAL_MESH_RESOURCE
224
    try:
225
        _GLOBAL_MESH_RESOURCE = resource
226
227
        yield
    finally:
228
229
        _GLOBAL_MESH_RESOURCE = prev_gmr

230

231
232
233
234
235
def global_mesh_resource() -> MeshResource:
    """
    A getter of the global MeshResource
    """
    return _GLOBAL_MESH_RESOURCE
236

237

238
def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
239
    """
240
    All-Reduce (Sum) along DP and FSDP mesh axes.
241
    """
242
243
    x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
    return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
244
245


246
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
247
248
249
250
251
252
    """
    All-Reduce (Max) along all mesh axes.
    """
    all_axes = get_all_mesh_axes()
    for axis in all_axes:
        if axis != global_mesh_resource().pp_resource:
253
            x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
254
255
256
257
258
259
260
    return x


# Deprecating Items ---------------------------------------------------------------
ShardingResource = MeshResource

global_shard_resource = global_mesh_resource
261
262
263


class MajorShardingType(Enum):
264
    r"""
265
    The major sharding type to indicate sharding pattern.
266
267
    .. warning::
        MajorShardingType is deprecating in the near feature.
268
269
270
271
272
273
274
275
276
277
278

    Values
    ----------
    SINGLE:
        Single process training.
    DP:
        Data parallel training.
    TP:
        Standard tensor parallel training.
    DPTP:
        Data and Standard tensor parallel training.
279
    """
280

281
282
283
284
285
286
287
288
289
    SINGLE = 0
    DP = 1
    TP = 2
    DPTP = 3


class ShardingType(Enum):
    """
    The sharding type to indicate sharding pattern.
290
291
    .. warning::
        ShardingType is deprecating in the near feature.
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

    Values
    ----------
    SINGLE:
        No sharding.
    DP:
        Sharding along data parallelism.
    TP_COL:
        Sharding along column-split tensor parallelism.
    TP_ROW:
        Sharding along row-split tensor parallelism.
    DP_TP_COL:
        Sharding along data and column-split tensor parallelism.
    DP_TP_ROW:
        Sharding along data and row-split tensor parallelism.
307
    """
308

309
310
311
312
313
314
    SINGLE = (MajorShardingType.SINGLE, "single")
    DP = (MajorShardingType.DP, "dp")
    TP_COL = (MajorShardingType.TP, "tp_col")
    TP_ROW = (MajorShardingType.TP, "tp_row")
    DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
    DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")