gemm.py 77.2 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""

6
import math
Alp Dener's avatar
Alp Dener committed
7
8
import operator
from collections.abc import Iterable
Phuong Nguyen's avatar
Phuong Nguyen committed
9
from dataclasses import dataclass
Alp Dener's avatar
Alp Dener committed
10
from functools import partial, reduce
Phuong Nguyen's avatar
Phuong Nguyen committed
11
12
from typing import Tuple, Sequence, Union
from enum import Enum
13
import warnings
Alp Dener's avatar
Alp Dener committed
14

15
16
import jax
import jax.numpy as jnp
Alp Dener's avatar
Alp Dener committed
17
18
19
20
from jax import dtypes
from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental.custom_partitioning import SdyShardingRule

Phuong Nguyen's avatar
Phuong Nguyen committed
21
22
23
24
25
26
27
from transformer_engine_jax import (
    get_num_compute_streams,
    JAXX_Collective_Op,
    get_device_compute_capability,
    initialize_cgemm_communicator,
    get_cgemm_num_max_streams,
)
28
29

from .base import BasePrimitive, register_primitive
30
from .quantization import grouped_quantize
31
from ..quantize import (
32
33
    AbstractBaseTensor,
    NoScaleTensor,
34
    ScaledTensor,
35
    ScaledTensor1x,
Alp Dener's avatar
Alp Dener committed
36
    ScaledTensor2x,
37
    GroupedScaledTensor1x,
38
39
    ScalingMode,
    Quantizer,
40
    GroupedQuantizer,
41
    get_quantize_config,
42
43
    QuantizerSet,
    QuantizeLayout,
44
    noop_quantizer_set,
45
    is_fp8_gemm_with_all_layouts_supported,
Alp Dener's avatar
Alp Dener committed
46
    apply_padding_to_scale_inv,
47
)
Phuong Nguyen's avatar
Phuong Nguyen committed
48
49
50
51
52
53
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
    global_mesh_resource,
    tpsp_axis_size,
    dp_or_fsdp_axis_size,
)
54
55


Alp Dener's avatar
Alp Dener committed
56
__all__ = [
Phuong Nguyen's avatar
Phuong Nguyen committed
57
58
59
60
    "CollectiveOp",
    "CollectiveOpSet",
    "collective_gemm_bootstrap",
    "noop_collective_op_set",
Alp Dener's avatar
Alp Dener committed
61
    "gemm",
62
    "grouped_gemm_copy_group_sizes",
Alp Dener's avatar
Alp Dener committed
63
64
65
66
67
68
    "grouped_gemm",
    "gemm_uses_jax_dot",
    "sanitize_dims",
    "get_non_contracting_dims",
    "transpose_dims",
]
69
70


71
num_cublas_streams = get_num_compute_streams()
72
73
74
75


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
Phuong Nguyen's avatar
Phuong Nguyen committed
76
    if get_device_compute_capability(0) >= 90:
77
78
79
80
        return 33_554_432
    return 4_194_304


Alp Dener's avatar
Alp Dener committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]:
    """Convert relative (negative) indexes to absolute dimension numbers."""
    dims_ = dims if isinstance(dims, Iterable) else (dims,)
    if len(dims_) == 0:
        return dims_
    return tuple(ndim + dim if dim < 0 else dim for dim in dims_ if dim is not None)


def get_non_contracting_dims(ndim, contracting_dims):
    """Return a tuple of dimensions not included in the contracting dimensions."""
    contracting_dims = sanitize_dims(ndim, contracting_dims)
    return tuple(dim for dim in range(ndim) if dim not in contracting_dims)


def transpose_dims(ndim, dims_to_transpose, flatten_axis=-1):
    """Compute the new dimension numbers after transpose."""
    if len(dims_to_transpose) == 0:
        return dims_to_transpose
    flatten_axis = ndim - flatten_axis if flatten_axis > 0 else flatten_axis
    transposed_dims = (*range(flatten_axis, ndim), *range(flatten_axis))
    return tuple(transposed_dims.index(dim) for dim in dims_to_transpose)


def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool:
    lhs, rhs, e4m3, e5m2 = map(
        dtypes.canonicalize_dtype,
        (
            lhs_dtype,
            rhs_dtype,
            jnp.float8_e4m3fn,
            jnp.float8_e5m2,
        ),
    )

    # FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3)
    if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3):
        return True

    # Any other combination of data types is not supported
    return False


def _get_gemm_layout(
    operand_ndims: Tuple[int, int], contracting_dims: Tuple[Sequence[int], Sequence[int]]
) -> Tuple[bool, bool]:
    lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims)
    lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting
    rhs_is_transposed = operand_ndims[1] - 1 in rhs_contracting
    return lhs_is_transposed, rhs_is_transposed


def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims):
    lhs_q = lhs
    rhs_q = rhs

    if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None:
        lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0])
        lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims
        need_lhs_colwise = lhs_is_transposed and (
            lhs_quantizer.scaling_mode.is_1d_block_scaling()
            or not is_fp8_gemm_with_all_layouts_supported()
142
            or lhs_quantizer.scaling_mode.is_nvfp4_scaling
Alp Dener's avatar
Alp Dener committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        )
        flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims)
        lhs_q = lhs_quantizer.quantize(
            lhs,
            is_rowwise=not need_lhs_colwise,
            is_colwise=need_lhs_colwise,
            flatten_axis=flatten_axis,
        )

    if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None:
        rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1])
        rhs_is_transposed = rhs.ndim - 1 in rhs_cdims
        need_rhs_colwise = not rhs_is_transposed and (
            rhs_quantizer.scaling_mode.is_1d_block_scaling()
            or not is_fp8_gemm_with_all_layouts_supported()
158
            or rhs_quantizer.scaling_mode.is_nvfp4_scaling
Alp Dener's avatar
Alp Dener committed
159
160
161
162
163
164
165
166
167
168
169
170
        )
        flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1
        rhs_q = rhs_quantizer.quantize(
            rhs,
            is_rowwise=not need_rhs_colwise,
            is_colwise=need_rhs_colwise,
            flatten_axis=flatten_axis,
        )

    assert not isinstance(lhs_q, ScaledTensor2x)
    assert not isinstance(rhs_q, ScaledTensor2x)

171
172
    def has_rht_applied(q: AbstractBaseTensor) -> bool:
        return isinstance(q, ScaledTensor1x) and q.has_rht_applied
173

174
175
176
177
    assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), (
        "With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized"
        " with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the"
        " GEMM."
178
179
    )

Alp Dener's avatar
Alp Dener committed
180
181
182
    return lhs_q, rhs_q


183
184
185
186
187
188
def _get_nvfp4_tensor_scale_inv(amax):
    DATA_DTYPE_MAX = jnp.finfo(jnp.float4_e2m1fn.dtype).max.astype(jnp.float32)
    SCALE_DTYPE_MAX = jnp.finfo(jnp.float8_e4m3fn.dtype).max.astype(jnp.float32)
    return amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX)


Phuong Nguyen's avatar
Phuong Nguyen committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def collective_gemm_bootstrap(
    num_total_devices,
    num_devices_per_process,
    process_id,
    tensor_parallel_size,
    num_max_streams=3,
    compute_stream_priority=0,
    communication_stream_priority=0,
    num_sm_for_communication=2,
    use_ce=True,
    aggregate_all_gather=False,
):
    """Initialize NCCL communicators for Collective GEMM operations.

    This function sets up the distributed communication infrastructure needed for
    tensor parallel collective GEMM operations. It supports two main scenarios:

    1. **Multi-device per process**: TP domain = single process
       - Each process manages multiple GPUs (num_devices_per_process > 1)
       - TP group consists of GPUs within the same process
       - Example: 2 processes × 4 GPUs each = 8 total ranks, tp_size=4

    2. **Single device per process**: TP domain spans multiple processes
       - Each process manages one GPU (num_devices_per_process = 1)
       - TP group spans across multiple processes
       - Example: 8 processes × 1 GPU each = 8 total ranks, tp_size=4

    Args:
        num_total_devices (int): Total number of ranks across all processes.
            Must be divisible by num_devices_per_process.
        num_devices_per_process (int): Number of GPUs per process.
            - For multi-device: equals tp_size (e.g., 4 GPUs per process)
            - For single-device: equals 1 (1 GPU per process)
        process_id (int): Process identifier (0-based).
            Must be in range [0, num_total_devices // num_devices_per_process).
        tensor_parallel_size (int): Size of tensor parallel groups.
            Must divide num_total_devices evenly.
        num_max_streams (int, optional): Maximum number of CUDA streams for overlap.
            Higher values enable more parallelism but use more GPU resources. Default: 3.
        compute_stream_priority (int, optional): Priority for GEMM computation streams.
            Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0.
        communication_stream_priority (int, optional): Priority for NCCL communication streams.
            Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0.
        num_sm_for_communication (int, optional): Number of streaming multiprocessors
            reserved for communication operations. Default: 2.
        use_ce (bool, optional): Enable CUDA copy engines for memory transfers.
            Can improve performance by offloading memory operations. Default: True.
        aggregate_all_gather (bool, optional): Aggregate multiple small all-gather operations
            into larger ones for better efficiency. Default: False.

    Raises:
        AssertionError: If num_total_devices is not divisible by num_devices_per_process,
            or if process_id is out of valid range.
        AssertionError: If num_devices_per_process is not 1 (Temporary: only single device per process is supported for now)
        RuntimeError: If NCCL initialization fails or if configuration
            is invalid (e.g., insufficient GPUs).

    Example:
        # Basic initialization (single device per process)
        collective_gemm_bootstrap(
            num_total_devices=8,
            num_devices_per_process=1,
            process_id=0,
            tensor_parallel_size=4
        )

        # Advanced configuration with custom performance settings
        collective_gemm_bootstrap(
            num_total_devices=8,
            num_devices_per_process=1,
            process_id=0,
            tensor_parallel_size=4,
            num_max_streams=5,                    # More parallelism
            compute_stream_priority=1,            # Lower compute priority
            communication_stream_priority=0,      # Higher comm priority
            num_sm_for_communication=4,           # More SMs for communication
            use_ce=True,                         # Enable copy engines
            aggregate_all_gather=True            # Aggregate small operations
        )

    Note:
        This function must be called after JAX distributed initialization
        and before any collective GEMM operations. Each process should call
        this function with its own unique process_id.
    """

    assert (
        num_devices_per_process == 1 and jax.local_device_count() == 1
    ), "Only single device per process is supported at the moment!"
    assert num_total_devices % num_devices_per_process == 0, (
        f"Invalid num_total_devices={num_total_devices},"
        f" num_devices_per_process={num_devices_per_process}"
    )
    assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}"
    initialize_cgemm_communicator(
        num_total_devices,
        num_devices_per_process,
        process_id,
        tensor_parallel_size,
        num_max_streams,
        compute_stream_priority,
        communication_stream_priority,
        num_sm_for_communication,
        use_ce,
        aggregate_all_gather,
    )


