gemm.py 59.6 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""

6
import math
Alp Dener's avatar
Alp Dener committed
7
8
9
10
11
import operator
from collections.abc import Iterable
from typing import Tuple, Sequence, Union
from functools import partial, reduce

12
13
import jax
import jax.numpy as jnp
Alp Dener's avatar
Alp Dener committed
14
15
16
17
18
19
from jax import dtypes
from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental.custom_partitioning import SdyShardingRule

import transformer_engine_jax as tex
from transformer_engine_jax import get_num_compute_streams
20
21

from .base import BasePrimitive, register_primitive
22
from .quantization import grouped_quantize
23
24
from ..quantize import (
    ScaledTensor,
Alp Dener's avatar
Alp Dener committed
25
    ScaledTensor2x,
26
    GroupedScaledTensor1x,
27
28
    ScalingMode,
    Quantizer,
29
    GroupedQuantizer,
30
    QuantizeConfig,
31
32
    QuantizerSet,
    QuantizeLayout,
33
    noop_quantizer_set,
34
    is_fp8_gemm_with_all_layouts_supported,
Alp Dener's avatar
Alp Dener committed
35
    apply_padding_to_scale_inv,
36
)
Alp Dener's avatar
Alp Dener committed
37
from .misc import get_padded_spec
38
39


Alp Dener's avatar
Alp Dener committed
40
41
42
43
44
45
46
47
__all__ = [
    "gemm",
    "grouped_gemm",
    "gemm_uses_jax_dot",
    "sanitize_dims",
    "get_non_contracting_dims",
    "transpose_dims",
]
48
49


50
num_cublas_streams = get_num_compute_streams()
51
52
53
54


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
Alp Dener's avatar
Alp Dener committed
55
    if tex.get_device_compute_capability(0) >= 90:
56
57
58
59
        return 33_554_432
    return 4_194_304


Alp Dener's avatar
Alp Dener committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]:
    """Convert relative (negative) indexes to absolute dimension numbers."""
    dims_ = dims if isinstance(dims, Iterable) else (dims,)
    if len(dims_) == 0:
        return dims_
    return tuple(ndim + dim if dim < 0 else dim for dim in dims_ if dim is not None)


def get_non_contracting_dims(ndim, contracting_dims):
    """Return a tuple of dimensions not included in the contracting dimensions."""
    contracting_dims = sanitize_dims(ndim, contracting_dims)
    return tuple(dim for dim in range(ndim) if dim not in contracting_dims)


def transpose_dims(ndim, dims_to_transpose, flatten_axis=-1):
    """Compute the new dimension numbers after transpose."""
    if len(dims_to_transpose) == 0:
        return dims_to_transpose
    flatten_axis = ndim - flatten_axis if flatten_axis > 0 else flatten_axis
    transposed_dims = (*range(flatten_axis, ndim), *range(flatten_axis))
    return tuple(transposed_dims.index(dim) for dim in dims_to_transpose)


def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool:
    lhs, rhs, e4m3, e5m2 = map(
        dtypes.canonicalize_dtype,
        (
            lhs_dtype,
            rhs_dtype,
            jnp.float8_e4m3fn,
            jnp.float8_e5m2,
        ),
    )

    # FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3)
    if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3):
        return True

    # Any other combination of data types is not supported
    return False


def _get_gemm_layout(
    operand_ndims: Tuple[int, int], contracting_dims: Tuple[Sequence[int], Sequence[int]]
) -> Tuple[bool, bool]:
    lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims)
    lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting
    rhs_is_transposed = operand_ndims[1] - 1 in rhs_contracting
    return lhs_is_transposed, rhs_is_transposed


def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims):
    lhs_q = lhs
    rhs_q = rhs

    if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None:
        lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0])
        lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims
        need_lhs_colwise = lhs_is_transposed and (
            lhs_quantizer.scaling_mode.is_1d_block_scaling()
            or not is_fp8_gemm_with_all_layouts_supported()
        )
        flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims)
        lhs_q = lhs_quantizer.quantize(
            lhs,
            is_rowwise=not need_lhs_colwise,
            is_colwise=need_lhs_colwise,
            flatten_axis=flatten_axis,
        )

    if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None:
        rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1])
        rhs_is_transposed = rhs.ndim - 1 in rhs_cdims
        need_rhs_colwise = not rhs_is_transposed and (
            rhs_quantizer.scaling_mode.is_1d_block_scaling()
            or not is_fp8_gemm_with_all_layouts_supported()
        )
        flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1
        rhs_q = rhs_quantizer.quantize(
            rhs,
            is_rowwise=not need_rhs_colwise,
            is_colwise=need_rhs_colwise,
            flatten_axis=flatten_axis,
        )

    assert not isinstance(lhs_q, ScaledTensor2x)
    assert not isinstance(rhs_q, ScaledTensor2x)

    return lhs_q, rhs_q


class GemmPrimitive(BasePrimitive):
    """
    Primitive for cuBLAS GEMM
    """

    name = "te_gemm_ffi"
    multiple_results = True
158
    impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)
Alp Dener's avatar
Alp Dener committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        lhs,
        lhs_scale_inv,
        rhs,
        rhs_scale_inv,
        bias,
        gelu_input,
        out_dtype,
        contracting_dims,
        batched_dims,
        lhs_quantized_colwise,
        rhs_quantized_colwise,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
180
181
        sequence_parallel_output,
        sequence_dim,
Alp Dener's avatar
Alp Dener committed
182
183
    ):
        del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator
184
185
186
187
        del (
            sequence_parallel_output,
            sequence_dim,
        )
