activation.py 40.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8
from packaging import version
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
17
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
18
19
20
21

from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
22
    te_dtype_to_jax_dtype,
23
    get_padded_spec,
24
25
26
27
28
29
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
30
from .quantization import _jax_dbias, _quantize_dbias_impl
31
32
33
34
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
    Quantizer,
35
    QuantizeLayout,
36
37
    DelayedScaleQuantizer,
    ScalingMode,
38
39
)

40
41
42
43
44
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

45
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
46
47
48


ActivationEnum = {
49
50
51
52
53
54
55
56
57
58
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
59
60
61
}


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


77
78
class ActLuPrimitive(BasePrimitive):
    """
79
    ActLu Primitive
80
    """
81

82
83
84
85
86
87
88
89
90
91
92
93
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
        9,
    )  # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, scale_shapes, is_outer
94
95
96
97
    inner_primitive = None
    outer_primitive = None

    @staticmethod
98
99
100
101
102
103
104
105
106
107
108
109
110
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
111
        """
112
        te_act_lu_p abstract
113
        """
114
        del act_enum, scale_shapes
115
116
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
117
        assert scale_aval is None or scale_aval.dtype == jnp.float32
118
119
120
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
121
        )
122

123
124
125
126
127
        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
            "Current tensor scaling is not supported for fused activation and quantization. Please"
            " do activation in higher-precision then quantize with current tensor scaling."
        )

128
        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
129
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
130

131
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
132

133
134
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
135
136
137
138
139
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
        if not is_2x:
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
140
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
141
142
143
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
144
145

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
146
147

    @staticmethod
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
162
        """
163
        te_gated_act_lu_p lowering rules
164
        """
165
166
        del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
        x_aval, scale_aval = ctx.avals_in
167
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
168
        assert scale_aval is None or scale_aval.dtype == jnp.float32
169

170
        out = ffi.ffi_lowering(ActLuPrimitive.name)(
171
            ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
172
        )
173
        return out
174
175

    @staticmethod
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    def impl(
        x,
        scale,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
        """
        to describe implementation
        """
        del is_outer
192
        assert ActLuPrimitive.inner_primitive is not None
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
210
211
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
212
213
214
215
216
217
218
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
219

220
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
221
222

    @staticmethod
223
224
225
226
227
228
229
230
231
232
233
234
235
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
236
        """
237
        to describe batch rules for vmap
238
        """
239
        del act_len, is_outer
240
241
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
242
243
244
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
245

246
247
248
249
250
251
252
253
254
255
256
257
258
259
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
            ),
            out_bdims,
        )
260
261

    @staticmethod
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            scale_shapes,
            act_len,
            is_outer,
        )  # Unused.
284
        x_spec = get_padded_spec(arg_infos[0])
285
286
287
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
288
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
289

290
        if is_2x:
291
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
292
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
293
294
295
296
297
298
299
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
300
301

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
302
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
303
            scale_inv_spec = amax_spec = scale_spec
304
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
305
306
307
308
309
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

310
        scale_inv_sharding = NamedSharding(
311
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
312
        )
313
314
315
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
316
        )
317

318
319
320
321
322
323
324
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
325
326

    @staticmethod
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer  # Unused.
341
        x_spec = get_padded_spec(arg_infos[0])
342
343
344
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
345
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
346

347
        if is_2x:
348
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
349
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
350
351
352
353
354
355
356
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
357
358

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
359
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
360
            scale_inv_spec = amax_spec = scale_spec
361
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
362
363
364
365
366
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

367
        scale_inv_sharding = NamedSharding(
368
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
369
        )
370
371
372
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
373
        )
374
375

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
376
377
378
379
380
381
382
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
383

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        def sharded_impl(x, scale):
            local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
                ActLuPrimitive.impl(
                    x,
                    scale,
                    out_dtype=out_dtype,
                    act_enum=act_enum,
                    act_len=act_len,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    scale_shapes=scale_shapes,
                    is_outer=True,
                )
            )
399

400
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
401
402
403
404
405
406
407
408
409
410
411
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
412