class CollectiveOp(Enum):
    "Enum for Collective Type in Collective GEMM"

    NONE = JAXX_Collective_Op.NONE
    ALL_GATHER = JAXX_Collective_Op.ALL_GATHER
    REDUCE_SCATTER = JAXX_Collective_Op.REDUCE_SCATTER

    @property
    def is_all_gather(self) -> bool:
        """Check if AllGather"""
        return self == CollectiveOp.ALL_GATHER

    @property
    def is_reduce_scatter(self) -> bool:
        """Check if ReduceScatter"""
        return self == CollectiveOp.REDUCE_SCATTER

    @property
    def is_none(self) -> bool:
        """Check if None"""
        return self == CollectiveOp.NONE


@dataclass(frozen=True)
class CollectiveOpSet:
    """
    A set of CollectiveOp objects that provide complementary collective GEMM configurations for the Forward and Backward passes through Dense-layers.
    """

    forward: CollectiveOp
    backward: CollectiveOp

    @staticmethod
    def create(forward_collective_op: CollectiveOp):
        """Create a set of CollectiveOp for forward and backward passes"""
        if forward_collective_op.is_all_gather:
            backward_collective_op = CollectiveOp.REDUCE_SCATTER
        elif forward_collective_op.is_reduce_scatter:
            backward_collective_op = CollectiveOp.ALL_GATHER
        else:
            backward_collective_op = CollectiveOp.NONE
        return CollectiveOpSet(forward=forward_collective_op, backward=backward_collective_op)


noop_collective_op_set = CollectiveOpSet.create(forward_collective_op=CollectiveOp.NONE)


344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
@partial(jax.jit, static_argnums=(1, 2))
def swizzled_scale(scale_inv, flatten_axis, is_colwise):
    "Swizzle scale_inv via JAX transpose ops"
    original_shape = scale_inv.shape
    shape_2d = (math.prod(original_shape[:flatten_axis]), math.prod(original_shape[flatten_axis:]))
    if is_colwise:
        scale_inv = jnp.transpose(scale_inv.reshape(shape_2d))
        cols, rows = shape_2d
    else:
        rows, cols = shape_2d
    reshape = scale_inv.reshape(rows // 128, 4, 32, cols // 4, 4)
    swizzled = jnp.transpose(reshape, (0, 3, 2, 1, 4))
    return swizzled.reshape(original_shape)


359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
def get_lhs_axis_boundary(lhs_cdims, is_transposed):
    """Get the axis boundary for the LHS operand."""
    return max(lhs_cdims) + 1 if is_transposed else min(lhs_cdims)


def get_rhs_axis_boundary(rhs_cdims, is_transposed):
    """Get the axis boundary for the RHS operand."""
    return min(rhs_cdims) if is_transposed else max(rhs_cdims) + 1


def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name):
    """Assert that the given tensor shape and layout meet the requirements for cuBLAS GEMM."""
    if scaling_mode != ScalingMode.NO_SCALING:
        # Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
        alignment = 32 if scaling_mode.is_nvfp4_scaling else 16

        assert contracting_size % alignment == 0, (
            f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of"
            f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}"
        )


Alp Dener's avatar
Alp Dener committed
381
382
383
384
385
386
387
class GemmPrimitive(BasePrimitive):
    """
    Primitive for cuBLAS GEMM
    """

    name = "te_gemm_ffi"
    multiple_results = True
388
    impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)
Alp Dener's avatar
Alp Dener committed
389
390
391
392
393
394
395
396
397
398
399
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        lhs,
        lhs_scale_inv,
        rhs,
        rhs_scale_inv,
        bias,
        gelu_input,
400
401
        alpha,
        beta,
Alp Dener's avatar
Alp Dener committed
402
403
404
405
406
407
408
        out_dtype,
        contracting_dims,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
409
410
411
412
        transpose_batch_sequence,
        sequence_dim,
        is_outer,
        collective_op,
Alp Dener's avatar
Alp Dener committed
413
    ):
Phuong Nguyen's avatar
Phuong Nguyen committed
414
        del use_split_accumulator, transpose_batch_sequence
Alp Dener's avatar
Alp Dener committed
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448

        def _dims_are_consecutive(dims):
            if len(dims) <= 1:
                return True
            return sorted(dims) == list(range(min(dims), max(dims) + 1))

        # Sanity-check operand layouts and types
        operand_ndims = (lhs.ndim, rhs.ndim)

        (
            lhs_contracting_dims,
            rhs_contracting_dims,
        ) = map(sanitize_dims, operand_ndims, contracting_dims)
        assert _dims_are_consecutive(lhs_contracting_dims), (
            "cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got "
            f"{lhs_contracting_dims}."
        )
        assert _dims_are_consecutive(rhs_contracting_dims), (
            "cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got "
            f"{rhs_contracting_dims}."
        )

        lhs_contracting_size, rhs_contracting_size = map(
            lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]),
            (lhs.shape, rhs.shape),
            (lhs_contracting_dims, rhs_contracting_dims),
        )
        assert lhs_contracting_size == rhs_contracting_size, (
            "cuBLAS GEMM operands have incompatible contracting dimensions: "
            f"{lhs.shape} @ idx {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}."
        )

        lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims)
        if scaling_mode != ScalingMode.NO_SCALING:
449
450
451
            assert scaling_mode.is_nvfp4_scaling or _compatible_fp8_gemm_dtypes(
                lhs.dtype, rhs.dtype
            ), (
Alp Dener's avatar
Alp Dener committed
452
453
454
455
456
457
458
459
                "cuBLAS GEMM quantized operands have incompatible data types: "
                f"{lhs.dtype} x {rhs.dtype}."
            )
            assert (
                lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0
            ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands."
            if (
                scaling_mode != ScalingMode.MXFP8_1D_SCALING
Phuong Nguyen's avatar
Phuong Nguyen committed
460
                and not is_fp8_gemm_with_all_layouts_supported()
Alp Dener's avatar
Alp Dener committed
461
462
463
464
465
466
            ):
                assert not lhs_is_transposed and rhs_is_transposed, (
                    "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) "
                    "require non-transposed LHS and transposed RHS operands "
                    "(`contracting_dims=((-1, ), (-1, ))`)."
                )
467
468
469
470
471
        else:
            assert lhs.dtype == rhs.dtype, (
                "For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal."
                f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}"
            )
Alp Dener's avatar
Alp Dener committed
472
473
474
475
476
477
478
479
480
481
482
483
484

        # Determine output shape and dtype
        assert (
            dtypes.canonicalize_dtype(out_dtype).itemsize > 1
        ), "cuBLAS GEMM custom op does not support 8-bit quantized output types."
        lhs_non_contracting_shape, rhs_non_contracting_shape = map(
            lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims],
            (lhs.shape, rhs.shape),
            (lhs_contracting_dims, rhs_contracting_dims),
        )
        out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape)
        output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)

Phuong Nguyen's avatar
Phuong Nguyen committed
485
486
487
488
489
490
491
492
493
494
495
496
497
        # Adjust output shape for comm+GEMM overlap
        if not collective_op.is_none and not is_outer:  # Inner abstract
            assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
            overlap_out_shape = list(out_shape).copy()
            if collective_op.is_all_gather:
                overlap_out_shape[1] *= tpsp_axis_size()
            else:  # RS
                overlap_out_shape[sequence_dim] = (
                    overlap_out_shape[sequence_dim] // tpsp_axis_size()
                )
            assert out_dtype == jnp.bfloat16, f"Unsupported out_dtype={out_dtype}"
            output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype)

Alp Dener's avatar
Alp Dener committed
498
499
        # Validate bias
        if fuse_bias:
500
501
502
503
504
505
506
507
508
509
510
511
            assert bias.shape == tuple(rhs_non_contracting_shape), (
                "cuBLAS GEMM bias tensor has incorrect shape, "
                f"expected ({tuple(rhs_non_contracting_shape)}, ) but found {bias.shape}."
            )
            assert bias.dtype == out_dtype, (
                "cuBLAS GEMM bias tensor has incorrect data type, "
                f"expected {out_dtype} but found {bias.dtype}."
            )
        # WAR: allocate dbias regardless of fuse_bias so that the sharding propagation works as we
        # change the fuse_bias value in the sharded_impl
        dbias_shape = bias.shape if grad else (0,)
        bias_grad = jax.core.ShapedArray(shape=dbias_shape, dtype=bias.dtype)
Alp Dener's avatar
Alp Dener committed
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

        # Validate pre-GeLU
        pre_gelu_shape = (0,)
        pre_gelu_dtype = out_dtype
        if fuse_gelu:
            pre_gelu_shape = out_shape
            if grad:
                pre_gelu_ndim = len(pre_gelu_shape)
                assert gelu_input.ndim == pre_gelu_shape and all(
                    gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim)
                ), (
                    "cuBLAS GEMM pre-GeLU tensor has incorrect shape, "
                    f"expected {pre_gelu_shape} but found {gelu_input.shape}."
                )
                assert gelu_input.dtype == out_dtype, (
                    "cuBLAS GEMM pre-GeLU tensor has incorrect data type, "
                    f"expected {pre_gelu_dtype} but found {gelu_input.dtype}."
                )
        pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)