Alp Dener's avatar
Alp Dener committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

        def _dims_are_consecutive(dims):
            if len(dims) <= 1:
                return True
            return sorted(dims) == list(range(min(dims), max(dims) + 1))

        # Sanity-check operand layouts and types
        operand_ndims = (lhs.ndim, rhs.ndim)

        (
            lhs_contracting_dims,
            rhs_contracting_dims,
        ) = map(sanitize_dims, operand_ndims, contracting_dims)
        assert _dims_are_consecutive(lhs_contracting_dims), (
            "cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got "
            f"{lhs_contracting_dims}."
        )
        assert _dims_are_consecutive(rhs_contracting_dims), (
            "cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got "
            f"{rhs_contracting_dims}."
        )

        (
            lhs_batch_dims,
            rhs_batch_dims,
        ) = map(sanitize_dims, operand_ndims, batched_dims)
        assert _dims_are_consecutive(lhs_batch_dims), (
            "cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got "
            f"{lhs_batch_dims}."
        )
        assert _dims_are_consecutive(rhs_batch_dims), (
            "cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got "
            f"{rhs_batch_dims}."
        )
        if len(lhs_batch_dims) == 0:
            assert (
                len(rhs_batch_dims) == 0
            ), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched."
        elif len(rhs_batch_dims) != 0:
            assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all(
                bdim in rhs_contracting_dims for bdim in rhs_batch_dims
            ), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched."

        lhs_contracting_size, rhs_contracting_size = map(
            lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]),
            (lhs.shape, rhs.shape),
            (lhs_contracting_dims, rhs_contracting_dims),
        )
        assert lhs_contracting_size == rhs_contracting_size, (
            "cuBLAS GEMM operands have incompatible contracting dimensions: "
            f"{lhs.shape} @ idx {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}."
        )

        lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims)
        if scaling_mode != ScalingMode.NO_SCALING:
            assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), (
                "cuBLAS GEMM quantized operands have incompatible data types: "
                f"{lhs.dtype} x {rhs.dtype}."
            )
            assert (
                lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0
            ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands."
            if (
                scaling_mode != ScalingMode.MXFP8_1D_SCALING
                and not tex.is_non_nt_fp8_gemm_supported()
            ):
                assert not lhs_is_transposed and rhs_is_transposed, (
                    "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) "
                    "require non-transposed LHS and transposed RHS operands "
                    "(`contracting_dims=((-1, ), (-1, ))`)."
                )

        # Determine output shape and dtype
        assert (
            dtypes.canonicalize_dtype(out_dtype).itemsize > 1
        ), "cuBLAS GEMM custom op does not support 8-bit quantized output types."
        lhs_non_contracting_shape, rhs_non_contracting_shape = map(
            lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims],
            (lhs.shape, rhs.shape),
            (lhs_contracting_dims, rhs_contracting_dims),
        )
        out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape)
        output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)

        # Validate bias
        bias_shape = (0,)
        bias_dtype = out_dtype
        if fuse_bias:
            expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape)
            if not grad:
                assert bias.size == expected_bias_size, (
                    "cuBLAS GEMM bias tensor has incorrect shape, "
                    f"expected ({expected_bias_size}, ) but found {bias.shape}."
                )
                assert bias.dtype == out_dtype, (
                    "cuBLAS GEMM bias tensor has incorrect data type, "
                    f"expected {bias_dtype} but found {bias.dtype}."
                )
                bias_shape = bias.shape
            else:
                bias_shape = rhs_non_contracting_shape
        bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype)

        # Validate pre-GeLU
        pre_gelu_shape = (0,)
        pre_gelu_dtype = out_dtype
        if fuse_gelu:
            pre_gelu_shape = out_shape
            if grad:
                pre_gelu_ndim = len(pre_gelu_shape)
                assert gelu_input.ndim == pre_gelu_shape and all(
                    gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim)
                ), (
                    "cuBLAS GEMM pre-GeLU tensor has incorrect shape, "
                    f"expected {pre_gelu_shape} but found {gelu_input.shape}."
                )
                assert gelu_input.dtype == out_dtype, (
                    "cuBLAS GEMM pre-GeLU tensor has incorrect data type, "
                    f"expected {pre_gelu_dtype} but found {gelu_input.dtype}."
                )
        pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)

        # Need extra workspace for swizzled scale factors
        lhs_swizzle_size = 0
        rhs_swizzle_size = 0
        swizzle_dtype = jnp.uint8
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
            lhs_swizzle_size = lhs_scale_inv.size
            rhs_swizzle_size = rhs_scale_inv.size
        lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype)
        rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype)

        # Declare cuBLAS workspace
        # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
        # necessarily 256 bytes aligned, we add some padding to ensure alignment.
        workspace_size = get_cublas_workspace_size_bytes() + 256
        workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)

        return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace

    @staticmethod
    def outer_abstract(*args, **kwargs):
        outputs = GemmPrimitive.abstract(*args, **kwargs)
        return outputs[:-3]  # discard workspace arrays

    @staticmethod
    def lowering(
        ctx,
        lhs,
        lhs_scale_inv,
        rhs,
        rhs_scale_inv,
        bias,
        gelu_input,
        out_dtype,
        contracting_dims,
        batched_dims,
        lhs_quantized_colwise,
        rhs_quantized_colwise,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
352
353
        sequence_parallel_output,
        sequence_dim,
Alp Dener's avatar
Alp Dener committed
354
355
    ):
        del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype
356
357
        del sequence_parallel_output, sequence_dim

