activation.py 16.9 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
from typing import Tuple, Sequence, Union, Callable
import operator
7
from functools import reduce, partial
8

9
import jax
10
11
12
13
import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
14
from jax.extend import ffi
15
16
17
18
19
20
21
22
23
24

from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import NVTE_Activation_Type

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    jax_dtype_to_ir_dtype,
25
    get_padded_spec,
26
    is_ffi_enabled,
27
)
28
from .quantization import _jax_cast_fp8
29
30
31
from ..sharding import all_reduce_max_along_all_axes_except_PP


32
__all__ = ["act_lu", "dact_lu", "act_lu_fp8"]
33
34
35


ActivationEnum = {
36
37
38
39
40
41
42
43
44
45
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
46
47
48
}


49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


def _jax_act_lu(inputs, activation_type):
    """
    JAX native activation implementation
    """
    x = jnp.split(inputs, len(activation_type), axis=-2)
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
    x = jnp.squeeze(x, axis=-2)
    return x


78
79
80
81
class ActLuPrimitive(BasePrimitive):
    """
    Activation Forward Primitive
    """
82

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    name = "te_act_lu"
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
    impl_static_args = (1,)

    @staticmethod
    def abstract(x_aval, *, act_enum):  # pylint: disable=unused-argument
        """
        act_lu abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        x_shape = x_aval.shape
98
        assert x_shape[-2] == 2 or x_shape[-2] == 1
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        hidden_size = x_shape[-1]
        batch_shapes = x_shape[:-2]
        out_aval = core.raise_to_shaped(x_aval)
        out_shape = (batch_shapes) + (hidden_size,)
        out_aval = out_aval.update(shape=out_shape, dtype=dtype)

        return out_aval

    @staticmethod
    def lowering(ctx, x, *, act_enum):
        """
        act_lu lowering rules
        """
        (x_aval,) = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        if is_ffi_enabled():
            name = "te_act_lu_ffi"
            out = ffi.ffi_lowering(name)(ctx, x, act_enum=act_enum)
        else:
            ir_x_type = ir.RankedTensorType(x.type)
            ir_x_shape = ir_x_type.shape
            out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]

            out_types = [
                ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
            ]
            operands = [x]
            operand_shapes = [ir_x_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            hidden_size = ir_x_shape[-1]
            batch_size = reduce(operator.mul, ir_x_shape[:-2])
            in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
            opaque = transformer_engine_jax.pack_common_descriptor(
                (batch_size, hidden_size), in_dtype, in_dtype, act_enum
            )
135

136
            out = custom_caller(ActLuPrimitive.name, args, opaque, False)
137

138
        return out
139
140
141
142
143
144
145
146
147
148
149
150
151
152

    @staticmethod
    def impl(x, act_enum):
        assert ActLuPrimitive.inner_primitive is not None
        out = ActLuPrimitive.inner_primitive.bind(x, act_enum=act_enum)
        return out

    @staticmethod
    def batcher(batched_args, batch_dims, *, act_enum):
        """
        act_lu batcher
        """
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
153
154
        (inputs,) = batched_args
        (inputs_bdim,) = batch_dims
155
156
157
158
159
160
161
162
163

        out_bdims = inputs_bdim
        return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_enum), out_bdims

    @staticmethod
    def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
        """
        act_lu infer_sharding_from_operands
        """
164
        del result_infos, act_enum  # Unused.
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        return out_sharding

    @staticmethod
    def partition(act_enum, mesh, arg_infos, result_infos):
        """
        act_lu partitioning
        """
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))

        def sharded_impl(x):
            return ActLuPrimitive.impl(x, act_enum=act_enum)

        return mesh, sharded_impl, out_sharding, arg_shardings


register_primitive(ActLuPrimitive)


def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
    """
    act_lu wrapper
    Return act_lu(inputs)
    Input shape: (N, 1, H) for non-gated activations
                 (N, 2, H) for gated activations
    """
195
196
197
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(inputs, activation_type)

198
    act_type_id = ActivationEnum[activation_type].value
199
200
201
202
203
204
205
    return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)


class DActLuPrimitive(BasePrimitive):
    """
    Dgated ActLu Primitive
    """
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    name = "te_dact_lu"
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
    impl_static_args = (2,)

    @staticmethod
    def abstract(dz_aval, x_aval, *, act_enum):  # pylint: disable=unused-argument
        """
        dact_lu abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        for axis in range(len(dz_aval.shape) - 1):
            assert dz_aval.shape[axis] == x_aval.shape[axis]