531
532
        assert alpha.size == 1 and alpha.dtype == jnp.float32
        assert beta.size == 1 and beta.dtype == jnp.float32
Alp Dener's avatar
Alp Dener committed
533
534

        # Declare cuBLAS workspace
Phuong Nguyen's avatar
Phuong Nguyen committed
535
536
537
        workspace_size = get_cublas_workspace_size_bytes()
        if not collective_op.is_none:
            workspace_size *= get_cgemm_num_max_streams()
Alp Dener's avatar
Alp Dener committed
538
539
        # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
        # necessarily 256 bytes aligned, we add some padding to ensure alignment.
Phuong Nguyen's avatar
Phuong Nguyen committed
540
        workspace_size += 256
Alp Dener's avatar
Alp Dener committed
541
542
        workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)

543
        return output, bias_grad, pre_gelu_out, workspace
Alp Dener's avatar
Alp Dener committed
544
545
546
547

    @staticmethod
    def outer_abstract(*args, **kwargs):
        outputs = GemmPrimitive.abstract(*args, **kwargs)
548
        return outputs[:-1]  # discard workspace array
Alp Dener's avatar
Alp Dener committed
549
550
551
552
553
554
555
556
557
558

    @staticmethod
    def lowering(
        ctx,
        lhs,
        lhs_scale_inv,
        rhs,
        rhs_scale_inv,
        bias,
        gelu_input,
559
560
        alpha,
        beta,
Alp Dener's avatar
Alp Dener committed
561
562
563
564
565
566
567
        out_dtype,
        contracting_dims,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
568
569
570
571
        transpose_batch_sequence,
        sequence_dim,
        is_outer,
        collective_op,
Alp Dener's avatar
Alp Dener committed
572
    ):
Phuong Nguyen's avatar
Phuong Nguyen committed
573
        del out_dtype, transpose_batch_sequence, sequence_dim, is_outer
574

Alp Dener's avatar
Alp Dener committed
575
576
577
578
579
580
        lhs_aval, _, rhs_aval, *_ = ctx.avals_in
        lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
        lhs_transposed, rhs_transposed = _get_gemm_layout(
            (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims)
        )

581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed)
        lhs_contracting_size = (
            reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:])
            if lhs_transposed
            else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary])
        )
        assert_cublas_requirements(
            scaling_mode,
            lhs_contracting_size,
            "LHS",
        )
        rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed)
        rhs_contracting_size = (
            reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary])
            if rhs_transposed
            else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:])
        )
        assert_cublas_requirements(
            scaling_mode,
            rhs_contracting_size,
            "RHS",
        )

604
        args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta)
Alp Dener's avatar
Alp Dener committed
605
606
        kwargs = {
            "scaling_mode": int(scaling_mode.value),
607
608
            "lhs_axis_boundary": get_lhs_axis_boundary(lhs_cdims, lhs_transposed),
            "rhs_axis_boundary": get_rhs_axis_boundary(rhs_cdims, rhs_transposed),
Alp Dener's avatar
Alp Dener committed
609
610
611
612
613
614
            "lhs_transposed": lhs_transposed,
            "rhs_transposed": rhs_transposed,
            "fuse_bias": fuse_bias,
            "fuse_gelu": fuse_gelu,
            "grad": grad,
            "use_split_accumulator": use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
615
            "collective_op": int(collective_op.value),
Alp Dener's avatar
Alp Dener committed
616
617
618
        }

        operand_output_aliases = {}
619
        if grad:
Alp Dener's avatar
Alp Dener committed
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
            operand_output_aliases.update({4: 1})  # bias <-> bias_grad
        if fuse_gelu and grad:
            operand_output_aliases.update({5: 2})  # gelu_input <-> pre_gelu_out

        return jax.ffi.ffi_lowering(
            GemmPrimitive.name,
            operand_output_aliases=operand_output_aliases,
        )(ctx, *args, **kwargs)

    @staticmethod
    def impl(
        lhs,
        lhs_scale_inv,
        rhs,
        rhs_scale_inv,
        bias,
        gelu_input,
637
638
        alpha,
        beta,
Alp Dener's avatar
Alp Dener committed
639
640
641
642
643
644
645
        out_dtype,
        contracting_dims,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
646
647
648
649
        transpose_batch_sequence,
        sequence_dim,
        is_outer,
        collective_op,
Alp Dener's avatar
Alp Dener committed
650
    ):
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
        if scaling_mode.is_1d_block_scaling():
            lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
            lhs_transposed, rhs_transposed = _get_gemm_layout(
                (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims)
            )
            lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims)
            rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1

            lhs_scale_inv = apply_padding_to_scale_inv(
                lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis
            )
            rhs_scale_inv = apply_padding_to_scale_inv(
                rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis
            )
            lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed)
            rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed)
Alp Dener's avatar
Alp Dener committed
667

Phuong Nguyen's avatar
Phuong Nguyen committed
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        # Alter lhs blocks so that CGEMM RS outputs correctly
        if (
            collective_op.is_reduce_scatter
            and not transpose_batch_sequence
            and not is_outer
            and not lhs.shape[0] == 1
        ):
            assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
            original_shape = lhs.shape
            assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, (
                f"Original_shape[0]={original_shape[0]} is not divisible by"
                f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}"
            )
            assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, (
                f"Original_shape[1]={original_shape[1]} is not divisible by"
                f" tpsp_axis_size()={tpsp_axis_size()}"
            )
            reshaped = lhs.reshape(
                dp_or_fsdp_axis_size(),
                int(original_shape[0] / dp_or_fsdp_axis_size()),
                tpsp_axis_size(),
                int(original_shape[1] / tpsp_axis_size()),
                *original_shape[2:],
            )
            reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim))
            lhs = reordered.reshape(original_shape)

        (output, bias_grad, pre_gelu_out, _) = GemmPrimitive.inner_primitive.bind(
Alp Dener's avatar
Alp Dener committed
696
697
698
699
700
701
            lhs,
            lhs_scale_inv,
            rhs,
            rhs_scale_inv,
            bias,
            gelu_input,
702
703
            alpha,
            beta,
Alp Dener's avatar
Alp Dener committed
704
705
706
707
708
709
710
            out_dtype=out_dtype,
            contracting_dims=contracting_dims,
            scaling_mode=scaling_mode,
            fuse_bias=fuse_bias,
            fuse_gelu=fuse_gelu,
            grad=grad,
            use_split_accumulator=use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
711
712
713
714
            collective_op=collective_op,
            transpose_batch_sequence=transpose_batch_sequence,
            sequence_dim=sequence_dim,
            is_outer=is_outer,
Alp Dener's avatar
Alp Dener committed
715
        )
Phuong Nguyen's avatar
Phuong Nguyen committed
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
        # Alter output blocks for CGEMM AG
        if (
            collective_op.is_all_gather
            and not transpose_batch_sequence
            and not is_outer
            and not output.shape[0] == 1
        ):
            assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
            original_shape = output.shape
            assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, (
                f"Original_shape[0]={original_shape[0]} is not divisible by"
                f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}"
            )
            assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, (
                f"Original_shape[1]={original_shape[1]} is not divisible by"
                f" tpsp_axis_size()={tpsp_axis_size()}"
            )
            reshaped = output.reshape(
                tpsp_axis_size(),
                dp_or_fsdp_axis_size(),
                int(original_shape[0] / dp_or_fsdp_axis_size()),
                int(original_shape[1] / tpsp_axis_size()),
                *original_shape[2:],
            )
            reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim))
            output = reordered.reshape(original_shape)

        return [output, bias_grad, pre_gelu_out]
744
745
746
747
748
749
750
751
752

    @staticmethod
    def outer_impl(
        lhs,
        lhs_scale_inv,
        rhs,
        rhs_scale_inv,
        bias,
        gelu_input,
753
754
        alpha,
        beta,
755
756
757
758
759
760
761
        out_dtype,
        contracting_dims,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
762
763
764
765
        transpose_batch_sequence,
        sequence_dim,
        is_outer,
        collective_op,
766
767
768
769
770
771
772
773
    ):
        return GemmPrimitive.impl(
            lhs,
            lhs_scale_inv,
            rhs,
            rhs_scale_inv,
            bias,
            gelu_input,
774
775
            alpha,
            beta,
776
777
778
779
780
781
782
            out_dtype,
            contracting_dims,
            scaling_mode,
            fuse_bias,
            fuse_gelu,
            grad,
            use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
783
784
785
786
            transpose_batch_sequence,
            sequence_dim,
            is_outer,
            collective_op,
787
        )
Alp Dener's avatar
Alp Dener committed
788
789
790
791

    @staticmethod
    def batcher(
        batched_args,
792
        batch_dims,
Alp Dener's avatar
Alp Dener committed
793
794
795
796
797
798
799
        out_dtype,
        contracting_dims,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
800
801
802
803
        collective_op,
        transpose_batch_sequence,
        sequence_dim,
        is_outer,
Alp Dener's avatar
Alp Dener committed
804
    ):
Phuong Nguyen's avatar
Phuong Nguyen committed
805
        del transpose_batch_sequence, sequence_dim, is_outer
Alp Dener's avatar
Alp Dener committed
806
        assert GemmPrimitive.outer_primitive is not None
807
        lhs_bdims, _, rhs_bdims, *_ = batch_dims
Alp Dener's avatar
Alp Dener committed
808

809
810
811
812
813
        # Batched GEMM is not supported
        assert (
            lhs_bdims is None and rhs_bdims is None
        ), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})"
        out_bdims = (None,)