Alp Dener's avatar
Alp Dener committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
        lhs_aval, _, rhs_aval, *_ = ctx.avals_in
        lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
        lhs_transposed, rhs_transposed = _get_gemm_layout(
            (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims)
        )

        args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input)
        kwargs = {
            "scaling_mode": int(scaling_mode.value),
            "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
            "rhs_axis_boundary": min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
            "lhs_transposed": lhs_transposed,
            "rhs_transposed": rhs_transposed,
            "fuse_bias": fuse_bias,
            "fuse_gelu": fuse_gelu,
            "grad": grad,
            "use_split_accumulator": use_split_accumulator,
        }

        operand_output_aliases = {}
        if fuse_bias and not grad:
            operand_output_aliases.update({4: 1})  # bias <-> bias_grad
        if fuse_gelu and grad:
            operand_output_aliases.update({5: 2})  # gelu_input <-> pre_gelu_out

        return jax.ffi.ffi_lowering(
            GemmPrimitive.name,
            operand_output_aliases=operand_output_aliases,
        )(ctx, *args, **kwargs)

    @staticmethod
    def impl(
        lhs,
        lhs_scale_inv,
        rhs,
        rhs_scale_inv,
        bias,
        gelu_input,
        out_dtype,
        contracting_dims,
        batched_dims,
        lhs_quantized_colwise,
        rhs_quantized_colwise,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
406
407
        sequence_parallel_output,
        sequence_dim,
Alp Dener's avatar
Alp Dener committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    ):
        lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
        lhs_transposed, rhs_transposed = _get_gemm_layout(
            (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims)
        )
        lhs_scale_inv = apply_padding_to_scale_inv(
            lhs_scale_inv,
            scaling_mode,
            lhs.shape,
            is_colwise=lhs_quantized_colwise,
            flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
        )
        rhs_scale_inv = apply_padding_to_scale_inv(
            rhs_scale_inv,
            scaling_mode,
            rhs.shape,
            is_colwise=rhs_quantized_colwise,
            flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
        )

        outputs = GemmPrimitive.inner_primitive.bind(
            lhs,
            lhs_scale_inv,
            rhs,
            rhs_scale_inv,
            bias,
            gelu_input,
            out_dtype=out_dtype,
            contracting_dims=contracting_dims,
            batched_dims=batched_dims,
            lhs_quantized_colwise=lhs_quantized_colwise,
            rhs_quantized_colwise=rhs_quantized_colwise,
            scaling_mode=scaling_mode,
            fuse_bias=fuse_bias,
            fuse_gelu=fuse_gelu,
            grad=grad,
            use_split_accumulator=use_split_accumulator,
445
446
            sequence_parallel_output=sequence_parallel_output,
            sequence_dim=sequence_dim,
Alp Dener's avatar
Alp Dener committed
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
        )
        return outputs[:-3]  # discard workspace arrays

    @staticmethod
    def batcher(
        batched_args,
        jax_batch_dims,
        out_dtype,
        contracting_dims,
        batched_dims,
        lhs_quantized_colwise,
        rhs_quantized_colwise,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
464
465
        sequence_parallel_output,
        sequence_dim,
Alp Dener's avatar
Alp Dener committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    ):
        assert GemmPrimitive.outer_primitive is not None
        lhs, _, rhs, *_ = batched_args
        lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims
        arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
        arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims
        assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), (
            "User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch "
            f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
        )
        arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims
        assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), (
            "User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch "
            f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
        )

        # Output is batched like the non-contracting batch dimensions of the LHS operand
        lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims)
        lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims)
        out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims

        # Bias gradient is never batched
        bias_bdims = (None,)

        # Pre-GeLU output, if exists, is batched like GEMM output
        pre_gelu_bdims = (None,)
        if fuse_gelu and not grad:
            pre_gelu_bdims = out_bdims

        return (
            GemmPrimitive.outer_primitive.bind(
                *batched_args,
                out_dtype=out_dtype,
                contracting_dims=contracting_dims,
                batched_dims=batched_dims,
                lhs_quantized_colwise=lhs_quantized_colwise,
                rhs_quantized_colwise=rhs_quantized_colwise,
                scaling_mode=scaling_mode,
                fuse_bias=fuse_bias,
                fuse_gelu=fuse_gelu,
                grad=grad,
                use_split_accumulator=use_split_accumulator,
508
509
                sequence_parallel_output=sequence_parallel_output,
                sequence_dim=sequence_dim,
Alp Dener's avatar
Alp Dener committed
510
511
512
513
514
            ),
            (out_bdims, bias_bdims, pre_gelu_bdims),
        )

    @staticmethod
515
516
517
518
519
520
521
    def _parse_operand_output_specs(
        arg_infos,
        contracting_dims,
        batched_dims,
        sequence_parallel_output,
        sequence_dim,
    ):
522
        del sequence_dim, sequence_parallel_output, batched_dims
Alp Dener's avatar
Alp Dener committed
523
524
        lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)

525
526
527
528
        lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
        lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims)
        lhs_non_cdims, rhs_non_cdims = map(
            lambda ndim, cdims: tuple(i for i in range(ndim) if i not in cdims),
Alp Dener's avatar
Alp Dener committed
529
            (lhs_ndim, rhs_ndim),
530
            (lhs_cdims, rhs_cdims),
531
        )
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        lhs_non_cspecs, lhs_cspecs, rhs_non_cspecs, rhs_cspecs = map(
            lambda specs, dims: tuple(specs[i] for i in dims),
            (lhs_specs, lhs_specs, rhs_specs, rhs_specs),
            (lhs_non_cdims, lhs_cdims, rhs_non_cdims, rhs_cdims),
        )

        reduce_spec = None
        for l in lhs_cspecs:
            for r in rhs_cspecs:
                if l is not None and l == r:
                    assert reduce_spec is None, "Multiple reduce dimension is detected!"
                    reduce_spec = l

        if reduce_spec is not None:
            # Other non-reduce cdims (if exists) need to be unsharded
            lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs)
            rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs)

            # Non-batched non-contracting dims of RHS needs to be unsharded (i.e. FSDP)
            # Check if spec is not the batch-dim is not needed as rhs_non_cspecs never includes batch-dim
            # rhs_specs only includes batch-dim in the Wgrad GEMM, but there batch-dim belongs to rhs_cspecs
            rhs_non_cspecs = tuple(
                None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs
555
556
557
            )
        else:
            # Otherwise, require contracting dims of both operands to be unsharded
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
            lhs_cspecs = (None,) * len(lhs_cspecs)
            rhs_cspecs = (None,) * len(rhs_cspecs)

        # Non-batched non-contracting dims of LHS to be unsharded, i.e gather SP dim
        # The spec for batch_dim in lhs_non_cspecs won't ever appear in the rhs_non_cspecs as
        # rhs_non_cspecs never has batch-dim. Hence, spec for batch_dim of lhs_non_cspecs won't be
        # overwrite
        # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for
        # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet.
        lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs)

        out_specs = lhs_non_cspecs + rhs_non_cspecs

        # specs = merge(cspecs, non_cspecs)
        lhs_specs, rhs_specs = map(
            lambda cdims, cspecs, non_cspecs: (
                cspecs + non_cspecs if cdims[0] == 0 else non_cspecs + cspecs
            ),
Alp Dener's avatar
Alp Dener committed
576
            (lhs_cdims, rhs_cdims),
577
578
            (lhs_cspecs, rhs_cspecs),
            (lhs_non_cspecs, rhs_non_cspecs),
Alp Dener's avatar
Alp Dener committed
579
580
        )

581
        # Bias and Pre-GeLU sharding is based on GEMM output before any scatter
582
        bias_specs = tuple(list(rhs_non_cspecs).copy())
583
584
        gelu_specs = tuple(list(out_specs).copy())