413
        return mesh, sharded_impl, out_shardings, arg_shardings
414

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
        del out_dtype, act_enum, act_len, scale_dtype, scale_shapes, is_outer, mesh, result_types

        x_rank = len(value_types[0].shape)
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
            x_rank - 1, unique_var="i", flatten_axis=-2
        )
        x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
        out = (*x_axes[:-2], x_axes[-1])
        scale_inv = scale_rules.rowwise_rule
        colwise_scale_inv = scale_rules.colwise_rule

        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(
                    multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
                )
            else:
                colwise_out = out
        else:
            colwise_out = ("j",)
            colwise_scale_inv = ("k",)

        # amax is always a unit tensor.
        amax = ("l",)

        return SdyShardingRule(
            (
                x_axes,
                "…1",
            ),
            (out, colwise_out, scale_inv, colwise_scale_inv, amax),
            **scale_rules.factor_sizes,
        )

463

464
register_primitive(ActLuPrimitive)
465
466


467
# TODO(Jeremy): replace is_2x with q_layout
468
class DActLuDBiasQuantizePrimitive(BasePrimitive):
469
    """
470
    DActLu DBias Cast Transpose Primitive
471
    """
472

473
474
475
476
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
    # out_dtype, scaling_mode, is_2x, scale_dtype, scale_shapes, is_dbias, act_enum, act_len, is_outer
    impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
477
478
479
480
    inner_primitive = None
    outer_primitive = None

    @staticmethod
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
496
        """
497
        te_dact_dbias_quantize_p abstract
498
        """
499
        del act_enum, scale_shapes
500
501
502
503
504
505
506
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
507
        assert scale_aval.dtype == jnp.float32
508
509
510
511
512
513

        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
            "Current tensor scaling is not supported for fused dact and quantization. Please do"
            " dact in higher-precision then quantize with current tensor scaling."
        )

514
        ir_hidden_size = dz_aval.shape[-1]
515
        gi_hidden_size = act_len * x_aval.shape[-1]
516
517
518
        assert act_len * ir_hidden_size == gi_hidden_size
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
519

520
521
522
523
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
524
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
525
        if is_2x:
526
527
528
529
            if scaling_mode in (
                ScalingMode.DELAYED_TENSOR_SCALING.value,
                ScalingMode.CURRENT_TENSOR_SCALING.value,
            ):
530
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
531
            else:
532
533
534
535
536
537
538
539
540
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
541

542
        if is_dbias:
543
            dbias_shape = (act_len, ir_hidden_size)
544
545
546
547
548
549
550
551
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
                is_2x,
            )
552
553
554
555
556
557
558
559
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
560

561
562
563
564
565
566
567
568
569
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
570
571

    @staticmethod
572
    def outer_abstract(*args, **kwargs):
573
        """
574
        te_dact_dbias_quantize_p outer abstract
575
        """
576
577
578
579
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
            DActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
580
581

    @staticmethod
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
    def lowering(
        ctx,
        dz,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
598
        """
599
        te_dact_dbias_quantize_p lowering rules
600
        """
601
602
603
604
605
606
607
608
609
610
        del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
        dz_aval, x_aval, scale_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert scale_aval.dtype == jnp.float32
        return ffi.ffi_lowering(DActLuDBiasQuantizePrimitive.name)(
            ctx,
            dz,
            x,
            scale,
611
            scaling_mode=scaling_mode.value,
612
613
614
615
            is_2x=is_2x,
            is_dbias=is_dbias,
            act_enum=int(act_enum),
        )
616
617

    @staticmethod
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    def impl(
        dz,
        x,
        scale,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
632
        """
633
        te_dact_dbias_quantize_p impl
634
        """
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
        del is_outer
        assert DActLuDBiasQuantizePrimitive.inner_primitive is not None
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
            DActLuDBiasQuantizePrimitive.inner_primitive.bind(
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
655
656
657
658
659
660
661
662
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
663
            )
664
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
665
666

    @staticmethod
667
668
669
670
671
672
673
674
675
676
677
678
679
680
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
681
        """
682
        to describe batch rules for vmap
683
        """
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
        del is_outer
        check_valid_batch_dims(batch_dims)
        assert DActLuDBiasQuantizePrimitive.outer_primitive is not None
        dz, x, scale = batched_args
        _, x_bdim, scale_bdim = batch_dims

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
            DActLuDBiasQuantizePrimitive.outer_primitive.bind(
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
            ),
            out_bdims,
        )
714
715

    @staticmethod
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del out_dtype, result_infos, act_enum
731
        del scale_dtype, scale_shapes, act_len, is_outer
732
        x_spec = get_padded_spec(arg_infos[1])
733
        scale_spec = get_padded_spec(arg_infos[2])
734

735
736
737
738
        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Partitioned current tensor scaling is not yet supported."

739
740
741
742
        out_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
        )
        if is_2x:
743
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
744
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
745
746
747
748
749
750
751
752
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
        )

753
754
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
755
            mesh,
756
            PartitionSpec(*dbias_spec),
757
758
            desc="DActLuDBiasQuantizePrimitive.dbias",
        )
759
760

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
761
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
762
            scale_inv_spec = amax_spec = scale_spec
763
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
764
765
766
767
768
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

769
        scale_inv_sharding = NamedSharding(
770
            mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
771
772
        )
        amax_sharding = NamedSharding(
773
            mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax"
774
        )
775
776
777
778
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv",
779
780
781
782
783
784
785
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
786
            dbias_sharding,
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
806
807
808
809
810
811
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
        )

812
        if is_2x:
813
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
814
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
815
816
817
818
819
820
821
822
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
        )

823
824
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
825
            mesh,
826
            PartitionSpec(*dbias_spec),
827
828
            desc="DActLuDBiasQuantizePrimitive.dbias",
        )
829
830

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
831
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
832
            scale_inv_spec = amax_spec = scale_spec
833
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
834
835
836
837
838
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

839
        scale_inv_sharding = NamedSharding(
840
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
841
        )
842
843
844
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
845
846
        )

847
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
848

849
850
851
852
853
854
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
855
            dbias_sharding,
856
        )
857

858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
        def sharded_impl(dz, x, scale):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
                DActLuDBiasQuantizePrimitive.impl(
                    dz,
                    x,
                    scale,
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    scale_shapes=scale_shapes,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
                    is_outer=True,
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

880
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
881
882
883
884
885
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
886
887
888

        return mesh, sharded_impl, out_shardings, arg_shardings

889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
        del out_dtype, scale_dtype, scale_shapes, act_enum, act_len, is_outer, mesh, result_types

        x_rank = len(value_types[1].shape)
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
            x_rank, unique_var="i", flatten_axis=-2
        )
        x_axes = scale_rules.input_spec
        out = x_axes
        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
            else:
                colwise_out = tuple(x_axes)
        else:
            colwise_out = ("j",)

        dbias = x_axes[-2:] if is_dbias else ("k",)
        amax = ("…4",)

        return SdyShardingRule(
            (("…0",), tuple(x_axes), ("…2",)),
            (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
            **scale_rules.factor_sizes,
        )

929

930
register_primitive(DActLuDBiasQuantizePrimitive)
931
932


933
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
934
    """
935
    JAX native activation implementation
936
    """
937
938
939
940
941
942
943
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )

    x = jnp.split(inputs, act_len, axis=-2)
944
945
946
947
948
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
949
    x = jnp.squeeze(x, axis=-2)
950
    if quantizer:
951
        return quantizer.quantize(x, flatten_axis=-1)
952
    return x
953
954


955
956
957
958
959
960
961
def _jax_quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
):
962
    """
963
    JAX implementation of dact_lu and dbias with optional quantization
964
    """
965
966
967
968
969
970
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

971
972
973
974
    _, vjp_func = jax.vjp(
        partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
    )
    (dx,) = vjp_func(dz.astype(jnp.float32))
975

976
977
    dbias = None
    if is_dbias:
978
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
979

980
    if quantizer is not None:
981
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
982
983
    else:
        dx = dx.astype(x.dtype)
