sharding.py 5.21 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.
"""
Sharding Meta for xmap with CustomCall
"""

from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
11
from typing import Callable
12
13
14
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
15
from jax.sharding import PartitionSpec
16
17
18
19
20
21
22
23
24
25
26

_PXLA_THREAD_RESOURCES = pxla.thread_resources


def _get_mesh_info(resource: str):
    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    assert resource in mesh.axis_names, \
        f"{resource} is not in the axis_names of Mesh {mesh}."
    return mesh.shape[resource], resource


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_all_mesh_axes():
    """
    Get all name of mesh axes
    """
    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    return mesh.axis_names


def get_padded_spec(spec, ndim):
    """
    Get padded spec for partitioning from arguments' information
    """
    if spec is None:
        return (None,) * ndim
    assert len(spec) <= ndim
    return spec + (None,) * (ndim - len(spec))


45
46
47
48
49
50
51
52
53
54
55
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
    """
    A wrapper function to jax.lax.with_sharding_constraint to
    support the case that Mesh is empty.
    """
    mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
    if mesh.empty:
        return x
    return jax.lax.with_sharding_constraint(x, pspec)


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
    """
    A wrapper function to invoke lax.p* operations, like psum.
    """
    if mesh_resource is not None:
        _, resource = _get_mesh_info(mesh_resource)
        return ops(x, resource)
    return x


def num_of_devices():
    """
    Get total number of detected devices
    """
    return len(jax.devices())


73
@dataclass
74
class MeshResource:
75
76
77
78
79
80
81
    """
    A data container to indicate which axis in Mesh for data parallelism and
    which for tensor parallelism.

    Parameters
    ----------
    dp_resource : str, default = None
82
83
        The axis name in Mesh used to shard batches along.
        If it is None, then data parallelism is disabled.
84
    tp_resource : str, default = None
85
86
        The axis name in Mesh used to split the hidden dimensions along.
        If it is None, then tensor parallelism is disabled.
87
88
89
90
91
92
    fsdp_resource : str, default = None
        The axis name in Mesh used to split the batch and weights along.
        If it is None, then full-sharded data parallelism is disabled.
    pp_resource : str, default = None
        The axis name in Mesh used to split model layers. along.
        If it is None, then pipeline parallelism is disabled.
93
94
95
    """
    dp_resource: str = None
    tp_resource: str = None
96
    fsdp_resource: str = None
97
    pp_resource: str = None
98
99


100
_GLOBAL_MESH_RESOURCE = MeshResource()
101
102
103


@contextmanager
104
def global_shard_guard(resource: MeshResource):
105
    """
106
    A context manager to switch the global MeshResource
107
    """
108
109
    global _GLOBAL_MESH_RESOURCE
    prev_gmr = _GLOBAL_MESH_RESOURCE
110
    try:
111
        _GLOBAL_MESH_RESOURCE = resource
112
113
        yield
    finally:
114
115
        _GLOBAL_MESH_RESOURCE = prev_gmr

116

117
118
119
120
121
def global_mesh_resource() -> MeshResource:
    """
    A getter of the global MeshResource
    """
    return _GLOBAL_MESH_RESOURCE
122

123
124

def all_reduce_sum_along_dp_fsdp(x: jnp.array):
125
    """
126
    All-Reduce (Sum) along DP and FSDP mesh axes.
127
    """
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource)
    return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource)


def all_reduce_max_along_all_axes_except_PP(x: jnp.array):
    """
    All-Reduce (Max) along all mesh axes.
    """
    all_axes = get_all_mesh_axes()
    for axis in all_axes:
        if axis != global_mesh_resource().pp_resource:
            x = lax_paral_op(x, jax.lax.pmax, axis)
    return x


# Deprecating Items ---------------------------------------------------------------
ShardingResource = MeshResource

global_shard_resource = global_mesh_resource
147
148
149


class MajorShardingType(Enum):
150
    r"""
151
    The major sharding type to indicate sharding pattern.
152
153
    .. warning::
        MajorShardingType is deprecating in the near feature.
154
155
156
157
158
159
160
161
162
163
164

    Values
    ----------
    SINGLE:
        Single process training.
    DP:
        Data parallel training.
    TP:
        Standard tensor parallel training.
    DPTP:
        Data and Standard tensor parallel training.
165
166
167
168
169
170
171
172
173
174
    """
    SINGLE = 0
    DP = 1
    TP = 2
    DPTP = 3


class ShardingType(Enum):
    """
    The sharding type to indicate sharding pattern.
175
176
    .. warning::
        ShardingType is deprecating in the near feature.
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

    Values
    ----------
    SINGLE:
        No sharding.
    DP:
        Sharding along data parallelism.
    TP_COL:
        Sharding along column-split tensor parallelism.
    TP_ROW:
        Sharding along row-split tensor parallelism.
    DP_TP_COL:
        Sharding along data and column-split tensor parallelism.
    DP_TP_ROW:
        Sharding along data and row-split tensor parallelism.
192
193
194
195
196
197
198
    """
    SINGLE = (MajorShardingType.SINGLE, "single")
    DP = (MajorShardingType.DP, "dp")
    TP_COL = (MajorShardingType.TP, "tp_col")
    TP_ROW = (MajorShardingType.TP, "tp_row")
    DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
    DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")