Alp Dener's avatar
Alp Dener committed
585
586
587
        return (
            (lhs_specs, rhs_specs, bias_specs, gelu_specs),
            (out_specs, bias_specs, gelu_specs),
588
            reduce_spec,
589
            0,
Alp Dener's avatar
Alp Dener committed
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        )

    @staticmethod
    def infer_sharding_from_operands(
        out_dtype,
        contracting_dims,
        batched_dims,
        lhs_quantized_colwise,
        rhs_quantized_colwise,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
604
605
        sequence_parallel_output,
        sequence_dim,
Alp Dener's avatar
Alp Dener committed
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            lhs_quantized_colwise,
            rhs_quantized_colwise,
            scaling_mode,
            grad,
        )
        del use_split_accumulator, result_infos

        (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = (
620
621
622
623
624
625
626
            GemmPrimitive._parse_operand_output_specs(
                arg_infos,
                contracting_dims,
                batched_dims,
                sequence_parallel_output,
                sequence_dim,
            )
Alp Dener's avatar
Alp Dener committed
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
        )
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))

        # Discard bias gradient spec if there is no bias fusion
        if not fuse_bias:
            dbias_specs = (None,)
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs))

        # Discard pre-GeLU output spec if there is no GeLU fusion
        if not fuse_gelu:
            pre_gelu_specs = (None,)
        pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))

        return [out_sharding, dbias_sharding, pre_gelu_sharding]

    @staticmethod
    def partition(
        out_dtype,
        contracting_dims,
        batched_dims,
        lhs_quantized_colwise,
        rhs_quantized_colwise,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
654
655
        sequence_parallel_output,
        sequence_dim,
Alp Dener's avatar
Alp Dener committed
656
657
658
659
660
661
662
663
664
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos

        (
            (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
            (out_specs, dbias_specs, pre_gelu_specs),
665
            reduce_spec,
666
            _,
667
668
669
670
671
672
673
        ) = GemmPrimitive._parse_operand_output_specs(
            arg_infos,
            contracting_dims,
            batched_dims,
            sequence_parallel_output,
            sequence_dim,
        )
Alp Dener's avatar
Alp Dener committed
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727

        # Assemble argument shardings
        # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
        none_sharding = NamedSharding(mesh, PartitionSpec(None))
        lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs))
        rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs))
        arg_shardings = (
            lhs_sharding,
            lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding,
            rhs_sharding,
            rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding,
        )

        # Discard bias input spec if there is no bias fusion
        if not fuse_bias:
            bias_input_specs = (None,)
        arg_shardings += (NamedSharding(mesh, PartitionSpec(*bias_input_specs)),)

        # Discard pre-GeLU input spec if there is no GeLU fusion
        if not fuse_gelu:
            gelu_input_specs = (None,)
        arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),)

        # Assemble output shardings
        out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))]

        # Discard bias gradient spec if there is no bias fusion
        if not fuse_bias:
            dbias_specs = (None,)
        out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs)))

        # Discard pre-GeLU output spec if there is no GeLU fusion
        if not fuse_gelu:
            pre_gelu_specs = (None,)
        out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)))

        def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input):
            outputs = GemmPrimitive.impl(
                lhs,
                lhs_scale_inv,
                rhs,
                rhs_scale_inv,
                bias,
                gelu_input,
                out_dtype=out_dtype,
                contracting_dims=contracting_dims,
                batched_dims=batched_dims,
                lhs_quantized_colwise=lhs_quantized_colwise,
                rhs_quantized_colwise=rhs_quantized_colwise,
                scaling_mode=scaling_mode,
                fuse_bias=fuse_bias,
                fuse_gelu=fuse_gelu,
                grad=grad,
                use_split_accumulator=use_split_accumulator,
728
729
                sequence_parallel_output=sequence_parallel_output,
                sequence_dim=sequence_dim,
Alp Dener's avatar
Alp Dener committed
730
731
732
            )

            # All-Reduce/Reduce-Scatter GEMM output
733
            if reduce_spec is not None:
734
                outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
Alp Dener's avatar
Alp Dener committed
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751

            return outputs

        return mesh, _sharded_impl, out_shardings, arg_shardings

    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        contracting_dims,
        batched_dims,
        lhs_quantized_colwise,
        rhs_quantized_colwise,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
752
753
        sequence_parallel_output,
        sequence_dim,
Alp Dener's avatar
Alp Dener committed
754
755
756
757
758
        mesh,
        operand_types,
        result_types,
    ):
        del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator
759
        del sequence_parallel_output, sequence_dim, mesh, result_types
Alp Dener's avatar
Alp Dener committed
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847

        prefix = "GemmPrimitive_"

        def _generate_operand_rules(name, ndim, cdims, bdims):
            specs = []
            ldims = tuple(i for i in range(ndim) if i not in bdims + cdims)
            for i in range(ndim):
                dim_name = None
                if i in bdims:
                    dim_idx = bdims.index(i) if len(bdims) > 1 else ""
                    dim_name = f"b{dim_idx}"
                elif i in cdims:
                    dim_idx = cdims.index(i) if len(cdims) > 1 else ""
                    dim_name = f"k{dim_idx}"
                else:
                    dim_idx = ldims.index(i) if len(ldims) > 1 else ""
                    dim_name = f"{name}_l{dim_idx}"
                specs.append(prefix + dim_name)
            return specs

        lhs, _, rhs, *_ = operand_types
        operand_ndims = (len(lhs.shape), len(rhs.shape))
        (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map(
            lambda dims: map(sanitize_dims, operand_ndims, dims),
            (contracting_dims, batched_dims),
        )
        lhs_specs, rhs_specs = map(
            _generate_operand_rules,
            ("lhs", "rhs"),
            operand_ndims,
            (lhs_cdims, rhs_cdims),
            (lhs_bdims, rhs_bdims),
        )
        lhs_scale_specs = ("…1",)
        rhs_scale_specs = ("…2",)
        if scaling_mode.is_1d_block_scaling():
            # Shardy rules for MXFP8 scales cannot be related to the operands because of the
            # global-unpadding and local-padding workflow. This can potentially insert expensive
            # re-shards in the partition call later if the scales are not already sharded correctly.
            lhs_scale_specs, rhs_scale_specs = map(
                lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs),
                (lhs_specs, rhs_specs),
            )

        lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims)
        rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims)
        out_spec = (*lhs_non_cspec, *rhs_non_cspec)
        bias_spec = rhs_non_cspec if fuse_bias else ("…4",)
        gelu_spec = out_spec if fuse_gelu else ("…5",)

        return SdyShardingRule(
            operand_mappings=(
                lhs_specs,
                lhs_scale_specs,
                rhs_specs,
                rhs_scale_specs,
                bias_spec,
                gelu_spec,
            ),
            result_mappings=(
                out_spec,
                bias_spec,
                gelu_spec,
            ),
        )


register_primitive(GemmPrimitive)


def gemm_uses_jax_dot() -> bool:
    """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot."""
    return not GemmPrimitive.enabled()


def _te_gemm(
    lhs: Union[jax.Array, ScaledTensor],
    rhs: Union[jax.Array, ScaledTensor],
    bias: jax.Array = None,
    gelu_input: jax.Array = None,
    lhs_quantizer: Quantizer = None,
    rhs_quantizer: Quantizer = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
    batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
    fuse_bias: bool = False,
    fuse_gelu: bool = False,
    grad: bool = False,
    use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
848
849
    sequence_parallel_output: bool = False,
    sequence_dim: int = None,
Alp Dener's avatar
Alp Dener committed
850
) -> Tuple[jax.Array, ...]:
851