984

985
    return dx, dbias
986
987


988
989
990
991
992
993
994
995
996
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
997
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
998
999
1000
1001
1002
1003
1004
1005
1006
1007
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
    act_type_id = ActivationEnum[activation_type].value
1008
1009
1010
1011
1012
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
1013

1014
1015
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(x, activation_type, quantizer)
1016

1017
    # TE/common does not support colwise-only quantization yet
1018
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1019
        return _jax_act_lu(x, activation_type, quantizer)
1020

1021
1022
1023
1024
1025
1026
1027
1028
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
        f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
1029
    output_shape = (*x.shape[:-2], x.shape[-1])
1030
1031
1032
1033
1034
1035
1036

    if quantizer is None:
        out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
            x,
            scale,
            out_dtype=x.dtype,
            act_enum=act_type_id,
1037
            act_len=act_len,
1038
            scaling_mode=ScalingMode.NO_SCALING.value,
1039
1040
1041
1042
            is_2x=False,
            scale_dtype=jnp.float32,
            scale_shapes=((), ()),
            is_outer=True,
1043
        )
1044
1045
        out = out.reshape(output_shape)
        return out
1046

1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = act_lu(
            x=x.astype(jnp.float32),
            activation_type=activation_type,
            quantizer=None,
        )
        out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
        return out

1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
1071
        act_len=act_len,
1072
1073
1074
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
1075
1076
        # output does not have act axis
        scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1),
1077
1078
        is_outer=True,
    )
1079

1080
1081
1082
1083
1084
1085
1086
1087
1088
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1089
1090
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
1091
    )
1092
1093


1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
1106
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1107
1108
1109
1110
1111
1112
1113
1114
1115
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
1116

1117
1118
1119
1120
1121
1122
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1123
1124
    if not DActLuDBiasQuantizePrimitive.enabled():
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1125

1126
    # TE/common does not support colwise-only quantization yet
1127
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1128
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1129

1130
1131
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1132
1133
1134
1135
1136
1137
        out = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
        )
        return _quantize_dbias_impl(
            out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1138

1139
    is_gated = act_len == 2
1140
1141
1142
1143
1144
1145
1146
1147
1148
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1149
            flatten_axis=-2,
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
        )
        if war_output is not None:
            return war_output

    scale = jnp.empty((), jnp.float32)

    act_type_id = ActivationEnum[activation_type]

    if quantizer is None:
        output, _, _, _, _, _ = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
            dz,
            x,
            scale,
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
1166
            scaling_mode=ScalingMode.NO_SCALING.value,
1167
1168
1169
1170
1171
            is_2x=False,  # unused
            scale_dtype=jnp.float32,  # unused
            scale_shapes=((), ()),  # unused
            is_dbias=False,
            act_enum=act_type_id,
1172
            act_len=act_len,
1173
1174
1175
1176
            is_outer=True,
        )
        dbias = None
        if is_dbias:
1177
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
1178
1179
        return output.astype(x.dtype), dbias

1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = dact_lu(
            dz=dz.astype(jnp.float32),
            x=x.astype(jnp.float32),
            activation_type=activation_type,
            quantizer=None,
        )
        out, dbias = _quantize_dbias_impl(
            out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
        )
        return out, dbias

1193
1194
1195
1196
1197
1198
1199
1200
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
        )
1201
1202
1203
        out, dbias = _quantize_dbias_impl(
            dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
        return out, dbias

    out_shape = x.shape

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
    ) = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
        dz,
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
1223
1224
        # output has act axis
        scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2),
1225
1226
        is_dbias=is_dbias,
        act_enum=act_type_id,
1227
        act_len=act_len,
1228
1229
        is_outer=True,
    )
1230

1231
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
1232
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1244
1245
1246
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1247
    )
1248

1249
    return out, dbias
1250
1251


1252
1253
def dact_lu(
    dz: jnp.ndarray,
1254
1255
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1256
1257
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
1258
    """
1259
    Backward pass for activation with optional quantization.
1260

1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1276
    )
1277
    return output