activation.py 17.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
from typing import Tuple, Sequence, Union, Callable
import operator
7
from functools import reduce, partial
8
from packaging import version
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes
13
14
15
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding

16
17
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
18
19
20
21
22
23
24

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    jax_dtype_to_ir_dtype,
25
    get_padded_spec,
26
    is_ffi_enabled,
27
)
28
from .quantization import _jax_cast_fp8
29
30
from ..sharding import all_reduce_max_along_all_axes_except_PP

31
32
33
34
35
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

36

37
__all__ = ["act_lu", "dact_lu", "act_lu_fp8"]
38
39
40


ActivationEnum = {
41
42
43
44
45
46
47
48
49
50
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
51
52
53
}


54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


def _jax_act_lu(inputs, activation_type):
    """
    JAX native activation implementation
    """
    x = jnp.split(inputs, len(activation_type), axis=-2)
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
    x = jnp.squeeze(x, axis=-2)
    return x


83
84
85
86
class ActLuPrimitive(BasePrimitive):
    """
    Activation Forward Primitive
    """
87

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    name = "te_act_lu"
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
    impl_static_args = (1,)

    @staticmethod
    def abstract(x_aval, *, act_enum):  # pylint: disable=unused-argument
        """
        act_lu abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        x_shape = x_aval.shape
103
        assert x_shape[-2] == 2 or x_shape[-2] == 1
104
105
        hidden_size = x_shape[-1]
        batch_shapes = x_shape[:-2]
106
        out_aval = x_aval
107
108
109
110
111
112
113
114
115
116
117
118
        out_shape = (batch_shapes) + (hidden_size,)
        out_aval = out_aval.update(shape=out_shape, dtype=dtype)

        return out_aval

    @staticmethod
    def lowering(ctx, x, *, act_enum):
        """
        act_lu lowering rules
        """
        (x_aval,) = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        if is_ffi_enabled():
            name = "te_act_lu_ffi"
            out = ffi.ffi_lowering(name)(ctx, x, act_enum=act_enum)
        else:
            ir_x_type = ir.RankedTensorType(x.type)
            ir_x_shape = ir_x_type.shape
            out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]

            out_types = [
                ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
            ]
            operands = [x]
            operand_shapes = [ir_x_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            hidden_size = ir_x_shape[-1]
            batch_size = reduce(operator.mul, ir_x_shape[:-2])
            in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
            opaque = transformer_engine_jax.pack_common_descriptor(
                (batch_size, hidden_size), in_dtype, in_dtype, act_enum
            )
140

141
            out = custom_caller(ActLuPrimitive.name, args, opaque, False)
142

143
        return out
144
145
146
147
148
149
150
151
152
153
154
155
156
157

    @staticmethod
    def impl(x, act_enum):
        assert ActLuPrimitive.inner_primitive is not None
        out = ActLuPrimitive.inner_primitive.bind(x, act_enum=act_enum)
        return out

    @staticmethod
    def batcher(batched_args, batch_dims, *, act_enum):
        """
        act_lu batcher
        """
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
158
159
        (inputs,) = batched_args
        (inputs_bdim,) = batch_dims
160
161
162
163
164
165
166
167
168

        out_bdims = inputs_bdim
        return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_enum), out_bdims

    @staticmethod
    def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
        """
        act_lu infer_sharding_from_operands
        """
169
        del result_infos, act_enum  # Unused.
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        return out_sharding

    @staticmethod
    def partition(act_enum, mesh, arg_infos, result_infos):
        """
        act_lu partitioning
        """
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))

        def sharded_impl(x):
            return ActLuPrimitive.impl(x, act_enum=act_enum)

        return mesh, sharded_impl, out_sharding, arg_shardings


register_primitive(ActLuPrimitive)


def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
    """
    act_lu wrapper
    Return act_lu(inputs)
    Input shape: (N, 1, H) for non-gated activations
                 (N, 2, H) for gated activations
    """
200
201
202
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(inputs, activation_type)

203
    act_type_id = ActivationEnum[activation_type].value
204
205
206
207
208
209
210
    return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)


class DActLuPrimitive(BasePrimitive):
    """
    Dgated ActLu Primitive
    """
211

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    name = "te_dact_lu"
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
    impl_static_args = (2,)

    @staticmethod
    def abstract(dz_aval, x_aval, *, act_enum):  # pylint: disable=unused-argument
        """
        dact_lu abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        for axis in range(len(dz_aval.shape) - 1):
            assert dz_aval.shape[axis] == x_aval.shape[axis]