Alp Dener's avatar
Alp Dener committed
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
    # Prepare non-quantized GEMM operands
    lhs_data = lhs
    rhs_data = rhs
    lhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
    rhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
    scaling_mode = ScalingMode.NO_SCALING
    lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims)
    lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
    lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)

    # Quantize operands (if necessary)
    lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)

    # Extract GEMM custom op inputs from quantized operands
    if isinstance(lhs_q, ScaledTensor):
        assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, (
            "cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid "
            "`Quantizer` object to quantize the RHS operand."
        )
        if isinstance(lhs_q, ScaledTensor2x):
            # Choose the quantization of the contracting dimension(s)
            lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor()
        scaling_mode = lhs_q.scaling_mode
        lhs_data = lhs_q.data
876
        lhs_scale_inv = lhs_q.scale_inv
Alp Dener's avatar
Alp Dener committed
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
        if lhs_q.data_layout == "T":
            lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis)
            lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis)

    if isinstance(rhs_q, ScaledTensor):
        assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, (
            "cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid "
            "`Quantizer` object to quantize the LHS operand."
        )
        if isinstance(rhs_q, ScaledTensor2x):
            # Choose the quantization of the contracting dimension(s)
            rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor()
        assert rhs_q.scaling_mode == lhs_q.scaling_mode, (
            "cuBLAS GEMM quantized operands have mismatched scaling types, "
            f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}."
        )
        rhs_data = rhs_q.data
894
        rhs_scale_inv = rhs_q.scale_inv
Alp Dener's avatar
Alp Dener committed
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
        if rhs_q.data_layout == "T":
            rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis)
            rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis)

    # Dummy empties for bias and gelu
    out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype
    if bias is None or not (fuse_bias and not grad):
        bias = jnp.empty(0, dtype=out_dtype)
    if gelu_input is None or not (fuse_gelu and grad):
        gelu_input = jnp.empty(0, dtype=out_dtype)

    return GemmPrimitive.outer_primitive.bind(
        lhs_data,
        lhs_scale_inv,
        rhs_data,
        rhs_scale_inv,
        bias,
        gelu_input,
        out_dtype=out_dtype,
        contracting_dims=(lhs_cdims, rhs_cdims),
        batched_dims=(lhs_bdims, rhs_bdims),
        lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False,
        rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False,
        scaling_mode=scaling_mode,
        fuse_bias=fuse_bias,
        fuse_gelu=fuse_gelu,
        grad=grad,
        use_split_accumulator=use_split_accumulator,
923
924
        sequence_parallel_output=sequence_parallel_output,
        sequence_dim=sequence_dim,
Alp Dener's avatar
Alp Dener committed
925
926
927
    )


928
929
930
931
932
933
934
class GroupedGemmPrimitive(BasePrimitive):
    """
    Primitive for grouped GEMM
    """

    name = "te_grouped_gemm_ffi"
    multiple_results = True
935
    impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15)
936
937
938
939
    inner_primitive = None
    outer_primitive = None

    @staticmethod
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
    def abstract(
        lhs_data_aval,
        lhs_scale_inv_aval,
        rhs_data_aval,
        rhs_scale_inv_aval,
        bias_aval,
        group_sizes_aval,
        group_offset_aval,
        *,
        M,
        N,
        K,
        lhs_is_trans,
        rhs_is_trans,
        scaling_mode,
        out_dtype,
        has_bias,
        is_grouped_dense_wgrad,
    ):
959
        """
960
961
        Grouped GEMM operation.

962
        Args:
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
            lhs_data: Left-hand side input matrix data, 1D flattened array
            lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array
            rhs_data: Right-hand side input matrix data, 1D flattened array
            rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array
            bias: Bias matrix of shape (G, N)
            group_sizes: 1D array containing the sizes of each group
            group_offset: 1D array containing offsets for each group (not yet implemented)
            M: Number of rows in the output matrix
            N: Number of columns in the output matrix
            K: Number of columns in the left-hand side matrix
            lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed
            rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed
            scaling_mode: Scaling mode for the GEMM operations
            out_dtype: Data type of the output tensors
            has_bias: Boolean indicating if bias tensors are provided
            is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation
                                    where both lhs and rhs are 2D matrices and output is (G, M, N)
980
981

        Returns:
982
            A jnp.ndarray containing the result of the grouped GEMM operation
983
        """
984
        del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
985
        del K, lhs_is_trans, rhs_is_trans, has_bias
986
        # TODO(Phuong): move some shape checks from Cpp to here
987
        workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
988
989
990
        workspace_alignment_padding = 256
        tensor_scaling_sinv_aligment = 16
        mxfp8_scaling_sinv_alignment_padding = 256
991
992
        # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
        # necessarily 256 bytes aligned, we add some padding to ensure alignment.
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
        workspace_size += workspace_alignment_padding
        if scaling_mode in (
            ScalingMode.DELAYED_TENSOR_SCALING.value,
            ScalingMode.CURRENT_TENSOR_SCALING.value,
        ):
            # For tensor scaling, each matrix has a single scale value, but it
            # needs to be aligned to 16 bytes for CUDA 12.9.1 and later.
            workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment
            workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
            # We also pad scale_inv swizzle buffers size for 256 bytes alignment.
            workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding
            workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding
1006
        workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
1007

1008
1009
1010
1011
1012
        out_shape = (M, N)
        if is_grouped_dense_wgrad:
            out_shape = (group_sizes_aval.size, M, N)
        out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
        return (out_aval, workspace_aval)
1013
1014
1015
1016

    @staticmethod
    def outer_abstract(*args, **kwargs):
        (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs)
1017
        return (out_aval,)
1018
1019

    @staticmethod
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
    def lowering(
        ctx,
        *args,
        M,
        N,
        K,
        lhs_is_trans,
        rhs_is_trans,
        scaling_mode,
        out_dtype,
        has_bias,
        is_grouped_dense_wgrad,
    ):
1033
        del out_dtype
1034
1035
        return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
            ctx,
1036
            *args,
1037
1038
1039
1040
1041
1042
            M=M,
            N=N,
            K=K,
            lhs_is_trans=lhs_is_trans,
            rhs_is_trans=rhs_is_trans,
            scaling_mode=scaling_mode.value,
1043
            has_bias=has_bias,
1044
            is_grouped_dense_wgrad=is_grouped_dense_wgrad,
