sharding.py 7.43 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Sharding Meta for xmap with CustomCall
"""
7
import os
8
9
10
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
11
from typing import Callable
12
13
14
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
15
from jax.sharding import PartitionSpec
16
17
18

_PXLA_THREAD_RESOURCES = pxla.thread_resources

19
# Axis Names
20
21
22
23
24
25
26
27
28
29
30
BATCH_AXES = "nvte_batch"
SEQLEN_AXES = "nvte_seqlen"
SEQLEN_TP_AXES = "nvte_seqlen_tp"
HEAD_AXES = "nvte_head"
HIDDEN_AXES = "nvte_hidden"
HIDDEN_TP_AXES = "nvte_hidden_tp"
JOINED_AXES = "nvte_joined"
W_NO_SHARD_AXES = "nvte_w_no_shard"
W_FSDP_AXES = "nvte_w_fsdp"
W_TP_AXES = "nvte_w_tp"
W_JOINED_AXES = "nvte_w_joined"
31

32
33
34

def _get_mesh_info(resource: str):
    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
35
    assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
36
37
38
    return mesh.shape[resource], resource


39
40
41
42
43
44
45
46
def get_sharding_map_logic_axis_to_mesh_axis():
    """
    Generate a dict to map logical axes to mesh axes.
    """
    gsr = global_mesh_resource()

    IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False)))

47
48
49
50
51
    batch_resources = (
        [gsr.fsdp_resource, gsr.dp_resource]
        if IS_FSDP_OUTER
        else [gsr.dp_resource, gsr.fsdp_resource]
    )
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

    batch_dim_rule = []
    for resource in batch_resources:
        if resource is not None and resource not in batch_dim_rule:
            batch_dim_rule.append(resource)

    if len(batch_dim_rule) <= 0:
        batch_dim_rule = None
    elif len(batch_dim_rule) == 1:
        batch_dim_rule = batch_dim_rule[0]
    else:
        batch_dim_rule = tuple(batch_dim_rule)

    te_logical_axis_to_mesh_axis = {
        BATCH_AXES: batch_dim_rule,
        SEQLEN_AXES: None,
        SEQLEN_TP_AXES: gsr.tp_resource,
        HEAD_AXES: gsr.tp_resource,
        HIDDEN_AXES: None,
        HIDDEN_TP_AXES: gsr.tp_resource,
        JOINED_AXES: None,
        W_NO_SHARD_AXES: None,
        W_FSDP_AXES: gsr.fsdp_resource,
        W_TP_AXES: gsr.tp_resource,
        W_JOINED_AXES: None,
    }
    return te_logical_axis_to_mesh_axis


def generate_pspec(logical_axis_names):
    """
    Convert logical axes to PartitionSpec
    """
    rules = get_sharding_map_logic_axis_to_mesh_axis()
    mesh_axis_names = [rules[name] for name in logical_axis_names]
    pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
    return pspec


def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
    """
    A wrapper function to jax.lax.with_sharding_constraint to
    support the case that Mesh is empty.
    """
    if pspec is None:
        return x

    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    if mesh.empty:
        return x
    return jax.lax.with_sharding_constraint(x, pspec)


def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: tuple | list):
    """
    A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
    """
    if logical_axis_names is None:
        return x

    assert len(x.shape) == len(logical_axis_names)
    pspec = generate_pspec(logical_axis_names)
    return with_sharding_constraint(x, pspec)


117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def get_all_mesh_axes():
    """
    Get all name of mesh axes
    """
    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    return mesh.axis_names


def get_padded_spec(spec, ndim):
    """
    Get padded spec for partitioning from arguments' information
    """
    if spec is None:
        return (None,) * ndim
    assert len(spec) <= ndim
    return spec + (None,) * (ndim - len(spec))


def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
    """
    A wrapper function to invoke lax.p* operations, like psum.
    """
    if mesh_resource is not None:
        _, resource = _get_mesh_info(mesh_resource)
        return ops(x, resource)
    return x


def num_of_devices():
    """
    Get total number of detected devices
    """
    return len(jax.devices())


152
@dataclass
153
class MeshResource:
154
155
156
157
158
159
160
    """
    A data container to indicate which axis in Mesh for data parallelism and
    which for tensor parallelism.

    Parameters
    ----------
    dp_resource : str, default = None