228
        assert x_aval.shape[-2] == 2 or x_aval.shape[-2] == 1
229
230
231
232

        i_hidden_size = dz_aval.shape[-1]
        g_hidden_size = x_aval.shape[-1]
        assert i_hidden_size == g_hidden_size
233
        out_aval = x_aval
234
235
236
237
238
239
240
241
242
243
244

        return out_aval

    @staticmethod
    def lowering(ctx, dz, x, *, act_enum):
        """
        dact_lu lowering rules
        """
        in_aval, gi_aval = ctx.avals_in
        assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gi_aval.dtype == in_aval.dtype
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        if is_ffi_enabled():
            name = "te_dact_lu_ffi"
            out = ffi.ffi_lowering(name)(ctx, dz, x, act_enum=act_enum)
        else:
            ir_in_type = ir.RankedTensorType(dz.type)
            ir_in_shape = ir_in_type.shape
            gi_type = ir.RankedTensorType(x.type)
            gi_shape = gi_type.shape
            #        assert ir_in_shape == gi_shape
            for axis in range(len(ir_in_shape) - 1):
                assert ir_in_shape[axis] == gi_shape[axis]

            ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
            i_hidden_size = ir_in_shape[-1]
            g_hidden_size = gi_shape[-1]
            assert i_hidden_size == g_hidden_size
            out_dtype = ir_in_type.element_type
            out_shape = gi_shape

            out_types = [
                ir.RankedTensorType.get(out_shape, out_dtype),
            ]
            operands = [dz, x]
            operand_shapes = [ir_in_shape, gi_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
            opaque = transformer_engine_jax.pack_common_descriptor(
                (ir_batch_size, i_hidden_size), in_dtype, in_dtype, act_enum
            )
275

276
            out = custom_caller(DActLuPrimitive.name, args, opaque, False)
277

278
        return out
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

    @staticmethod
    def impl(dz, x, act_enum):
        """
        dact_lu implementation
        """
        assert DActLuPrimitive.inner_primitive is not None
        dx = DActLuPrimitive.inner_primitive.bind(dz, x, act_enum=act_enum)
        return dx

    @staticmethod
    def batcher(batched_args, batch_dims, *, act_enum):
        """
        dact_lu batcher
        """
        check_valid_batch_dims(batch_dims)
        assert DActLuPrimitive.outer_primitive is not None
        dz, x = batched_args
        _, x_bdim = batch_dims

        out_bdims = x_bdim
        return DActLuPrimitive.outer_primitive.bind(dz, x, act_enum=act_enum), out_bdims

    @staticmethod
    def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
        """
        dact_lu infer_sharding_from_operands
        """
307
        del result_infos, act_enum  # Unused.
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        act_lu_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*act_lu_out_spec))
        return dx_sharding

    @staticmethod
    def partition(act_enum, mesh, arg_infos, result_infos):
        """
        dact_lu partition
        """
        del result_infos
        dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = dx_sharding

        def sharded_impl(dz, x):
            return DActLuPrimitive.impl(dz, x, act_enum=act_enum)

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DActLuPrimitive)


331
332
333
def dact_lu(
    inputs: jnp.ndarray, act_lu_inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]
) -> jnp.ndarray:
334
335
336
337
    """
    dact_lu fusion wrapper
    Return dgated_act_lu(inputs)
    """