1045
1046
1047
        )

    @staticmethod
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    def impl(
        lhs_data,
        lhs_scale_inv,
        rhs_data,
        rhs_scale_inv,
        bias,
        group_sizes,
        group_offset,
        M,
        N,
        K,
        lhs_is_trans,
        rhs_is_trans,
        scaling_mode,
        out_dtype,
        has_bias,
        is_grouped_dense_wgrad,
    ):
1066
        assert GroupedGemmPrimitive.inner_primitive is not None
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
        (out, _) = GroupedGemmPrimitive.inner_primitive.bind(
            lhs_data,
            lhs_scale_inv,
            rhs_data,
            rhs_scale_inv,
            bias,
            group_sizes,
            group_offset,
            M=M,
            N=N,
            K=K,
            lhs_is_trans=lhs_is_trans,
            rhs_is_trans=rhs_is_trans,
            scaling_mode=scaling_mode,
1081
            out_dtype=out_dtype,
1082
            has_bias=has_bias,
1083
            is_grouped_dense_wgrad=is_grouped_dense_wgrad,
1084
        )
1085
        return (out,)
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116


register_primitive(GroupedGemmPrimitive)


def _shape_normalization(x, dimension_numbers, already_transposed: bool = False):
    orig_order = list(range(x.ndim))
    contracting_dims, batch_dims = dimension_numbers
    contracting_order = [d for d in orig_order if d in contracting_dims]
    batch_order = [d for d in orig_order if d in batch_dims]
    non_contracting_order = [
        d for d in orig_order if d not in contracting_dims and d not in batch_dims
    ]
    batch_shape = [x.shape[d] for d in batch_order]
    rows_shape = [x.shape[d] for d in non_contracting_order]
    cols_shape = [x.shape[d] for d in contracting_order]
    new_order = batch_order + non_contracting_order + contracting_order
    rows, cols, batches = (
        reduce(operator.mul, rows_shape, 1),
        reduce(operator.mul, cols_shape, 1),
        reduce(operator.mul, batch_shape, 1),
    )
    # Remove this transpose when non-TN dot is supported
    if not already_transposed:
        t = jnp.transpose(x, new_order)
    else:
        t = x
    return jnp.reshape(t, (batches, rows, cols))


def _calculate_remaining_shape(shape, contracting_dims):
Alp Dener's avatar
Alp Dener committed
1117
1118
    contracting_dims_ = sanitize_dims(len(shape), contracting_dims)
    return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims_)
1119

1120

1121
1122
1123
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(2, 3))
def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
1124
    (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
1125
    if lhs.data_layout == "T":
Alp Dener's avatar
Alp Dener committed
1126
1127
        lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis)
        lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis)
1128
    if rhs.data_layout == "T":
Alp Dener's avatar
Alp Dener committed
1129
1130
        rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis)
        rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis)
1131

1132
    dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
1133

1134
    out_fp8 = jax.lax.dot_general(
1135
        lhs.data, rhs.data, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
1136
    )
1137
1138
    scale_inv = lhs.scale_inv * rhs.scale_inv
    out = (out_fp8 * scale_inv).astype(lhs.dq_dtype)
1139

1140
    return out
1141
1142


1143
@partial(jax.jit, static_argnums=(2,))
1144
1145
1146
1147
1148
1149
1150
def _jax_gemm_mxfp8_1d(
    lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
    """
    JAX GEMM for MXFP8 via scaled_matmul
    """
    assert (
1151
        rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
    ), "rhs does not have MXFP8 1D scaling mode"

    (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums

    expected_lhs_is_colwise = lhs_contract[-1] != lhs.data.ndim - 1
    expected_rhs_is_colwise = rhs_contract[-1] != rhs.data.ndim - 1
    assert lhs.is_colwise is expected_lhs_is_colwise, (
        f"LHS with unexpected quantize dimension.\nExpect is_colwise={expected_lhs_is_colwise}, got"
        f" {lhs.is_colwise}"
    )
    assert rhs.is_colwise is expected_rhs_is_colwise, (
        f"RHS with unexpected quantize dimension.\nExpect is_colwise={expected_rhs_is_colwise}, got"
        f" {rhs.is_colwise}"
    )

    # Reshape + Transpose (if needed)
    # [..., M, K] -> [1, reduce(..., M), K]
    # [..., K, M] -> [1, reduce(..., M), K]
    lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch))
    rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch))
    lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch))
    rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch))

    # JAX scaled_matmul only supports NT now (TN-gemm)
    # * Expected shape:
    # * lhs_data  (B, M, K)           * rhs_data  (B, N, K)
    # * lhs_scale (B, M, K_block)     * rhs_scale (B, N, K_block)
1179
    out_3d = jax.nn.scaled_matmul(
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
        lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype
    )
    # Reshape [1, reduce(..., M), N] -> [..., M, N]
    lhs_remain_shape = tuple(
        lhs.data.shape[dim] for dim in range(len(lhs.data.shape)) if dim not in lhs_contract
    )
    rhs_remain_shape = tuple(
        rhs.data.shape[dim] for dim in range(len(rhs.data.shape)) if dim not in rhs_contract
    )
    out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
    return out


def _jax_gemm(
    lhs: Union[jnp.ndarray, ScaledTensor],
    rhs: Union[jnp.ndarray, ScaledTensor],
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
Alp Dener's avatar
Alp Dener committed
1197
1198
    lhs_quantizer: Quantizer = None,
    rhs_quantizer: Quantizer = None,
1199
1200
1201
1202
1203
1204
1205
) -> jnp.ndarray:
    """
    FP8 GEMM via JAX
    """
    dim_nums = (contracting_dims, ((), ()))

    def _jax_gemm_fp8_impl(lhs, rhs):
1206
        if lhs.scaling_mode.is_tensor_scaling():
1207
1208
1209
1210
1211
1212
1213
1214
1215
            assert (
                rhs.scaling_mode == lhs.scaling_mode
            ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
            precision = (
                jax.lax.Precision.HIGHEST
                if QuantizeConfig.FP8_2X_ACC_FPROP
                else jax.lax.Precision.DEFAULT
            )
            return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
1216

1217
        if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
1218
1219
1220
1221
            return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)

        raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")

Alp Dener's avatar
Alp Dener committed
1222
    lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
1223

Alp Dener's avatar
Alp Dener committed
1224
1225
    if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor):
        return _jax_gemm_fp8_impl(lhs_q, rhs_q)
1226
1227
1228
1229

    if (
        isinstance(lhs, jnp.ndarray)
        and isinstance(rhs, jnp.ndarray)
Alp Dener's avatar
Alp Dener committed
1230
1231
        and lhs_quantizer is None
        and rhs_quantizer is None
1232
1233
1234
1235
1236
1237
1238
1239
1240
    ):
        return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype)

    raise NotImplementedError("Not supporting multiplication of ScaledTensor and jnp.array")