223
        assert x_aval.shape[-2] == 2 or x_aval.shape[-2] == 1
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

        i_hidden_size = dz_aval.shape[-1]
        g_hidden_size = x_aval.shape[-1]
        assert i_hidden_size == g_hidden_size
        out_aval = core.raise_to_shaped(x_aval)

        return out_aval

    @staticmethod
    def lowering(ctx, dz, x, *, act_enum):
        """
        dact_lu lowering rules
        """
        in_aval, gi_aval = ctx.avals_in
        assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gi_aval.dtype == in_aval.dtype
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        if is_ffi_enabled():
            name = "te_dact_lu_ffi"
            out = ffi.ffi_lowering(name)(ctx, dz, x, act_enum=act_enum)
        else:
            ir_in_type = ir.RankedTensorType(dz.type)
            ir_in_shape = ir_in_type.shape
            gi_type = ir.RankedTensorType(x.type)
            gi_shape = gi_type.shape
            #        assert ir_in_shape == gi_shape
            for axis in range(len(ir_in_shape) - 1):
                assert ir_in_shape[axis] == gi_shape[axis]

            ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
            i_hidden_size = ir_in_shape[-1]
            g_hidden_size = gi_shape[-1]
            assert i_hidden_size == g_hidden_size
            out_dtype = ir_in_type.element_type
            out_shape = gi_shape

            out_types = [
                ir.RankedTensorType.get(out_shape, out_dtype),
            ]
            operands = [dz, x]
            operand_shapes = [ir_in_shape, gi_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
            opaque = transformer_engine_jax.pack_common_descriptor(
                (ir_batch_size, i_hidden_size), in_dtype, in_dtype, act_enum
            )
270

271
            out = custom_caller(DActLuPrimitive.name, args, opaque, False)
272

273
        return out
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301

    @staticmethod
    def impl(dz, x, act_enum):
        """
        dact_lu implementation
        """
        assert DActLuPrimitive.inner_primitive is not None
        dx = DActLuPrimitive.inner_primitive.bind(dz, x, act_enum=act_enum)
        return dx

    @staticmethod
    def batcher(batched_args, batch_dims, *, act_enum):
        """
        dact_lu batcher
        """
        check_valid_batch_dims(batch_dims)
        assert DActLuPrimitive.outer_primitive is not None
        dz, x = batched_args
        _, x_bdim = batch_dims

        out_bdims = x_bdim
        return DActLuPrimitive.outer_primitive.bind(dz, x, act_enum=act_enum), out_bdims

    @staticmethod
    def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
        """
        dact_lu infer_sharding_from_operands
        """
302
        del result_infos, act_enum  # Unused.
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        act_lu_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*act_lu_out_spec))
        return dx_sharding

    @staticmethod
    def partition(act_enum, mesh, arg_infos, result_infos):
        """
        dact_lu partition
        """
        del result_infos
        dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = dx_sharding

        def sharded_impl(dz, x):
            return DActLuPrimitive.impl(dz, x, act_enum=act_enum)

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DActLuPrimitive)


326
327
328
def dact_lu(
    inputs: jnp.ndarray, act_lu_inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]
) -> jnp.ndarray:
329
330
331
332
    """
    dact_lu fusion wrapper
    Return dgated_act_lu(inputs)
    """
333
334
335
336
    if not DActLuPrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs)
        return vjp_func(inputs)[0]

337
    act_type_id = ActivationEnum[activation_type].value
338
339
340
341
342
343
344
    return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)


class ActLuFp8Primitive(BasePrimitive):
    """
    ActLu FP8 Primitive
    """
345

346
347
    name = "te_act_lu_fp8"
    multiple_results = True
348
    impl_static_args = (4, 5)  # out_dtype, act_enum
349
350
351
352
    inner_primitive = None
    outer_primitive = None

    @staticmethod
353
354
355
    def abstract(
        x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, act_enum
    ):  # pylint: disable=unused-argument
356
357
358
359
360
361
362
363
364
365
366
        """
        te_act_lu_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

367
        assert x_aval.shape[-2] == 1 or x_aval.shape[-2] == 2
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        hidden_size = x_aval.shape[-1]
        batch_shape = x_aval.shape[:-2]
        out_shape = (batch_shape) + (hidden_size,)
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return out_aval, updated_amax_aval

    @staticmethod
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum):
        """
        te_gated_act_lu_p lowering rules
        """
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        hidden_size = ir_x_shape[-1]
        batch_shape = ir_x_shape[:-2]
        batch_size = reduce(operator.mul, batch_shape)
        out_shape = batch_shape + [hidden_size]
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

407
408
        opaque = transformer_engine_jax.pack_common_descriptor(
            (batch_size, hidden_size),
409
410
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
411
412
            act_enum,
        )
413

414
415
416
        out = custom_caller(
            ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
        )
417
418
419
420
421
422
423
424
425

        return out

    @staticmethod
    def impl(x, amax, scale, scale_inv, out_dtype, act_enum):
        """
        to describe implementation
        """
        assert ActLuFp8Primitive.inner_primitive is not None
426
427
428
        out, updated_amax = ActLuFp8Primitive.inner_primitive.bind(
            x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
        )
429
430
431
432
433
434
435
436
437
438
439
440
441
        return out, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, act_enum):
        """
        to describe batch rules for vmap
        """
        check_valid_batch_dims(batch_dims)
        assert ActLuFp8Primitive.outer_primitive is not None
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, amax_bdim
442
443
444
445
446
447
        return (
            ActLuFp8Primitive.outer_primitive.bind(
                x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
            ),
            out_bdims,
        )
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466

    @staticmethod
    def infer_sharding_from_operands(out_dtype, act_enum, mesh, arg_infos, result_infos):
        del out_dtype, result_infos, act_enum
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, act_enum, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (out_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
467
468
469
            local_x, local_amax = ActLuFp8Primitive.impl(
                x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
            )
470
471
472
473
474
475
476
477
478
479
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)

            return local_x, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(ActLuFp8Primitive)


480
481
482
483
484
485
486
487
def act_lu_fp8(
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: jnp.dtype,
    activation_type: Sequence[Union[str, Callable]],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
488
489
490
491
492
493
    """
    act wrapper
    Return FP8(act_lu(x))
    Input shape: (N, 1, H) for non-gated activations
                 (N, 2, H) for gated activations
    """
494
495
496
497
498
    if not ActLuFp8Primitive.enabled():
        act_lu_output = _jax_act_lu(x, activation_type)
        casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype)
        return casted_output, updated_amax

499
    act_type_id = ActivationEnum[activation_type].value
500
501
502
    return ActLuFp8Primitive.outer_primitive.bind(
        x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id
    )