Alp Dener's avatar
Alp Dener committed
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832

        # Bias gradient is never batched
        bias_bdims = (None,)

        # Pre-GeLU output, if exists, is batched like GEMM output
        pre_gelu_bdims = (None,)
        if fuse_gelu and not grad:
            pre_gelu_bdims = out_bdims

        return (
            GemmPrimitive.outer_primitive.bind(
                *batched_args,
                out_dtype=out_dtype,
                contracting_dims=contracting_dims,
                scaling_mode=scaling_mode,
                fuse_bias=fuse_bias,
                fuse_gelu=fuse_gelu,
                grad=grad,
                use_split_accumulator=use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
833
834
835
836
                collective_op=collective_op,
                transpose_batch_sequence=transpose_batch_sequence,
                sequence_dim=sequence_dim,
                is_outer=is_outer,
Alp Dener's avatar
Alp Dener committed
837
838
839
840
841
            ),
            (out_bdims, bias_bdims, pre_gelu_bdims),
        )

    @staticmethod
842
843
844
    def _parse_operand_output_specs(
        arg_infos,
        contracting_dims,
Phuong Nguyen's avatar
Phuong Nguyen committed
845
846
        transpose_batch_sequence,
        collective_op,
847
    ):
Alp Dener's avatar
Alp Dener committed
848
849
        lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)

850
851
852
853
        gsr = global_mesh_resource()

        # Ensure that tensor sequence parallelism is not used via setting tp_resource
        if gsr.tp_resource is not None:
Phuong Nguyen's avatar
Phuong Nguyen committed
854
855
856
857
858
859
            if gsr.tp_resource in lhs_specs:
                warnings.warn(
                    "Tensor sequence parallelism is detected as tp_resource='{gsr.tp_resource}'"
                    " appears in lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource"
                    " for tensor sequence parallelism to avoid potential issues."
                )
860

861
862
863
864
        lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
        lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims)
        lhs_non_cdims, rhs_non_cdims = map(
            lambda ndim, cdims: tuple(i for i in range(ndim) if i not in cdims),
Alp Dener's avatar
Alp Dener committed
865
            (lhs_ndim, rhs_ndim),
866
            (lhs_cdims, rhs_cdims),
867
        )
868
869
870
871
872
873
874
875
876
877
878
879
880
        lhs_non_cspecs, lhs_cspecs, rhs_non_cspecs, rhs_cspecs = map(
            lambda specs, dims: tuple(specs[i] for i in dims),
            (lhs_specs, lhs_specs, rhs_specs, rhs_specs),
            (lhs_non_cdims, lhs_cdims, rhs_non_cdims, rhs_cdims),
        )

        reduce_spec = None
        for l in lhs_cspecs:
            for r in rhs_cspecs:
                if l is not None and l == r:
                    assert reduce_spec is None, "Multiple reduce dimension is detected!"
                    reduce_spec = l

Phuong Nguyen's avatar
Phuong Nguyen committed
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
        sequence_dim = None

        # Find sequence dimension in lhs_specs if tensor sequence parallel is enabled
        # We only do CollectiveGemm AG on the x or dY thus they always the LHS and have sequence dim
        if collective_op.is_all_gather:
            try:
                tpsp_idx = lhs_specs.index(gsr.tpsp_resource)
            except ValueError as exc:
                raise ValueError(
                    f"tpsp_resource '{gsr.tpsp_resource}' is not found in lhs_specs: {lhs_specs}."
                    " Please check your sharding configuration."
                ) from exc
            sequence_dim = tpsp_idx
            assert (sequence_dim == 1) ^ transpose_batch_sequence, (
                "CollectiveGEMM supports only (sequence_dim=1 and transpose_batch_sequence=False)"
                " or (sequence_dim=0 and transpose_batch_sequence=True). Received:"
                f" sequence_dim={sequence_dim},"
                f" transpose_batch_sequence={transpose_batch_sequence}."
            )

        elif collective_op.is_reduce_scatter:
            assert reduce_spec == gsr.tpsp_resource, (
                "Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got"
                f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}"
            )
            sequence_dim = int(not transpose_batch_sequence)

908
909
910
        if reduce_spec is not None:
            # Other non-reduce cdims (if exists) need to be unsharded
            lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs)
Phuong Nguyen's avatar
Phuong Nguyen committed
911
912
913
914
915
916
917
            # Only do AG Sequence dim if not Overlap
            if collective_op.is_all_gather:
                rhs_cspecs = tuple(
                    s if s in (reduce_spec, gsr.tpsp_resource) else None for s in rhs_cspecs
                )
            else:
                rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs)
918

919
920
921
            # Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden
            # No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim.
            # In `rhs_specs`, the batch dim appears only in Wgrad GEMM under `rhs_cspecs`.
922
923
            rhs_non_cspecs = tuple(
                None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs
924
            )
925

926
927
        else:
            # Otherwise, require contracting dims of both operands to be unsharded
928
929
930
            lhs_cspecs = (None,) * len(lhs_cspecs)
            rhs_cspecs = (None,) * len(rhs_cspecs)

931
932
            # Non-contracting dims of RHS always needs to be gathered along the FSDP axis
            rhs_non_cspecs = tuple(
933
                None if spec is not None and spec == gsr.fsdp_resource else spec
934
                for spec in rhs_non_cspecs
935
936
            )

Phuong Nguyen's avatar
Phuong Nguyen committed
937
938
939
940
941
942
943
944
        # Only do AG Sequence dim if not Overlap
        if not collective_op.is_all_gather:
            # Non-contracting dims of LHS to be gathered along the SP axis.
            # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for
            # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet.
            lhs_non_cspecs = tuple(
                None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs
            )
945
946
947

        out_specs = lhs_non_cspecs + rhs_non_cspecs

Phuong Nguyen's avatar
Phuong Nguyen committed
948
949
950
951
952
953
954
955
956
957
958
959
960
961
        # Only do AG Sequence dim if not Overlap RS
        if collective_op.is_all_gather:
            assert sequence_dim <= len(
                lhs_non_cspecs
            ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}"
            out_specs = out_specs[:sequence_dim] + (None,) + out_specs[sequence_dim + 1 :]
        elif collective_op.is_reduce_scatter:
            assert sequence_dim <= len(
                lhs_non_cspecs
            ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}"
            out_specs = (
                out_specs[:sequence_dim] + (gsr.tpsp_resource,) + out_specs[sequence_dim + 1 :]
            )

962
963
964
965
966
        # specs = merge(cspecs, non_cspecs)
        lhs_specs, rhs_specs = map(
            lambda cdims, cspecs, non_cspecs: (
                cspecs + non_cspecs if cdims[0] == 0 else non_cspecs + cspecs
            ),
Alp Dener's avatar
Alp Dener committed
967
            (lhs_cdims, rhs_cdims),
968
969
            (lhs_cspecs, rhs_cspecs),
            (lhs_non_cspecs, rhs_non_cspecs),
Alp Dener's avatar
Alp Dener committed
970
971
        )

972
        # Bias and Pre-GeLU sharding is based on GEMM output before any scatter
973
        bias_specs = tuple(list(rhs_non_cspecs).copy())
974
975
        gelu_specs = tuple(list(out_specs).copy())

Phuong Nguyen's avatar
Phuong Nguyen committed
976
977
978
        if not collective_op.is_none:
            assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"

Alp Dener's avatar
Alp Dener committed
979
980
981
        return (
            (lhs_specs, rhs_specs, bias_specs, gelu_specs),
            (out_specs, bias_specs, gelu_specs),
982
            reduce_spec,
Phuong Nguyen's avatar
Phuong Nguyen committed
983
            sequence_dim,
Alp Dener's avatar
Alp Dener committed
984
985
986
987
988
989
990
991
992
993
994
        )

    @staticmethod
    def infer_sharding_from_operands(
        out_dtype,
        contracting_dims,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
995
996
997
998
        transpose_batch_sequence,
        sequence_dim,
        is_outer,
        collective_op,
Alp Dener's avatar
Alp Dener committed
999
1000
1001
1002
1003
1004
1005
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            scaling_mode,
Phuong Nguyen's avatar
Phuong Nguyen committed
1006
1007
1008
1009
            use_split_accumulator,
            result_infos,
            is_outer,
            sequence_dim,
Alp Dener's avatar
Alp Dener committed
1010
1011
        )

Phuong Nguyen's avatar
Phuong Nguyen committed
1012
1013
1014
1015
        (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = (
            GemmPrimitive._parse_operand_output_specs(
                arg_infos, contracting_dims, transpose_batch_sequence, collective_op
            )
Alp Dener's avatar
Alp Dener committed
1016
1017
1018
        )
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))

1019
1020
        # Discard dbias gradient spec if there is no bias and grad fusion
        if not (fuse_bias and grad):
Alp Dener's avatar
Alp Dener committed
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
            dbias_specs = (None,)
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs))

        # Discard pre-GeLU output spec if there is no GeLU fusion
        if not fuse_gelu:
            pre_gelu_specs = (None,)
        pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))

        return [out_sharding, dbias_sharding, pre_gelu_sharding]

    @staticmethod
    def partition(
        out_dtype,
        contracting_dims,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
1040
1041
1042
1043
        transpose_batch_sequence,
        sequence_dim,
        is_outer,
        collective_op,
Alp Dener's avatar
Alp Dener committed
1044
1045
1046
1047
        mesh,
        arg_infos,
        result_infos,
    ):
Phuong Nguyen's avatar
Phuong Nguyen committed
1048
        del result_infos, is_outer, sequence_dim