def gemm(
    lhs: Union[jnp.ndarray, ScaledTensor],
    rhs: Union[jnp.ndarray, ScaledTensor],
Alp Dener's avatar
Alp Dener committed
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
    batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
    lhs_quantizer: Quantizer = None,
    rhs_quantizer: Quantizer = None,
    **kwargs,
) -> Tuple[jnp.ndarray, ...]:
    r"""General matrix multiplication with optional quantization.

    Parameters
    ----------
    lhs: Union[jax.Array, ScaledTensor]
        Left-hand side operand in the matrix multiplication.
    rhs: Union[jax.Array, ScaledTensor]
        Right-hand side operand in the matrix multiplication.
    lhs_quantizer: Quantizer, default = None
        Object for down-casting the LHS operand for quantized GEMM.
    rhs_quantizer: Quantizer, default = None
        Object for down-casting the RHS operand for quantized GEMM.
    contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, ))
        Tuple of sequences representing the contracting dimensions of the operands.
    batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()),
        Tuple of sequences representing the batched dimensions of the operands. This is *not* used
1263
1264
1265
        to perform a batched matrix multiplication, but it is required for TE's custom cuBLAS GEMM
        call to avoid a potentially undesirable reduction in any batched contracting dimensions
        when invoked with sharded operands (e.g. when computing weight gradients in a Flax module).
Alp Dener's avatar
Alp Dener committed
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
    bias: jax.Array, default = None
        Optional additive bias term, required for forward GEMM with bias fusion. Only supported
        with TE's custom call to cuBLAS GEMM.
    gelu_input: jax.Array, default = None
        Pre-GeLU output from forward GEMM, required for backward/grad GEMM with dGeLU fusion. Only
        supported with TE's custom call to cuBLAS GEMM.
    fuse_bias: bool, default = False
        Enable bias addition in forward GEMM or bias gradient in backward GEMM. Only supported with
        TE's custom call to cuBLAS GEMM.
    fuse_gelu: bool, default = False
        Enable GeLU activation in forward GEMM or GeLU gradient in backward GEMM. Only supported
        with TE's custom call to cuBLAS GEMM.
    grad: bool, default = False
        Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with
        TE's custom call to cuBLAS GEMM.
    use_split_accumulator: bool, default = True
        Enable promoting some intermediate sums to higher precision when accumulating the result in
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
        the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only
        supported with TE's custom call to cuBLAS GEMM.
    sequence_parallel_output: bool, default = False
        Produces an output with the first non-batched non-contracting dimension sharded with the
        same spec as operand contracting dimensions. This effectively converts the `jax.lax.psum`
        for the GEMM output into a `jax.lax.psum_scatter`. Only supported with TE's custom call to
        cuBLAS GEMM.
    sequence_dim: int, default = None
        Index of the sequence dimension for the LHS operand. This controls which dimension of the
        GEMM output is scattered when `sequence_parallel_output=True`. When `None`, the first
        non-batched non-contracting dimension is assumed to be the sequence dimension.
Alp Dener's avatar
Alp Dener committed
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309

    Returns
    -------
    jax.Array:
        Result of the operation. For TE's custom call to cuBLAS GEMM, this result can include the
        GeLU application when `fuse_gelu=True` and `grad=False`, the GeLU gradient contribution
        when `fuse_gelu=True` and `grad=True`, and the additive bias when `fuse_bias=True` and
        `grad=False`.
    Optional[jax.Array]:
        Bias gradient when `fuse_bias=True` and `grad=True`. Only supported with TE's custom call
        to cuBLAS GEMM.
    Optional[jax.Array]:
        Pre-GeLU GEMM output when `fuse_gelu=True` and `grad=False`. This is required as an input
        to `_te_gemm()` with `fuse_gelu=True` and `grad=True` in the backward pass in order to
        compute the GeLU contribution to the gradient. Only supported with TE's custom call to
        cuBLAS GEMM.
1310
    """
Alp Dener's avatar
Alp Dener committed
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
    # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility
    if lhs_quantizer is None or rhs_quantizer is None:
        quantizer_set = kwargs.get("quantizer_set", None)
        if quantizer_set is not None:
            lhs_quantizer = quantizer_set.x
            rhs_quantizer = quantizer_set.kernel

    # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled
    fuse_bias = kwargs.get("fuse_bias", False)
    fuse_gelu = kwargs.get("fuse_gelu", False)
    if not GemmPrimitive.enabled():
        assert kwargs.get("bias", None) is None and not fuse_gelu, (
            "TE GEMM was invoked with bias fusion options that are not supported by the "
1324
            "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
Alp Dener's avatar
Alp Dener committed
1325
1326
1327
1328
            "GEMM primitive is disabled."
        )
        assert kwargs.get("gelu_input", None) is None and not fuse_bias, (
            "TE GEMM was invoked with GeLU fusion options that are not supported by the "
1329
1330
1331
1332
1333
1334
1335
1336
1337
            "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
            "GEMM primitive is disabled."
        )
        assert (
            not kwargs.get("sequence_parallel_output", False)
            and kwargs.get("sequence_dim", None) is None
        ), (
            "TE GEMM was invoked with sequence-parallelism options that are not supported by the "
            "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backedns used when the custom cuBLAS "
Alp Dener's avatar
Alp Dener committed
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
            "GEMM primitive is disabled."
        )
        return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)

    outputs = _te_gemm(
        lhs,
        rhs,
        lhs_quantizer=lhs_quantizer,
        rhs_quantizer=rhs_quantizer,
        contracting_dims=contracting_dims,
        batched_dims=batched_dims,
        **kwargs,
    )
1351

Alp Dener's avatar
Alp Dener committed
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
    # Discard empty outputs
    grad = kwargs.get("grad", False)
    clean_outputs = outputs[0]  # first output is the final result and is never empty
    if (fuse_bias and grad) or (fuse_gelu and not grad):
        clean_outputs = (outputs[0],)
        if fuse_bias and grad:  # only return bias gradient if it exists
            clean_outputs += (outputs[1],)
        if fuse_gelu and not grad:  # only return pre-GeLU output if it exists
            clean_outputs += (outputs[2],)
    return clean_outputs
1362
1363