338
339
340
341
    if not DActLuPrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs)
        return vjp_func(inputs)[0]

342
    act_type_id = ActivationEnum[activation_type].value
343
344
345
346
347
348
349
    return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)


class ActLuFp8Primitive(BasePrimitive):
    """
    ActLu FP8 Primitive
    """
350

351
352
    name = "te_act_lu_fp8"
    multiple_results = True
353
    impl_static_args = (4, 5)  # out_dtype, act_enum
354
355
356
357
    inner_primitive = None
    outer_primitive = None

    @staticmethod
358
359
360
    def abstract(
        x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, act_enum
    ):  # pylint: disable=unused-argument
361
362
363
364
365
366
367
368
369
370
371
        """
        te_act_lu_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

372
        assert x_aval.shape[-2] == 1 or x_aval.shape[-2] == 2
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        hidden_size = x_aval.shape[-1]
        batch_shape = x_aval.shape[:-2]
        out_shape = (batch_shape) + (hidden_size,)
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return out_aval, updated_amax_aval

    @staticmethod
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum):
        """
        te_gated_act_lu_p lowering rules
        """
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        if is_ffi_enabled():
            name = "te_act_lu_fp8_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
                ctx, x, amax, scale, scale_inv, act_enum=act_enum
            )
        else:
            ir_x_type = ir.RankedTensorType(x.type)
            ir_x_shape = ir_x_type.shape
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape

            hidden_size = ir_x_shape[-1]
            batch_shape = ir_x_shape[:-2]
            batch_size = reduce(operator.mul, batch_shape)
            out_shape = batch_shape + [hidden_size]
            out_types = [
                ir.RankedTensorType.get(out_shape, ir_out_dtype),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ]
            operands = [x, amax, scale, scale_inv]
            operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            opaque = transformer_engine_jax.pack_common_descriptor(
                (batch_size, hidden_size),
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                act_enum,
            )

            out = custom_caller(
                ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
            )
428
429
430
431
432
433
434
435
436

        return out

    @staticmethod
    def impl(x, amax, scale, scale_inv, out_dtype, act_enum):
        """
        to describe implementation
        """
        assert ActLuFp8Primitive.inner_primitive is not None
437
438
439
        out, updated_amax = ActLuFp8Primitive.inner_primitive.bind(
            x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
        )
440
441
442
443
444
445
446
447
448
449
450
451
452
        return out, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, act_enum):
        """
        to describe batch rules for vmap
        """
        check_valid_batch_dims(batch_dims)
        assert ActLuFp8Primitive.outer_primitive is not None
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, amax_bdim
453
454
455
456
457
458
        return (
            ActLuFp8Primitive.outer_primitive.bind(
                x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
            ),
            out_bdims,
        )
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477

    @staticmethod
    def infer_sharding_from_operands(out_dtype, act_enum, mesh, arg_infos, result_infos):
        del out_dtype, result_infos, act_enum
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, act_enum, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (out_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
478
479
480
            local_x, local_amax = ActLuFp8Primitive.impl(
                x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
            )
481
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
482
483
484
485
486
487
488
489
490

            return local_x, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(ActLuFp8Primitive)


491
492
493
494
495
496
497
498
def act_lu_fp8(
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: jnp.dtype,
    activation_type: Sequence[Union[str, Callable]],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
499
500
501
502
503
504
    """
    act wrapper
    Return FP8(act_lu(x))
    Input shape: (N, 1, H) for non-gated activations
                 (N, 2, H) for gated activations
    """
505
506
507
508
509
    if not ActLuFp8Primitive.enabled():
        act_lu_output = _jax_act_lu(x, activation_type)
        casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype)
        return casted_output, updated_amax

510
    act_type_id = ActivationEnum[activation_type].value
511
512
513
    return ActLuFp8Primitive.outer_primitive.bind(
        x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id
    )