161
162
        The axis name in Mesh used to shard batches along.
        If it is None, then data parallelism is disabled.
163
    tp_resource : str, default = None
164
165
        The axis name in Mesh used to split the hidden dimensions along.
        If it is None, then tensor parallelism is disabled.
166
167
168
169
170
171
    fsdp_resource : str, default = None
        The axis name in Mesh used to split the batch and weights along.
        If it is None, then full-sharded data parallelism is disabled.
    pp_resource : str, default = None
        The axis name in Mesh used to split model layers. along.
        If it is None, then pipeline parallelism is disabled.
172
    """
173

174
175
    dp_resource: str = None
    tp_resource: str = None
176
    fsdp_resource: str = None
177
    pp_resource: str = None
178
179


180
_GLOBAL_MESH_RESOURCE = MeshResource()
181
182
183


@contextmanager
184
def global_shard_guard(resource: MeshResource):
185
    """
186
    A context manager to switch the global MeshResource
187
    """
188
189
    global _GLOBAL_MESH_RESOURCE
    prev_gmr = _GLOBAL_MESH_RESOURCE
190
    try:
191
        _GLOBAL_MESH_RESOURCE = resource
192
193
        yield
    finally:
194
195
        _GLOBAL_MESH_RESOURCE = prev_gmr

196

197
198
199
200
201
def global_mesh_resource() -> MeshResource:
    """
    A getter of the global MeshResource
    """
    return _GLOBAL_MESH_RESOURCE
202

203
204

def all_reduce_sum_along_dp_fsdp(x: jnp.array):
205
    """
206
    All-Reduce (Sum) along DP and FSDP mesh axes.
207
    """
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource)
    return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource)


def all_reduce_max_along_all_axes_except_PP(x: jnp.array):
    """
    All-Reduce (Max) along all mesh axes.
    """
    all_axes = get_all_mesh_axes()
    for axis in all_axes:
        if axis != global_mesh_resource().pp_resource:
            x = lax_paral_op(x, jax.lax.pmax, axis)
    return x


# Deprecating Items ---------------------------------------------------------------
ShardingResource = MeshResource

global_shard_resource = global_mesh_resource
227
228
229


class MajorShardingType(Enum):
230
    r"""
231
    The major sharding type to indicate sharding pattern.
232
233
    .. warning::
        MajorShardingType is deprecating in the near feature.
234
235
236
237
238
239
240
241
242
243
244

    Values
    ----------
    SINGLE:
        Single process training.
    DP:
        Data parallel training.
    TP:
        Standard tensor parallel training.
    DPTP:
        Data and Standard tensor parallel training.
245
    """
246

247
248
249
250
251
252
253
254
255
    SINGLE = 0
    DP = 1
    TP = 2
    DPTP = 3


class ShardingType(Enum):
    """
    The sharding type to indicate sharding pattern.
256
257
    .. warning::
        ShardingType is deprecating in the near feature.
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

    Values
    ----------
    SINGLE:
        No sharding.
    DP:
        Sharding along data parallelism.
    TP_COL:
        Sharding along column-split tensor parallelism.
    TP_ROW:
        Sharding along row-split tensor parallelism.
    DP_TP_COL:
        Sharding along data and column-split tensor parallelism.
    DP_TP_ROW:
        Sharding along data and row-split tensor parallelism.
273
    """
274

275
276
277
278
279
280
    SINGLE = (MajorShardingType.SINGLE, "single")
    DP = (MajorShardingType.DP, "dp")
    TP_COL = (MajorShardingType.TP, "tp_col")
    TP_ROW = (MajorShardingType.TP, "tp_row")
    DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
    DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")