Alp Dener's avatar
Alp Dener committed
1049
1050
1051
1052

        (
            (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
            (out_specs, dbias_specs, pre_gelu_specs),
1053
            reduce_spec,
Phuong Nguyen's avatar
Phuong Nguyen committed
1054
1055
1056
1057
1058
1059
1060
            inferred_sequence_dim,
        ) = GemmPrimitive._parse_operand_output_specs(
            arg_infos,
            contracting_dims,
            transpose_batch_sequence,
            collective_op,
        )
Alp Dener's avatar
Alp Dener committed
1061

Phuong Nguyen's avatar
Phuong Nguyen committed
1062
        # Block scale inverses match their operands, but tensor scale inverses are unsharded.
Alp Dener's avatar
Alp Dener committed
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
        none_sharding = NamedSharding(mesh, PartitionSpec(None))
        lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs))
        rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs))
        arg_shardings = (
            lhs_sharding,
            lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding,
            rhs_sharding,
            rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding,
        )

        # Discard bias input spec if there is no bias fusion
        if not fuse_bias:
            bias_input_specs = (None,)
        arg_shardings += (NamedSharding(mesh, PartitionSpec(*bias_input_specs)),)

        # Discard pre-GeLU input spec if there is no GeLU fusion
        if not fuse_gelu:
            gelu_input_specs = (None,)
        arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),)

1083
1084
1085
        # Alpha, beta
        arg_shardings += (none_sharding, none_sharding)

Alp Dener's avatar
Alp Dener committed
1086
1087
1088
        # Assemble output shardings
        out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))]

1089
1090
        # Discard bias gradient spec if there is no bias and grad fusion
        if not (fuse_bias and grad):
Alp Dener's avatar
Alp Dener committed
1091
1092
1093
1094
1095
1096
1097
1098
            dbias_specs = (None,)
        out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs)))

        # Discard pre-GeLU output spec if there is no GeLU fusion
        if not fuse_gelu:
            pre_gelu_specs = (None,)
        out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)))

1099
        def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta):
1100
1101
            # We should not fuse bias in the output reduction case
            sharded_fuse_bias = fuse_bias and reduce_spec is None
Alp Dener's avatar
Alp Dener committed
1102
1103
1104
1105
1106
1107
1108
            outputs = GemmPrimitive.impl(
                lhs,
                lhs_scale_inv,
                rhs,
                rhs_scale_inv,
                bias,
                gelu_input,
1109
1110
                alpha,
                beta,
Alp Dener's avatar
Alp Dener committed
1111
1112
1113
                out_dtype=out_dtype,
                contracting_dims=contracting_dims,
                scaling_mode=scaling_mode,
1114
                fuse_bias=sharded_fuse_bias,
Alp Dener's avatar
Alp Dener committed
1115
1116
1117
                fuse_gelu=fuse_gelu,
                grad=grad,
                use_split_accumulator=use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
1118
1119
1120
1121
                transpose_batch_sequence=transpose_batch_sequence,
                sequence_dim=inferred_sequence_dim,
                is_outer=False,
                collective_op=collective_op,
Alp Dener's avatar
Alp Dener committed
1122
1123
            )

1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
            if reduce_spec is not None:
                if not collective_op.is_reduce_scatter:
                    if is_all_reduce_in_float32():  # For unittest only
                        outputs[0] = jax.lax.psum(
                            outputs[0].astype(jnp.float32), reduce_spec
                        ).astype(out_dtype)
                    else:
                        outputs[0] = jax.lax.psum(outputs[0], reduce_spec)

                if fuse_bias:  # TODO(Phuong): rename fuse_bias to has_bias
                    outputs[0] += bias
Alp Dener's avatar
Alp Dener committed
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148

            return outputs

        return mesh, _sharded_impl, out_shardings, arg_shardings

    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        contracting_dims,
        scaling_mode,
        fuse_bias,
        fuse_gelu,
        grad,
        use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
1149
1150
1151
1152
        transpose_batch_sequence,
        sequence_dim,
        is_outer,
        collective_op,
Alp Dener's avatar
Alp Dener committed
1153
1154
1155
1156
        mesh,
        operand_types,
        result_types,
    ):
1157
        del out_dtype, use_split_accumulator
Phuong Nguyen's avatar
Phuong Nguyen committed
1158
1159
1160
1161
1162
1163
1164
        del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer

        if not collective_op.is_none:
            raise NotImplementedError(
                "CollectiveGEMM with Shardy propagation is not supported yet! Please turn off"
                " Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false"
            )
Alp Dener's avatar
Alp Dener committed
1165

1166
        prefix = "Gemm_"
Alp Dener's avatar
Alp Dener committed
1167

1168
        def _generate_operand_rules(name, ndim, cdims):
Alp Dener's avatar
Alp Dener committed
1169
            specs = []
1170
            ldims = tuple(i for i in range(ndim) if i not in cdims)
Alp Dener's avatar
Alp Dener committed
1171
1172
            for i in range(ndim):
                dim_name = None
1173
1174
                if i in cdims:
                    dim_idx = cdims.index(i)
Alp Dener's avatar
Alp Dener committed
1175
1176
                    dim_name = f"k{dim_idx}"
                else:
1177
                    dim_idx = ldims.index(i)
Alp Dener's avatar
Alp Dener committed
1178
1179
1180
1181
1182
1183
                    dim_name = f"{name}_l{dim_idx}"
                specs.append(prefix + dim_name)
            return specs

        lhs, _, rhs, *_ = operand_types
        operand_ndims = (len(lhs.shape), len(rhs.shape))
1184
        (lhs_cdims, rhs_cdims) = map(sanitize_dims, operand_ndims, contracting_dims)
Alp Dener's avatar
Alp Dener committed
1185
1186
1187
1188
1189
1190
1191
1192
1193
        lhs_specs, rhs_specs = map(
            _generate_operand_rules,
            ("lhs", "rhs"),
            operand_ndims,
            (lhs_cdims, rhs_cdims),
        )
        lhs_scale_specs = ("…1",)
        rhs_scale_specs = ("…2",)
        if scaling_mode.is_1d_block_scaling():
1194
1195
            lhs_scale_specs = lhs_specs
            rhs_scale_specs = rhs_specs
Alp Dener's avatar
Alp Dener committed
1196
1197
1198
1199
1200

        lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims)
        rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims)
        out_spec = (*lhs_non_cspec, *rhs_non_cspec)
        bias_spec = rhs_non_cspec if fuse_bias else ("…4",)
1201
1202
1203
1204
        gelu_spec = out_spec if fuse_gelu else ("…5",)
        alpha_spec = ("_6",)
        beta_spec = ("_7",)
        dbias_spec = bias_spec if grad else ("…8")
Alp Dener's avatar
Alp Dener committed
1205
1206
1207
1208
1209
1210
1211
1212
1213

        return SdyShardingRule(
            operand_mappings=(
                lhs_specs,
                lhs_scale_specs,
                rhs_specs,
                rhs_scale_specs,
                bias_spec,
                gelu_spec,
1214
1215
                alpha_spec,
                beta_spec,
Alp Dener's avatar
Alp Dener committed
1216
1217
1218
            ),
            result_mappings=(
                out_spec,
1219
                dbias_spec,
Alp Dener's avatar
Alp Dener committed
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
                gelu_spec,
            ),
        )


register_primitive(GemmPrimitive)


def gemm_uses_jax_dot() -> bool:
    """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot."""
    return not GemmPrimitive.enabled()


def _te_gemm(
    lhs: Union[jax.Array, ScaledTensor],
    rhs: Union[jax.Array, ScaledTensor],
    bias: jax.Array = None,
    gelu_input: jax.Array = None,
    lhs_quantizer: Quantizer = None,
    rhs_quantizer: Quantizer = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
    fuse_bias: bool = False,
    fuse_gelu: bool = False,
    grad: bool = False,
1244
    use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP,
Phuong Nguyen's avatar
Phuong Nguyen committed
1245
1246
    transpose_batch_sequence: bool = False,
    collective_op: CollectiveOp = CollectiveOp.NONE,
Alp Dener's avatar
Alp Dener committed
1247
) -> Tuple[jax.Array, ...]:
1248

1249
1250
1251
1252
1253
1254
1255
    if grad or fuse_gelu:
        warnings.warn(
            "GEMM + fused grad or fused gelu is not well tested and will be deprecated in the"
            " future",
            DeprecationWarning,
        )

Alp Dener's avatar
Alp Dener committed
1256
1257
1258
1259
1260
1261
    # Prepare non-quantized GEMM operands
    lhs_data = lhs
    rhs_data = rhs
    lhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
    rhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
    scaling_mode = ScalingMode.NO_SCALING
Phuong Nguyen's avatar
Phuong Nguyen committed
1262

Alp Dener's avatar
Alp Dener committed
1263
1264
1265
1266
1267
1268
    lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims)
    lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)

    # Quantize operands (if necessary)
    lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)

1269
    lhs_amax = rhs_amax = None
Alp Dener's avatar
Alp Dener committed
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
    # Extract GEMM custom op inputs from quantized operands
    if isinstance(lhs_q, ScaledTensor):
        assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, (
            "cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid "
            "`Quantizer` object to quantize the RHS operand."
        )
        if isinstance(lhs_q, ScaledTensor2x):
            # Choose the quantization of the contracting dimension(s)
            lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor()
        scaling_mode = lhs_q.scaling_mode
        lhs_data = lhs_q.data
1281
        lhs_scale_inv = lhs_q.scale_inv
Alp Dener's avatar
Alp Dener committed
1282
1283
        if lhs_q.data_layout == "T":
            lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis)
1284
        lhs_amax = lhs_q.amax
Alp Dener's avatar
Alp Dener committed
1285
1286
1287
1288
1289
1290
1291
1292
1293

    if isinstance(rhs_q, ScaledTensor):
        assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, (
            "cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid "
            "`Quantizer` object to quantize the LHS operand."
        )
        if isinstance(rhs_q, ScaledTensor2x):
            # Choose the quantization of the contracting dimension(s)
            rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor()
