activation.py 16.4 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
from typing import Tuple, Sequence, Union, Callable
import operator
7
from functools import reduce, partial
8

9
import jax
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding

from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import NVTE_Activation_Type

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    jax_dtype_to_ir_dtype,
24
    get_padded_spec,
25
)
26
from .quantization import _jax_cast_fp8
27
28
29
from ..sharding import all_reduce_max_along_all_axes_except_PP


30
__all__ = ["act_lu", "dact_lu", "act_lu_fp8"]
31
32
33


ActivationEnum = {
34
35
36
37
38
39
40
41
42
43
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
44
45
46
}


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


def _jax_act_lu(inputs, activation_type):
    """
    JAX native activation implementation
    """
    x = jnp.split(inputs, len(activation_type), axis=-2)
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
    x = jnp.squeeze(x, axis=-2)
    return x


76
77
78
79
class ActLuPrimitive(BasePrimitive):
    """
    Activation Forward Primitive
    """
80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    name = "te_act_lu"
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
    impl_static_args = (1,)

    @staticmethod
    def abstract(x_aval, *, act_enum):  # pylint: disable=unused-argument
        """
        act_lu abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        x_shape = x_aval.shape
96
        assert x_shape[-2] == 2 or x_shape[-2] == 1
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        hidden_size = x_shape[-1]
        batch_shapes = x_shape[:-2]
        out_aval = core.raise_to_shaped(x_aval)
        out_shape = (batch_shapes) + (hidden_size,)
        out_aval = out_aval.update(shape=out_shape, dtype=dtype)

        return out_aval

    @staticmethod
    def lowering(ctx, x, *, act_enum):
        """
        act_lu lowering rules
        """
        (x_aval,) = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]

        out_types = [
            ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
        ]
        operands = [x]
        operand_shapes = [ir_x_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        hidden_size = ir_x_shape[-1]
        batch_size = reduce(operator.mul, ir_x_shape[:-2])
        in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
        opaque = transformer_engine_jax.pack_common_descriptor(
127
128
            (batch_size, hidden_size), in_dtype, in_dtype, act_enum
        )
129
130
131

        out = custom_caller(ActLuPrimitive.name, args, opaque, False)

132
        return out
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    @staticmethod
    def impl(x, act_enum):
        assert ActLuPrimitive.inner_primitive is not None
        out = ActLuPrimitive.inner_primitive.bind(x, act_enum=act_enum)
        return out

    @staticmethod
    def batcher(batched_args, batch_dims, *, act_enum):
        """
        act_lu batcher
        """
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
147
148
        (inputs,) = batched_args
        (inputs_bdim,) = batch_dims
149
150
151
152
153
154
155
156
157

        out_bdims = inputs_bdim
        return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_enum), out_bdims

    @staticmethod
    def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
        """
        act_lu infer_sharding_from_operands
        """
158
        del result_infos, act_enum  # Unused.
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        return out_sharding

    @staticmethod
    def partition(act_enum, mesh, arg_infos, result_infos):
        """
        act_lu partitioning
        """
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))

        def sharded_impl(x):
            return ActLuPrimitive.impl(x, act_enum=act_enum)

        return mesh, sharded_impl, out_sharding, arg_shardings


register_primitive(ActLuPrimitive)


def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
    """
    act_lu wrapper
    Return act_lu(inputs)
    Input shape: (N, 1, H) for non-gated activations
                 (N, 2, H) for gated activations
    """
189
190
191
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(inputs, activation_type)

192
193
194
195
196
197
198
199
    act_type_id = ActivationEnum[activation_type]
    return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)


class DActLuPrimitive(BasePrimitive):
    """
    Dgated ActLu Primitive
    """
200

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    name = "te_dact_lu"
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
    impl_static_args = (2,)

    @staticmethod
    def abstract(dz_aval, x_aval, *, act_enum):  # pylint: disable=unused-argument
        """
        dact_lu abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        for axis in range(len(dz_aval.shape) - 1):
            assert dz_aval.shape[axis] == x_aval.shape[axis]
217
        assert x_aval.shape[-2] == 2 or x_aval.shape[-2] == 1
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

        i_hidden_size = dz_aval.shape[-1]
        g_hidden_size = x_aval.shape[-1]
        assert i_hidden_size == g_hidden_size
        out_aval = core.raise_to_shaped(x_aval)

        return out_aval

    @staticmethod
    def lowering(ctx, dz, x, *, act_enum):
        """
        dact_lu lowering rules
        """
        in_aval, gi_aval = ctx.avals_in
        assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gi_aval.dtype == in_aval.dtype
        ir_in_type = ir.RankedTensorType(dz.type)
        ir_in_shape = ir_in_type.shape
        gi_type = ir.RankedTensorType(x.type)
        gi_shape = gi_type.shape