1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
def grouped_gemm(
    lhs: Union[jnp.ndarray, GroupedScaledTensor1x],
    rhs: Union[jnp.ndarray, GroupedScaledTensor1x],
    group_sizes: jnp.ndarray,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)),
    bias: jnp.ndarray = None,
    precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
    preferred_element_type: jnp.dtype = None,
    group_offset: jnp.array = None,
    quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
    """
    Grouped GEMM operation.

    Args:
        lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
        rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
        group_sizes: 1D array containing the sizes of each group
        contracting_dims: Tuple of two sequences representing the contracting dimensions
        bias: Bias tensor of shape (G, N)
        precision: JAX precision for the GEMM operation
        preferred_element_type: Preferred data type for the output tensor
        group_offset: 1D array containing offsets for each group (not yet implemented)
        quantizer_set: Set of quantizers for FP8 quantization of the input and output
1388

1389
1390
    Returns:
        A jnp.ndarray containing the result of the grouped GEMM operation
1391

1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
    Note:
        Tested shapes:
        lhs: [M, K] or [K, N]
        rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K]
    """
    # TODO(Phuong): implement the group_offset
    group_offset = group_offset or jnp.zeros((1,), jnp.int32)

    # TODO(Phuong): implement the precision
    del precision

    if isinstance(lhs, jnp.ndarray):
        assert isinstance(rhs, jnp.ndarray)
        out_dtype = lhs.dtype
        lhs_shape = lhs.shape
        rhs_shape = rhs.shape
        lhs_data = lhs
        rhs_data = rhs
        lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32)
        scaling_mode = ScalingMode.NO_SCALING
    elif isinstance(lhs, GroupedScaledTensor1x):
        assert isinstance(rhs, GroupedScaledTensor1x)
        out_dtype = lhs.dq_dtype
        lhs_shape = lhs.original_shape
        rhs_shape = rhs.original_shape
        lhs_data = lhs.data
        rhs_data = rhs.data
        lhs_scale_inv = lhs.scale_inv
        rhs_scale_inv = rhs.scale_inv
        assert lhs.scaling_mode == rhs.scaling_mode
        scaling_mode = lhs.scaling_mode
    else:
        raise TypeError("Unsupported lhs type object!")

    out_dtype = preferred_element_type or out_dtype

    lhs_contract_dim, rhs_contract_dim = contracting_dims

    lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1
    lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1)

    # rhs_shape [G, K, N]
    rhs_is_trans = rhs_contract_dim[0] != 1
    rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim)

    is_grouped_dense_wgrad = False
    if len(rhs_shape) == 2:
        rhs_is_trans = rhs_contract_dim[0] != 0
        is_grouped_dense_wgrad = True

    # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this?
    if (
        is_grouped_dense_wgrad
        and not isinstance(lhs, ScaledTensor)
        and not isinstance(rhs, ScaledTensor)
    ):
        lhs_is_trans = True
        rhs_is_trans = False
        lhs_flatten_axis = 1
        rhs_flatten_axis = 1

    if (
        not isinstance(lhs, ScaledTensor)
        and not isinstance(rhs, ScaledTensor)
        and quantizer_set != noop_quantizer_set
    ):
        assert isinstance(quantizer_set.x, GroupedQuantizer)
        assert type(quantizer_set.x) is type(quantizer_set.kernel)
        scaling_mode = quantizer_set.x.scaling_mode
        if (
1462
1463
            quantizer_set.x.scaling_mode.is_tensor_scaling()
            and is_fp8_gemm_with_all_layouts_supported()
1464
        ):
1465
            lhs_is_rowwise = rhs_is_rowwise = True
1466
        else:
1467
            lhs_is_rowwise = not lhs_is_trans
1468
            rhs_is_rowwise = rhs_is_trans
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
        quantizer_set.x.q_layout = (
            QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE
        )
        quantizer_set.kernel.q_layout = (
            QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE
        )
        lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis)
        rhs_q = grouped_quantize(
            rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis
        )
        lhs_data = lhs_q.data
        rhs_data = rhs_q.data
        lhs_scale_inv = lhs_q.scale_inv
        rhs_scale_inv = rhs_q.scale_inv
1483
1484
        lhs_shape = lhs_q.original_shape
        rhs_shape = rhs_q.original_shape
1485
1486
1487
1488
1489
1490
1491

    assert not (
        lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2
    ), "FP8 GEMM does not support E5M2 * E5M2"

    # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs
    # thus additional transpose is required
1492
    if scaling_mode.is_tensor_scaling() and not is_fp8_gemm_with_all_layouts_supported():
1493
1494
1495
        if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
            lhs_layout_is_T = lhs.data_layout == "T"
            rhs_layout_is_T = rhs.data_layout == "T"
1496
        else:
1497
1498
            lhs_layout_is_T = lhs_q.data_layout == "T"
            rhs_layout_is_T = rhs_q.data_layout == "T"
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
        # we can't apply _shape_normalization on the grouped input
        # thus we need to ensure that lhs is in N and rhs is in T
        assert (
            lhs_is_trans == lhs_layout_is_T
        ), "lhs input must be transposed before calling grouped_gemm"
        assert (
            not rhs_is_trans == rhs_layout_is_T
        ), "rhs input must be transposed before calling grouped_gemm"
        lhs_is_trans = False
        rhs_is_trans = True
1509
1510
1511
1512
1513
        lhs_ndim = len(lhs_shape)
        rhs_ndim = len(rhs_shape)
        if lhs_layout_is_T:
            lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim)
        if rhs_layout_is_T:
1514
1515
1516
1517
1518
1519
1520
            # For rhs [G, K, N], need to exclude the G dim from contract_dim
            if group_sizes.size == rhs_shape[0]:
                rhs_contract_dim = tuple(
                    (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim
                )
            else:
                rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim)
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553

    # Calling GroupedGEMM Custom Call
    K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim)
    K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim)
    assert K_lhs == K_rhs
    M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim))
    N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:])  # Exclude G

    if is_grouped_dense_wgrad:
        N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim))
    else:
        assert group_sizes.size == rhs_shape[0]

    assert group_offset.size == 1

    has_bias = bias is not None
    assert not has_bias or bias.shape == (group_sizes.size, N)
    bias = jnp.empty((), jnp.float32) if bias is None else bias

    (out,) = GroupedGemmPrimitive.outer_primitive.bind(
        lhs_data,
        lhs_scale_inv,
        rhs_data,
        rhs_scale_inv,
        bias,
        group_sizes,
        group_offset,
        M=M,
        N=N,
        K=K_lhs,
        lhs_is_trans=lhs_is_trans,
        rhs_is_trans=rhs_is_trans,
        scaling_mode=scaling_mode.value,
1554
        out_dtype=out_dtype,
1555
1556
        has_bias=has_bias,
        is_grouped_dense_wgrad=is_grouped_dense_wgrad,
1557
    )
1558
    return out