1294
1295
1296
1297
1298
        assert (
            rhs_q.scaling_mode == lhs_q.scaling_mode
            or rhs_q.scaling_mode.is_nvfp4_scaling
            and lhs_q.scaling_mode.is_nvfp4_scaling
        ), (
Alp Dener's avatar
Alp Dener committed
1299
1300
1301
1302
            "cuBLAS GEMM quantized operands have mismatched scaling types, "
            f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}."
        )
        rhs_data = rhs_q.data
1303
        rhs_scale_inv = rhs_q.scale_inv
Alp Dener's avatar
Alp Dener committed
1304
1305
        if rhs_q.data_layout == "T":
            rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis)
1306
1307
1308
1309
1310
1311
1312
1313
1314
        rhs_amax = rhs_q.amax

    alpha = jnp.ones((1,), jnp.float32)
    beta = jnp.zeros((1,), jnp.float32)
    if scaling_mode.is_nvfp4_scaling:
        assert lhs_amax is not None and rhs_amax is not None
        lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs_amax)
        rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax)
        alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv
Alp Dener's avatar
Alp Dener committed
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329

    # Dummy empties for bias and gelu
    out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype
    if bias is None or not (fuse_bias and not grad):
        bias = jnp.empty(0, dtype=out_dtype)
    if gelu_input is None or not (fuse_gelu and grad):
        gelu_input = jnp.empty(0, dtype=out_dtype)

    return GemmPrimitive.outer_primitive.bind(
        lhs_data,
        lhs_scale_inv,
        rhs_data,
        rhs_scale_inv,
        bias,
        gelu_input,
1330
1331
        alpha,
        beta,
Alp Dener's avatar
Alp Dener committed
1332
1333
1334
1335
1336
1337
1338
        out_dtype=out_dtype,
        contracting_dims=(lhs_cdims, rhs_cdims),
        scaling_mode=scaling_mode,
        fuse_bias=fuse_bias,
        fuse_gelu=fuse_gelu,
        grad=grad,
        use_split_accumulator=use_split_accumulator,
Phuong Nguyen's avatar
Phuong Nguyen committed
1339
        transpose_batch_sequence=transpose_batch_sequence,
1340
        sequence_dim=-1,  #  Dummy value and will be set in the primitive
Phuong Nguyen's avatar
Phuong Nguyen committed
1341
1342
        is_outer=True,
        collective_op=collective_op,
Alp Dener's avatar
Alp Dener committed
1343
1344
1345
    )


1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
class GroupedGemmCopySizesPrimitive(BasePrimitive):
    """
    Primitive for async copying group sizes from device to host
    """

    name = "te_grouped_gemm_d2h_group_sizes_ffi"
    multiple_results = False
    impl_static_args = (1,)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        group_sizes_aval,
        *,
        num_gemms,
    ):
        del num_gemms
        out_aval = group_sizes_aval
        return out_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        out = GroupedGemmCopySizesPrimitive.abstract(*args, **kwargs)
        return out

    @staticmethod
    def lowering(
        ctx,
        group_sizes,
        num_gemms,
    ):
        return jax.ffi.ffi_lowering(
            GroupedGemmCopySizesPrimitive.name,
            operand_output_aliases={0: 0},  # Mark num_gemms as the output
        )(
            ctx,
            group_sizes,
            num_gemms=num_gemms,
        )

    @staticmethod
    def impl(
        group_sizes,
        num_gemms,
    ):
        assert GroupedGemmCopySizesPrimitive.inner_primitive is not None
        out = GroupedGemmCopySizesPrimitive.inner_primitive.bind(
            group_sizes,
            num_gemms=num_gemms,
        )
        return out


register_primitive(GroupedGemmCopySizesPrimitive)


1403
1404
1405
1406
1407
1408
1409
class GroupedGemmPrimitive(BasePrimitive):
    """
    Primitive for grouped GEMM
    """

    name = "te_grouped_gemm_ffi"
    multiple_results = True
1410
    impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
1411
1412
1413
1414
    inner_primitive = None
    outer_primitive = None

    @staticmethod
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
    def abstract(
        lhs_data_aval,
        lhs_scale_inv_aval,
        rhs_data_aval,
        rhs_scale_inv_aval,
        bias_aval,
        group_sizes_aval,
        group_offset_aval,
        *,
        M,
        N,
        K,
        lhs_is_trans,
        rhs_is_trans,
        scaling_mode,
        out_dtype,
        has_bias,
        is_grouped_dense_wgrad,
1433
        use_async_d2h_group_sizes,
1434
    ):
1435
        """
1436
1437
        Grouped GEMM operation.

1438
        Args:
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
            lhs_data: Left-hand side input matrix data, 1D flattened array
            lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array
            rhs_data: Right-hand side input matrix data, 1D flattened array
            rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array
            bias: Bias matrix of shape (G, N)
            group_sizes: 1D array containing the sizes of each group
            group_offset: 1D array containing offsets for each group (not yet implemented)
            M: Number of rows in the output matrix
            N: Number of columns in the output matrix
            K: Number of columns in the left-hand side matrix
            lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed
            rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed
            scaling_mode: Scaling mode for the GEMM operations
            out_dtype: Data type of the output tensors
            has_bias: Boolean indicating if bias tensors are provided
            is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation
                                    where both lhs and rhs are 2D matrices and output is (G, M, N)
1456
1457

        Returns:
1458
            A jnp.ndarray containing the result of the grouped GEMM operation
1459
        """
1460
        del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
1461
        del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes
1462
        # TODO(Phuong): move some shape checks from Cpp to here
1463
        workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
1464
1465
1466
        workspace_alignment_padding = 256
        tensor_scaling_sinv_aligment = 16
        mxfp8_scaling_sinv_alignment_padding = 256
1467
1468
        # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
        # necessarily 256 bytes aligned, we add some padding to ensure alignment.
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
        workspace_size += workspace_alignment_padding
        if scaling_mode in (
            ScalingMode.DELAYED_TENSOR_SCALING.value,
            ScalingMode.CURRENT_TENSOR_SCALING.value,
        ):
            # For tensor scaling, each matrix has a single scale value, but it
            # needs to be aligned to 16 bytes for CUDA 12.9.1 and later.
            workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment
            workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
            # We also pad scale_inv swizzle buffers size for 256 bytes alignment.
            workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding
            workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding
1482
        workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
1483

1484
1485
1486
1487
1488
        out_shape = (M, N)
        if is_grouped_dense_wgrad:
            out_shape = (group_sizes_aval.size, M, N)
        out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
        return (out_aval, workspace_aval)
1489
1490
1491
1492

    @staticmethod
    def outer_abstract(*args, **kwargs):
        (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs)
1493
        return (out_aval,)
1494
1495

    @staticmethod
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
    def lowering(
        ctx,
        *args,
        M,
        N,
        K,
        lhs_is_trans,
        rhs_is_trans,
        scaling_mode,
        out_dtype,
        has_bias,
        is_grouped_dense_wgrad,
1508
        use_async_d2h_group_sizes,
1509
    ):
1510
        del out_dtype
1511
1512
        return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
            ctx,
1513
            *args,
1514
1515
1516
1517
1518
1519
            M=M,
            N=N,
            K=K,
            lhs_is_trans=lhs_is_trans,
            rhs_is_trans=rhs_is_trans,
            scaling_mode=scaling_mode.value,
1520
            has_bias=has_bias,
1521
            is_grouped_dense_wgrad=is_grouped_dense_wgrad,
1522
            use_async_d2h_group_sizes=use_async_d2h_group_sizes,
1523
1524
1525
        )

    @staticmethod
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
    def impl(
        lhs_data,
        lhs_scale_inv,
        rhs_data,
        rhs_scale_inv,
        bias,
        group_sizes,
        group_offset,
        M,
        N,
        K,
        lhs_is_trans,
        rhs_is_trans,
        scaling_mode,
        out_dtype,
        has_bias,
        is_grouped_dense_wgrad,
1543
        use_async_d2h_group_sizes,
1544
    ):
1545
        assert GroupedGemmPrimitive.inner_primitive is not None
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
        (out, _) = GroupedGemmPrimitive.inner_primitive.bind(
            lhs_data,
            lhs_scale_inv,
            rhs_data,
            rhs_scale_inv,
            bias,
            group_sizes,
            group_offset,
            M=M,
            N=N,
            K=K,
            lhs_is_trans=lhs_is_trans,
            rhs_is_trans=rhs_is_trans,
            scaling_mode=scaling_mode,
1560
            out_dtype=out_dtype,
1561
            has_bias=has_bias,
1562
            is_grouped_dense_wgrad=is_grouped_dense_wgrad,
1563
            use_async_d2h_group_sizes=use_async_d2h_group_sizes,
1564
        )
1565
        return (out,)
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596


register_primitive(GroupedGemmPrimitive)


def _shape_normalization(x, dimension_numbers, already_transposed: bool = False):
    orig_order = list(range(x.ndim))
    contracting_dims, batch_dims = dimension_numbers
    contracting_order = [d for d in orig_order if d in contracting_dims]
    batch_order = [d for d in orig_order if d in batch_dims]
    non_contracting_order = [
        d for d in orig_order if d not in contracting_dims and d not in batch_dims
    ]
    batch_shape = [x.shape[d] for d in batch_order]
    rows_shape = [x.shape[d] for d in non_contracting_order]
    cols_shape = [x.shape[d] for d in contracting_order]
    new_order = batch_order + non_contracting_order + contracting_order
    rows, cols, batches = (
        reduce(operator.mul, rows_shape, 1),
        reduce(operator.mul, cols_shape, 1),
        reduce(operator.mul, batch_shape, 1),
    )
    # Remove this transpose when non-TN dot is supported
    if not already_transposed:
        t = jnp.transpose(x, new_order)
    else:
        t = x
    return jnp.reshape(t, (batches, rows, cols))


def _calculate_remaining_shape(shape, contracting_dims):
Alp Dener's avatar
Alp Dener committed
1597
1598
    contracting_dims_ = sanitize_dims(len(shape), contracting_dims)
    return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims_)
1599

1600

