amax.py 12.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for amax calculation"""
from enum import Enum


import jax
import jax.numpy as jnp
from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    NamedSharding,
)
from ..sharding import (
    global_mesh_resource,
    lax_paral_op,
)
from ..quantize import (
    get_wgrad_sign_vector,
    get_sign_from_vector,
)


__all__ = ["AmaxScope", "calculate_amax", "calculate_post_rht_amax"]


class AmaxScope(Enum):
    """
    Amax Scope Enum
    """

    LOCAL = 1
    TPSP = 2
    FSDP = 3

    def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh):
        """Reduce the amax based on its scope"""
        gmesh = global_mesh_resource()
        sequence_dim = 0 if transpose_batch_sequence else 1
        # Run AR across TPSP only when tensor-sequence is detected in the input spec
        if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource:
            return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
        # Run AR across FSDP
        if self is AmaxScope.FSDP:
            return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
        return amax


class AmaxCalculationPrimitive(BasePrimitive):
    """
    Amax Calculation Primitive with custom_partitioning
    """

    name = "jax_local_amax"
    multiple_results = False
    impl_static_args = (
        1,
        2,
    )  # amax_scope, transpose_batch_sequence
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        *,
        amax_scope,
        transpose_batch_sequence,
    ):
        """
        amax calcuation abstract
        """
        del amax_scope, transpose_batch_sequence

        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
        return out_aval

    @staticmethod
    def impl(
        x,
        amax_scope,
        transpose_batch_sequence,
    ):
        """
        amax calcuation implementation
        """
        del amax_scope, transpose_batch_sequence
        amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
        return amax

    @staticmethod
    def infer_sharding_from_operands(
        amax_scope,
        transpose_batch_sequence,
        mesh,
        arg_infos,
        result_infos,
    ):
        """
        amax calcuation infer_sharding_from_operands
        """
        del (amax_scope, transpose_batch_sequence, arg_infos, result_infos)  # Unused.
        amax_sharding = NamedSharding(
            mesh,
            PartitionSpec(None),
            desc="AmaxCalculationPrimitive.out_sharding",
        )
        return amax_sharding

    @staticmethod
    def partition(
        amax_scope,
        transpose_batch_sequence,
        mesh,
        arg_infos,
        result_infos,
    ):
        """
        amax calcuation partition
        """
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        amax_sharding = NamedSharding(
            mesh,
            PartitionSpec(None),
            desc="AmaxCalculation.amax_sharding",
        )

        def sharded_impl(x):
            amax = AmaxCalculationPrimitive.impl(
                x,
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
            )
            amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
                amax, x_spec, transpose_batch_sequence, mesh
            )

            return amax

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        return mesh, sharded_impl, amax_sharding, arg_shardings

    @staticmethod
    def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types):
        """
        amax calcuation shardy_sharding_rule
        """
        del amax_scope, transpose_batch_sequence, mesh, result_types
        prefix = "AmaxCal"
        input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
        output_spec = (f"{prefix}_amax",)
        return SdyShardingRule((input_spec,), (output_spec,))


register_primitive(AmaxCalculationPrimitive, outer_only=True)


class RHTAmaxCalculationPrimitive(BasePrimitive):
    """
    Amax Calculation Primitive with custom_partitioning for calculating regular and post-Random Hadamard Transform (RHT) amax using TE's fused kernels.
    """

    name = "te_rht_amax_ffi"
    multiple_results = True
    impl_static_args = (
        1,  # amax_scope
        2,  # transpose_batch_sequence
        3,  # rht_matrix_random_sign_mask_t
        4,  # produce_regular_amax
        5,  # flatten_axis
    )
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        *,
        amax_scope,
        transpose_batch_sequence,
        rht_matrix_random_sign_mask_t,
        produce_regular_amax,
        flatten_axis,
    ):
        """
        amax calcuation abstract
        """
        del (
            amax_scope,
            transpose_batch_sequence,
            rht_matrix_random_sign_mask_t,
            produce_regular_amax,
            flatten_axis,
        )

        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.bfloat16], f"RHT requires input to be bfloat16, but got {dtype}"

        amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
        post_rht_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        return amax_aval, post_rht_amax_aval

    @staticmethod
    def lowering(
        ctx,
        x,
        *,
        amax_scope,
        transpose_batch_sequence,
        rht_matrix_random_sign_mask_t,
        produce_regular_amax,
        flatten_axis,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del amax_scope, transpose_batch_sequence
        (x_aval,) = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        flatten_axis = flatten_axis if flatten_axis >= 0 else flatten_axis + len(x_aval.shape)
        assert 0 < flatten_axis < len(x_aval.shape), "Flatten axis out of bounds!"

        return ffi.ffi_lowering(
            RHTAmaxCalculationPrimitive.name,
        )(
            ctx,
            x,
            rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t,
            produce_regular_amax=produce_regular_amax,
            flatten_axis=flatten_axis,
        )

    @staticmethod
    def impl(
        x,
        amax_scope,
        transpose_batch_sequence,
        rht_matrix_random_sign_mask_t,
        produce_regular_amax,
        flatten_axis,
    ):
        """
        amax calcuation implementation
        """
        assert RHTAmaxCalculationPrimitive.inner_primitive is not None
        (
            amax,
            post_rht_amax,
        ) = RHTAmaxCalculationPrimitive.inner_primitive.bind(
            x,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t,
            produce_regular_amax=produce_regular_amax,
            flatten_axis=flatten_axis,
        )
        return amax, post_rht_amax

    @staticmethod
    def infer_sharding_from_operands(
        amax_scope,
        transpose_batch_sequence,
        rht_matrix_random_sign_mask_t,
        produce_regular_amax,
        flatten_axis,
        mesh,
        arg_infos,
        result_infos,
    ):
        """
        amax calcuation infer_sharding_from_operands
        """
        del (
            amax_scope,
            transpose_batch_sequence,
            rht_matrix_random_sign_mask_t,
            produce_regular_amax,
            flatten_axis,
            arg_infos,
            result_infos,
        )  # Unused.
        amax_sharding = NamedSharding(
            mesh,
            PartitionSpec(None),
            desc="RHTAmaxCalculationPrimitive.out_sharding",
        )
        return amax_sharding, amax_sharding

    @staticmethod
    def partition(
        amax_scope,
        transpose_batch_sequence,
        rht_matrix_random_sign_mask_t,
        produce_regular_amax,
        flatten_axis,
        mesh,
        arg_infos,
        result_infos,
    ):
        """
        amax calcuation partition
        """
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        amax_sharding = NamedSharding(
            mesh,
            PartitionSpec(None),
            desc="RHTAmaxCalculationPrimitive.amax_sharding",
        )
        out_shardings = (amax_sharding, amax_sharding)

        def sharded_impl(x):
            amax, post_rht_amax = RHTAmaxCalculationPrimitive.impl(
                x,
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t,
                produce_regular_amax=produce_regular_amax,
                flatten_axis=flatten_axis,
            )
            amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
                amax, x_spec, transpose_batch_sequence, mesh
            )
            post_rht_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
                post_rht_amax, x_spec, transpose_batch_sequence, mesh
            )

            return amax, post_rht_amax

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        return mesh, sharded_impl, out_shardings, arg_shardings

    @staticmethod
    def shardy_sharding_rule(
        amax_scope,
        transpose_batch_sequence,
        rht_matrix_random_sign_mask_t,
        produce_regular_amax,
        flatten_axis,
        mesh,
        value_types,
        result_types,
    ):
        """
        amax calcuation shardy_sharding_rule
        """
        del (
            amax_scope,
            transpose_batch_sequence,
            rht_matrix_random_sign_mask_t,
            produce_regular_amax,
            flatten_axis,
            mesh,
            result_types,
        )
        prefix = "RHTAmaxCal"
        input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
        output_amax_spec = (f"{prefix}_amax",)
        output_post_rht_amax_spec = (f"{prefix}_post_rht_amax",)
        return SdyShardingRule((input_spec,), (output_amax_spec, output_post_rht_amax_spec))


register_primitive(RHTAmaxCalculationPrimitive)


def calculate_amax(x: jnp.ndarray, amax_scope: AmaxScope, transpose_batch_sequence: bool):
    """
    Compute the maximum absolute value (amax) of the input tensor.
    """
    assert AmaxCalculationPrimitive.outer_primitive is not None
    return AmaxCalculationPrimitive.outer_primitive.bind(
        x,
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
    )


def calculate_post_rht_amax(
    x: jnp.ndarray,
    amax_scope: AmaxScope,
    transpose_batch_sequence: bool,
    produce_regular_amax: bool,
    flatten_axis: int,
):
    """Compute the post-Random Hadamard Transform (RHT) amax of the input tensor, and optionally the regular amax.

    Args:
        x: Input tensor.
        amax_scope: The scope for amax reduction (local, TPSP, or FSDP).
        transpose_batch_sequence: Whether the input tensor has its batch and sequence dimensions transposed.
        produce_regular_amax: Whether to compute and return the regular amax alongside the post-RHT amax.
        flatten_axis: The axis at which to flatten the input tensor before applying RHT.
    Returns:
        A tuple containing:
            - The regular amax if `produce_regular_amax` is True, otherwise None.
            - The post-RHT amax.
    """
    amax, post_rht_amax = RHTAmaxCalculationPrimitive.outer_primitive.bind(
        x,
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        rht_matrix_random_sign_mask_t=get_sign_from_vector(get_wgrad_sign_vector()),
        produce_regular_amax=produce_regular_amax,
        flatten_axis=flatten_axis,
    )

    if produce_regular_amax:
        return amax, post_rht_amax
    return None, post_rht_amax