238
        #        assert ir_in_shape == gi_shape
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        for axis in range(len(ir_in_shape) - 1):
            assert ir_in_shape[axis] == gi_shape[axis]

        ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
        i_hidden_size = ir_in_shape[-1]
        g_hidden_size = gi_shape[-1]
        assert i_hidden_size == g_hidden_size
        out_dtype = ir_in_type.element_type
        out_shape = gi_shape

        out_types = [
            ir.RankedTensorType.get(out_shape, out_dtype),
        ]
        operands = [dz, x]
        operand_shapes = [ir_in_shape, gi_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
257
258
259
        opaque = transformer_engine_jax.pack_common_descriptor(
            (ir_batch_size, i_hidden_size), in_dtype, in_dtype, act_enum
        )
260
261
262

        out = custom_caller(DActLuPrimitive.name, args, opaque, False)

263
        return out
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

    @staticmethod
    def impl(dz, x, act_enum):
        """
        dact_lu implementation
        """
        assert DActLuPrimitive.inner_primitive is not None
        dx = DActLuPrimitive.inner_primitive.bind(dz, x, act_enum=act_enum)
        return dx

    @staticmethod
    def batcher(batched_args, batch_dims, *, act_enum):
        """
        dact_lu batcher
        """
        check_valid_batch_dims(batch_dims)
        assert DActLuPrimitive.outer_primitive is not None
        dz, x = batched_args
        _, x_bdim = batch_dims

        out_bdims = x_bdim
        return DActLuPrimitive.outer_primitive.bind(dz, x, act_enum=act_enum), out_bdims

    @staticmethod
    def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
        """
        dact_lu infer_sharding_from_operands
        """
292
        del result_infos, act_enum  # Unused.
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        act_lu_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*act_lu_out_spec))
        return dx_sharding

    @staticmethod
    def partition(act_enum, mesh, arg_infos, result_infos):
        """
        dact_lu partition
        """
        del result_infos
        dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = dx_sharding

        def sharded_impl(dz, x):
            return DActLuPrimitive.impl(dz, x, act_enum=act_enum)

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DActLuPrimitive)


316
317
318
def dact_lu(
    inputs: jnp.ndarray, act_lu_inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]
) -> jnp.ndarray:
319
320
321
322
    """
    dact_lu fusion wrapper
    Return dgated_act_lu(inputs)
    """
323
324
325
326
327

    if not DActLuPrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs)
        return vjp_func(inputs)[0]

328
329
330
331
332
333
334
335
    act_type_id = ActivationEnum[activation_type]
    return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)


class ActLuFp8Primitive(BasePrimitive):
    """
    ActLu FP8 Primitive
    """
336

337
338
    name = "te_act_lu_fp8"
    multiple_results = True
339
    impl_static_args = (4, 5)  # out_dtype, act_enum
340
341
342
343
    inner_primitive = None
    outer_primitive = None

    @staticmethod
344
345
346
    def abstract(
        x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, act_enum
    ):  # pylint: disable=unused-argument
347
348
349
350
351
352
353
354
355
356
357
        """
        te_act_lu_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

358
        assert x_aval.shape[-2] == 1 or x_aval.shape[-2] == 2
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
        hidden_size = x_aval.shape[-1]
        batch_shape = x_aval.shape[:-2]
        out_shape = (batch_shape) + (hidden_size,)
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return out_aval, updated_amax_aval

    @staticmethod
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum):
        """
        te_gated_act_lu_p lowering rules
        """
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        hidden_size = ir_x_shape[-1]
        batch_shape = ir_x_shape[:-2]
        batch_size = reduce(operator.mul, batch_shape)
        out_shape = batch_shape + [hidden_size]
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

398
399
        opaque = transformer_engine_jax.pack_common_descriptor(
            (batch_size, hidden_size),
400
401
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
402
403
            act_enum,
        )
404

405
406
407
        out = custom_caller(
            ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
        )
408
409
410
411
412
413
414
415
416

        return out

    @staticmethod
    def impl(x, amax, scale, scale_inv, out_dtype, act_enum):
        """
        to describe implementation
        """
        assert ActLuFp8Primitive.inner_primitive is not None
417
418
419
        out, updated_amax = ActLuFp8Primitive.inner_primitive.bind(
            x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
        )
420
421
422
423
424
425
426
427
428
429
430
431
432
        return out, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, act_enum):
        """
        to describe batch rules for vmap
        """
        check_valid_batch_dims(batch_dims)
        assert ActLuFp8Primitive.outer_primitive is not None
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, amax_bdim
433
434
435
436
437
438
        return (
            ActLuFp8Primitive.outer_primitive.bind(
                x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
            ),
            out_bdims,
        )
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457

    @staticmethod
    def infer_sharding_from_operands(out_dtype, act_enum, mesh, arg_infos, result_infos):
        del out_dtype, result_infos, act_enum
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, act_enum, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (out_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
458
459
460
            local_x, local_amax = ActLuFp8Primitive.impl(
                x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
            )
461
462
463
464
465
466
467
468
469
470
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)

            return local_x, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(ActLuFp8Primitive)


471
472
473
474
475
476
477
478
def act_lu_fp8(
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: jnp.dtype,
    activation_type: Sequence[Union[str, Callable]],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
479
480
481
482
483
484
    """
    act wrapper
    Return FP8(act_lu(x))
    Input shape: (N, 1, H) for non-gated activations
                 (N, 2, H) for gated activations
    """
485
486
487
488
489
    if not ActLuFp8Primitive.enabled():
        act_lu_output = _jax_act_lu(x, activation_type)
        casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype)
        return casted_output, updated_amax

490
    act_type_id = ActivationEnum[activation_type]
491
492
493
    return ActLuFp8Primitive.outer_primitive.bind(
        x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id
    )