1601
1602
1603
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(2, 3))
def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
1604
    (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
1605
    if lhs.data_layout == "T":
Alp Dener's avatar
Alp Dener committed
1606
        lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis)
1607
    if rhs.data_layout == "T":
Alp Dener's avatar
Alp Dener committed
1608
        rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis)
1609

1610
    dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
1611

1612
    out_fp8 = jax.lax.dot_general(
1613
        lhs.data, rhs.data, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
1614
    )
1615
1616
    scale_inv = lhs.scale_inv * rhs.scale_inv
    out = (out_fp8 * scale_inv).astype(lhs.dq_dtype)
1617

1618
    return out
1619
1620


1621
@partial(jax.jit, static_argnums=(2,))
1622
def _jax_scaled_matmul(
1623
1624
1625
1626
1627
    lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
    """
    JAX GEMM for MXFP8 via scaled_matmul
    """
1628
1629
1630
1631
1632
    assert rhs.scaling_mode in (
        ScalingMode.MXFP8_1D_SCALING,
        ScalingMode.NVFP4_1D_SCALING,
        ScalingMode.NVFP4_2D_SCALING,
    ), f"rhs does not have MXFP8 or NVFP4 scaling mode, got rhs.scaling_mode={rhs.scaling_mode}"
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646

    (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums

    expected_lhs_is_colwise = lhs_contract[-1] != lhs.data.ndim - 1
    expected_rhs_is_colwise = rhs_contract[-1] != rhs.data.ndim - 1
    assert lhs.is_colwise is expected_lhs_is_colwise, (
        f"LHS with unexpected quantize dimension.\nExpect is_colwise={expected_lhs_is_colwise}, got"
        f" {lhs.is_colwise}"
    )
    assert rhs.is_colwise is expected_rhs_is_colwise, (
        f"RHS with unexpected quantize dimension.\nExpect is_colwise={expected_rhs_is_colwise}, got"
        f" {rhs.is_colwise}"
    )

1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
    if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
        out_dtype = lhs.dq_dtype
        assert (
            lhs.data_layout == "N" and rhs.data_layout == "N"
        ), f"Got lhs.data_layout={lhs.data_layout}, rhs.data_layout={rhs.data_layout}"
    else:
        if lhs.data_layout == "T":
            lhs_contract = transpose_dims(
                lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis
            )
        if rhs.data_layout == "T":
            rhs_contract = transpose_dims(
                rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis
            )
        out_dtype = jnp.float32

1663
1664
1665
    # Reshape + Transpose (if needed)
    # [..., M, K] -> [1, reduce(..., M), K]
    # [..., K, M] -> [1, reduce(..., M), K]
1666
1667
1668
1669
1670
1671
1672
1673
    lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch), lhs.data_layout == "T")
    rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch), rhs.data_layout == "T")
    lhs_scale_3d = _shape_normalization(
        lhs.scale_inv, (lhs_contract, lhs_batch), lhs.data_layout == "T"
    )
    rhs_scale_3d = _shape_normalization(
        rhs.scale_inv, (rhs_contract, rhs_batch), rhs.data_layout == "T"
    )
1674
1675
1676
1677
1678

    # JAX scaled_matmul only supports NT now (TN-gemm)
    # * Expected shape:
    # * lhs_data  (B, M, K)           * rhs_data  (B, N, K)
    # * lhs_scale (B, M, K_block)     * rhs_scale (B, N, K_block)
1679
    out_3d = jax.nn.scaled_matmul(
1680
        lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=out_dtype
1681
    )
1682
1683
1684
1685
1686
1687
1688
    if lhs.scaling_mode.is_nvfp4_scaling:
        assert lhs.amax is not None and rhs.amax is not None
        lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs.amax)
        rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs.amax)
        alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv
        out_3d = (out_3d * alpha).astype(lhs.dq_dtype)

1689
1690
1691
1692
1693
1694
1695
1696
    # Reshape [1, reduce(..., M), N] -> [..., M, N]
    lhs_remain_shape = tuple(
        lhs.data.shape[dim] for dim in range(len(lhs.data.shape)) if dim not in lhs_contract
    )
    rhs_remain_shape = tuple(
        rhs.data.shape[dim] for dim in range(len(rhs.data.shape)) if dim not in rhs_contract
    )
    out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
1697

1698
1699
1700
1701
1702
1703
1704
    return out


def _jax_gemm(
    lhs: Union[jnp.ndarray, ScaledTensor],
    rhs: Union[jnp.ndarray, ScaledTensor],
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
Alp Dener's avatar
Alp Dener committed
1705
1706
    lhs_quantizer: Quantizer = None,
    rhs_quantizer: Quantizer = None,
1707
1708
1709
1710
1711
1712
) -> jnp.ndarray:
    """
    FP8 GEMM via JAX
    """
    dim_nums = (contracting_dims, ((), ()))

1713
    def _jax_gemm_impl(lhs, rhs):
1714
        if lhs.scaling_mode.is_tensor_scaling():
1715
1716
1717
1718
1719
            assert (
                rhs.scaling_mode == lhs.scaling_mode
            ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
            precision = (
                jax.lax.Precision.HIGHEST
1720
                if get_quantize_config().FP8_2X_ACC_FPROP
1721
1722
1723
                else jax.lax.Precision.DEFAULT
            )
            return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
1724

1725
1726
        if lhs.scaling_mode.is_1d_block_scaling:
            return _jax_scaled_matmul(lhs, rhs, dim_nums)
1727

1728
        raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}")
1729

Alp Dener's avatar
Alp Dener committed
1730
    lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
1731

Alp Dener's avatar
Alp Dener committed
1732
    if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor):
1733
        return _jax_gemm_impl(lhs_q, rhs_q)
1734
1735
1736
1737

    if (
        isinstance(lhs, jnp.ndarray)
        and isinstance(rhs, jnp.ndarray)
Alp Dener's avatar
Alp Dener committed
1738
1739
        and lhs_quantizer is None
        and rhs_quantizer is None
1740
1741
1742
1743
1744
1745
1746
    ):
        return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype)

    raise NotImplementedError("Not supporting multiplication of ScaledTensor and jnp.array")


def gemm(
1747
1748
    lhs: Union[jnp.ndarray, AbstractBaseTensor],
    rhs: Union[jnp.ndarray, AbstractBaseTensor],
Alp Dener's avatar
Alp Dener committed
1749
1750
1751
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
    lhs_quantizer: Quantizer = None,
    rhs_quantizer: Quantizer = None,
Phuong Nguyen's avatar
Phuong Nguyen committed
1752
1753
    transpose_batch_sequence: bool = False,
    collective_op: CollectiveOp = CollectiveOp.NONE,
Alp Dener's avatar
Alp Dener committed
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
    **kwargs,
) -> Tuple[jnp.ndarray, ...]:
    r"""General matrix multiplication with optional quantization.

    Parameters
    ----------
    lhs: Union[jax.Array, ScaledTensor]
        Left-hand side operand in the matrix multiplication.
    rhs: Union[jax.Array, ScaledTensor]
        Right-hand side operand in the matrix multiplication.
    lhs_quantizer: Quantizer, default = None
        Object for down-casting the LHS operand for quantized GEMM.
    rhs_quantizer: Quantizer, default = None
        Object for down-casting the RHS operand for quantized GEMM.
    contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, ))
        Tuple of sequences representing the contracting dimensions of the operands.
    bias: jax.Array, default = None
        Optional additive bias term, required for forward GEMM with bias fusion. Only supported
        with TE's custom call to cuBLAS GEMM.
    gelu_input: jax.Array, default = None
        Pre-GeLU output from forward GEMM, required for backward/grad GEMM with dGeLU fusion. Only
        supported with TE's custom call to cuBLAS GEMM.
    fuse_bias: bool, default = False
        Enable bias addition in forward GEMM or bias gradient in backward GEMM. Only supported with
        TE's custom call to cuBLAS GEMM.
    fuse_gelu: bool, default = False
        Enable GeLU activation in forward GEMM or GeLU gradient in backward GEMM. Only supported
        with TE's custom call to cuBLAS GEMM.
    grad: bool, default = False
        Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with
        TE's custom call to cuBLAS GEMM.
    use_split_accumulator: bool, default = True
        Enable promoting some intermediate sums to higher precision when accumulating the result in
Phuong Nguyen's avatar
Phuong Nguyen committed
1787
1788
1789
1790
1791
        the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed.
    transpose_batch_sequence: bool, default = False
        Transpose the batch and sequence dimensions of the input tensor.
    collective_op: CollectiveOp, default = CollectiveOp.NONE
        Collective operation type for collective GEMM.
Alp Dener's avatar
Alp Dener committed
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807

    Returns
    -------
    jax.Array:
        Result of the operation. For TE's custom call to cuBLAS GEMM, this result can include the
        GeLU application when `fuse_gelu=True` and `grad=False`, the GeLU gradient contribution
        when `fuse_gelu=True` and `grad=True`, and the additive bias when `fuse_bias=True` and
        `grad=False`.
    Optional[jax.Array]:
        Bias gradient when `fuse_bias=True` and `grad=True`. Only supported with TE's custom call
        to cuBLAS GEMM.
    Optional[jax.Array]:
        Pre-GeLU GEMM output when `fuse_gelu=True` and `grad=False`. This is required as an input
        to `_te_gemm()` with `fuse_gelu=True` and `grad=True` in the backward pass in order to
        compute the GeLU contribution to the gradient. Only supported with TE's custom call to
        cuBLAS GEMM.
1808
    """
1809
1810
1811
1812
1813
    if isinstance(lhs, NoScaleTensor):
        lhs = lhs.data
    if isinstance(rhs, NoScaleTensor):
        rhs = rhs.data

Alp Dener's avatar
Alp Dener committed
1814
1815
1816
1817
1818
1819
1820
1821
    # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility
    if lhs_quantizer is None or rhs_quantizer is None:
        quantizer_set = kwargs.get("quantizer_set", None)
        if quantizer_set is not None:
            lhs_quantizer = quantizer_set.x
            rhs_quantizer = quantizer_set.kernel

    # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled
1822
    # TODO(Phuong): fuse_bias -> has_bias and has_bias = bias is not None
Alp Dener's avatar
Alp Dener committed
1823
1824
1825
1826
1827
    fuse_bias = kwargs.get("fuse_bias", False)
    fuse_gelu = kwargs.get("fuse_gelu", False)
    if not GemmPrimitive.enabled():
        assert kwargs.get("bias", None) is None and not fuse_gelu, (
            "TE GEMM was invoked with bias fusion options that are not supported by the "
1828
            "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
Alp Dener's avatar
Alp Dener committed
1829
1830
1831
1832
            "GEMM primitive is disabled."
        )
        assert kwargs.get("gelu_input", None) is None and not fuse_bias, (
            "TE GEMM was invoked with GeLU fusion options that are not supported by the "
1833
1834
1835
            "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
            "GEMM primitive is disabled."
        )
Phuong Nguyen's avatar
Phuong Nguyen committed
1836
        assert collective_op.is_none, "JAX GEMM does not support collective GEMM"
Alp Dener's avatar
Alp Dener committed
1837
1838
1839
1840
1841
1842
1843
1844
        return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)

    outputs = _te_gemm(
        lhs,
        rhs,
        lhs_quantizer=lhs_quantizer,
        rhs_quantizer=rhs_quantizer,
        contracting_dims=contracting_dims,
Phuong Nguyen's avatar
Phuong Nguyen committed
1845
1846
        transpose_batch_sequence=transpose_batch_sequence,
        collective_op=collective_op,
Alp Dener's avatar
Alp Dener committed
1847
1848
        **kwargs,
    )
1849

Alp Dener's avatar
Alp Dener committed
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
    # Discard empty outputs
    grad = kwargs.get("grad", False)
    clean_outputs = outputs[0]  # first output is the final result and is never empty
    if (fuse_bias and grad) or (fuse_gelu and not grad):
        clean_outputs = (outputs[0],)
        if fuse_bias and grad:  # only return bias gradient if it exists
            clean_outputs += (outputs[1],)
        if fuse_gelu and not grad:  # only return pre-GeLU output if it exists
            clean_outputs += (outputs[2],)
    return clean_outputs
1860
1861


1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
def grouped_gemm_copy_group_sizes(
    group_sizes: jnp.ndarray,
    num_gemms: int,
) -> jnp.ndarray:
    """
    Async copy group sizes from device to host

    Args:
        group_sizes: 1D array containing the sizes of each group
        num_gemms: number of grouped gemm calls to be made
    """
    out = GroupedGemmCopySizesPrimitive.outer_primitive.bind(
        group_sizes,
        num_gemms=num_gemms,
    )
    return out


1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
def grouped_gemm(
    lhs: Union[jnp.ndarray, GroupedScaledTensor1x],
    rhs: Union[jnp.ndarray, GroupedScaledTensor1x],
    group_sizes: jnp.ndarray,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)),
    bias: jnp.ndarray = None,
    precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
    preferred_element_type: jnp.dtype = None,
    group_offset: jnp.array = None,
    quantizer_set: QuantizerSet = noop_quantizer_set,
1890
    use_async_d2h_group_sizes: bool = False,
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
) -> jnp.ndarray:
    """
    Grouped GEMM operation.

    Args:
        lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
        rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
        group_sizes: 1D array containing the sizes of each group
        contracting_dims: Tuple of two sequences representing the contracting dimensions
        bias: Bias tensor of shape (G, N)
        precision: JAX precision for the GEMM operation
        preferred_element_type: Preferred data type for the output tensor
        group_offset: 1D array containing offsets for each group (not yet implemented)
        quantizer_set: Set of quantizers for FP8 quantization of the input and output
1905

1906
1907
    Returns:
        A jnp.ndarray containing the result of the grouped GEMM operation
1908

1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
    Note:
        Tested shapes:
        lhs: [M, K] or [K, N]
        rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K]
    """
    # TODO(Phuong): implement the group_offset
    group_offset = group_offset or jnp.zeros((1,), jnp.int32)

    # TODO(Phuong): implement the precision
    del precision

    if isinstance(lhs, jnp.ndarray):
        assert isinstance(rhs, jnp.ndarray)
        out_dtype = lhs.dtype
        lhs_shape = lhs.shape
        rhs_shape = rhs.shape
        lhs_data = lhs
        rhs_data = rhs
        lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32)
        scaling_mode = ScalingMode.NO_SCALING
    elif isinstance(lhs, GroupedScaledTensor1x):
        assert isinstance(rhs, GroupedScaledTensor1x)
        out_dtype = lhs.dq_dtype
        lhs_shape = lhs.original_shape
        rhs_shape = rhs.original_shape
        lhs_data = lhs.data
        rhs_data = rhs.data
        lhs_scale_inv = lhs.scale_inv
        rhs_scale_inv = rhs.scale_inv
        assert lhs.scaling_mode == rhs.scaling_mode
        scaling_mode = lhs.scaling_mode
    else:
        raise TypeError("Unsupported lhs type object!")

    out_dtype = preferred_element_type or out_dtype

    lhs_contract_dim, rhs_contract_dim = contracting_dims

    lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1
    lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1)

    # rhs_shape [G, K, N]
    rhs_is_trans = rhs_contract_dim[0] != 1
    rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim)

    is_grouped_dense_wgrad = False
    if len(rhs_shape) == 2:
        rhs_is_trans = rhs_contract_dim[0] != 0
        is_grouped_dense_wgrad = True

    # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this?
    if (
        is_grouped_dense_wgrad
        and not isinstance(lhs, ScaledTensor)
        and not isinstance(rhs, ScaledTensor)
    ):
        lhs_is_trans = True
        rhs_is_trans = False
        lhs_flatten_axis = 1
        rhs_flatten_axis = 1

    if (
        not isinstance(lhs, ScaledTensor)
        and not isinstance(rhs, ScaledTensor)
        and quantizer_set != noop_quantizer_set
    ):
        assert isinstance(quantizer_set.x, GroupedQuantizer)
        assert type(quantizer_set.x) is type(quantizer_set.kernel)
        scaling_mode = quantizer_set.x.scaling_mode
        if (
1979
1980
            quantizer_set.x.scaling_mode.is_tensor_scaling()
            and is_fp8_gemm_with_all_layouts_supported()
1981
        ):
1982
            lhs_is_rowwise = rhs_is_rowwise = True
1983
        else:
1984
            lhs_is_rowwise = not lhs_is_trans
1985
            rhs_is_rowwise = rhs_is_trans
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
        quantizer_set.x.q_layout = (
            QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE
        )
        quantizer_set.kernel.q_layout = (
            QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE
        )
        lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis)
        rhs_q = grouped_quantize(
            rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis
        )
        lhs_data = lhs_q.data
        rhs_data = rhs_q.data
        lhs_scale_inv = lhs_q.scale_inv
        rhs_scale_inv = rhs_q.scale_inv
2000
2001
        lhs_shape = lhs_q.original_shape
        rhs_shape = rhs_q.original_shape
2002
2003
2004
2005
2006
2007
2008

    assert not (
        lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2
    ), "FP8 GEMM does not support E5M2 * E5M2"

    # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs
    # thus additional transpose is required
2009
    if scaling_mode.is_tensor_scaling() and not is_fp8_gemm_with_all_layouts_supported():
2010
2011
2012
        if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
            lhs_layout_is_T = lhs.data_layout == "T"
            rhs_layout_is_T = rhs.data_layout == "T"
2013
        else:
2014
2015
            lhs_layout_is_T = lhs_q.data_layout == "T"
            rhs_layout_is_T = rhs_q.data_layout == "T"
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
        # we can't apply _shape_normalization on the grouped input
        # thus we need to ensure that lhs is in N and rhs is in T
        assert (
            lhs_is_trans == lhs_layout_is_T
        ), "lhs input must be transposed before calling grouped_gemm"
        assert (
            not rhs_is_trans == rhs_layout_is_T
        ), "rhs input must be transposed before calling grouped_gemm"
        lhs_is_trans = False
        rhs_is_trans = True
2026
2027
2028
2029
2030
        lhs_ndim = len(lhs_shape)
        rhs_ndim = len(rhs_shape)
        if lhs_layout_is_T:
            lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim)
        if rhs_layout_is_T:
2031
2032
2033
2034
2035
2036
2037
            # For rhs [G, K, N], need to exclude the G dim from contract_dim
            if group_sizes.size == rhs_shape[0]:
                rhs_contract_dim = tuple(
                    (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim
                )
            else:
                rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim)
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070

    # Calling GroupedGEMM Custom Call
    K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim)
    K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim)
    assert K_lhs == K_rhs
    M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim))
    N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:])  # Exclude G

    if is_grouped_dense_wgrad:
        N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim))
    else:
        assert group_sizes.size == rhs_shape[0]

    assert group_offset.size == 1

    has_bias = bias is not None
    assert not has_bias or bias.shape == (group_sizes.size, N)
    bias = jnp.empty((), jnp.float32) if bias is None else bias

    (out,) = GroupedGemmPrimitive.outer_primitive.bind(
        lhs_data,
        lhs_scale_inv,
        rhs_data,
        rhs_scale_inv,
        bias,
        group_sizes,
        group_offset,
        M=M,
        N=N,
        K=K_lhs,
        lhs_is_trans=lhs_is_trans,
        rhs_is_trans=rhs_is_trans,
        scaling_mode=scaling_mode.value,
2071
        out_dtype=out_dtype,
2072
2073
        has_bias=has_bias,
        is_grouped_dense_wgrad=is_grouped_dense_wgrad,
2074
        use_async_d2h_group_sizes=use_async_d2h_group_sizes,
2075
    )
2076
    return out