fbgemm_grouped_gemm.py 45.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
# Copy from https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import functools
import inspect
import sys
import warnings
from typing import Optional

import torch
import triton  # @manual
import triton.language as tl  # @manual
from triton.runtime import driver  # @manual


def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
    """
    Maps torch dtype to triton dtype.

    Args:
        dtype (torch.dtype): input dtype.

    Returns:
        tl.dtype: triton dtype.
    """
    if dtype == torch.float16:
        return tl.float16
    elif dtype == torch.bfloat16:
        return tl.bfloat16
    elif dtype == torch.float32:
        return tl.float32
    elif dtype == torch.int32:
        return tl.int32
    elif dtype == torch.float8_e4m3fn and torch.version.hip is None:
        return tl.float8e4nv
    else:
        raise ValueError(f"Unsupported dtype {dtype}")


# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)

if HAS_TMA_DESC:
    print(
        "TMA benchmarks will be running with experimental grid constant TMA descriptor.",
        file=sys.stderr,
    )
else:
    print(
        "TMA benchmarks will be running without grid constant TMA descriptor.",
        file=sys.stderr,
    )


class TmaAutoTuneHelper:

    # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
    class KernelParamWrapper:
        def __init__(self, desc):
            self.desc = desc

        def tma_desc_cpu_ptr(self):
            return self.desc.data_ptr()

    TMA_SIZE = 128

    def __init__(self):
        self.fill_1d_tma_descriptor_inner = (
            triton.runtime.driver.active.utils.fill_1d_tma_descriptor
        )
        self.fill_2d_tma_descriptor_inner = (
            triton.runtime.driver.active.utils.fill_2d_tma_descriptor
        )
        if HAS_TMA_DESC:
            self.descriptors = {}
        else:
            self.cuda_descriptors = {}

    # Call this method outside of the lambda function for grid size
    def init_tma_descriptor(self, name):
        if HAS_TMA_DESC:
            self.descriptors[name] = torch.empty(
                TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
            )
        else:
            self.cuda_descriptors[name] = torch.empty(
                TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
            )

    # Call this method inside the lambda function for grid size
    def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
        if HAS_TMA_DESC:
            desc_x = self.descriptors[name]
            assert desc_x.data_ptr() % 64 == 0
            self.fill_1d_tma_descriptor_inner(
                ptr, dim, block_dim, element_size, desc_x.data_ptr()
            )
        else:
            desc_x = self.cuda_descriptors[name]
            buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
            self.fill_1d_tma_descriptor_inner(
                ptr, dim, block_dim, element_size, buf_x.data_ptr()
            )
            desc_x.copy_(buf_x, non_blocking=True)

    # Call this method inside the lambda function for grid size
    def fill_2d_tma_descriptor(
        self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
    ):
        if HAS_TMA_DESC:
            desc_x = self.descriptors[name]
            assert desc_x.data_ptr() % 64 == 0
            self.fill_2d_tma_descriptor_inner(
                ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
            )
        else:
            desc_x = self.cuda_descriptors[name]
            buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
            self.fill_2d_tma_descriptor_inner(
                ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
            )
            desc_x.copy_(buf_x, non_blocking=True)

    def get_tma_descriptor_kernel_param(self, name):
        if HAS_TMA_DESC:
            assert self.descriptors[name] is not None
            return self.KernelParamWrapper(self.descriptors[name])
        else:
            assert self.cuda_descriptors[name] is not None
            return self.cuda_descriptors[name]


_NV_CONFIGS = [
    triton.Config(
        {
            "BLOCK_SIZE_M": block_size_m,
            "BLOCK_SIZE_N": block_size_n,
            "BLOCK_SIZE_K": block_size_k,
            "NUM_CONSUMER_GROUPS": 1,
        },
        num_stages=num_stages,
        num_warps=num_warps,
        num_ctas=num_ctas,
    )
    for block_size_m in [64, 128]
    for block_size_n in [64, 128, 256]
    for block_size_k in [64, 128, 256]
    for num_stages in [3, 4]
    for num_warps in [4, 8]
    for num_ctas in [1]
]

_HAS_WS_SUPPORT = None


def _check_ws_support():
    if not hasattr(tl, "async_task"):
        return False
    config_signature = inspect.signature(triton.Config).parameters
    if (
        "num_consumer_groups" not in config_signature
        or "num_buffers_warp_spec" not in config_signature
    ):
        return False
    if not HAS_TMA_DESC:
        return False
    return True


def _set_ws_support():
    global _HAS_WS_SUPPORT
    if _HAS_WS_SUPPORT is None:
        _HAS_WS_SUPPORT = _check_ws_support()


_set_ws_support()

if _HAS_WS_SUPPORT:
    _NV_WS_CONFIGS = [
        triton.Config(
            {
                "BLOCK_SIZE_M": block_size_m,
                "BLOCK_SIZE_N": block_size_n,
                "BLOCK_SIZE_K": block_size_k,
                "NUM_CONSUMER_GROUPS": max(1, num_consumer_groups),
                "USE_TMA_LOAD_ON_SCALES": use_tma_load_on_scales,
                "USE_TMA_STORE": use_tma_store,
            },
            num_stages=num_stages,
            num_warps=num_warps,
            num_ctas=num_ctas,
            num_consumer_groups=num_consumer_groups,
            num_buffers_warp_spec=num_stages,
        )
        for block_size_m in [64, 128, 256]
        for block_size_n in [64, 128, 256]
        for block_size_k in [64, 128, 256]
        for num_stages in [2, 3, 4]
        for num_warps in [4, 8, 16]
        # TODO(shikaili): Resolve LLVM error.
        for num_ctas in [1]
        for num_consumer_groups in [0, 2]
        for use_tma_load_on_scales in [True, False]
        # TODO(shikaili): Resolve compatibility with ws.
        for use_tma_store in [False]
    ]
else:
    _NV_WS_CONFIGS = _NV_CONFIGS


_AMD_CONFIGS = [
    triton.Config(
        {
            "BLOCK_SIZE_M": block_size_m,
            "BLOCK_SIZE_N": block_size_n,
            "BLOCK_SIZE_K": block_size_k,
            "waves_per_eu": waves_per_cu,
            "matrix_instr_nonkdim": matrix_instr_nonkdim,
            "NUM_CONSUMER_GROUPS": 1,
        },
        num_stages=num_stages,
        num_warps=num_warps,
    )
    for block_size_m in [32, 64, 128]
    for block_size_n in [32, 64, 128, 256]
    for block_size_k in [128, 256]
    for num_stages in [1, 2]
    for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]
    for matrix_instr_nonkdim in [16]
]


def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
    device = torch.cuda.current_device()
    # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
    if dtsize is None:
        dtsize = named_args["c_ptr"].element_size()
    if dtype is None:
        dtype = named_args["c_ptr"].dtype

    pruned_configs = []
    for config in configs:
        kw = config.kwargs
        (
            BLOCK_M,
            BLOCK_N,
            BLOCK_K,
            num_stages,
            num_warps,
            num_consumer_groups,
            use_tma_load_on_scales,
        ) = (
            kw["BLOCK_SIZE_M"],
            kw["BLOCK_SIZE_N"],
            kw["BLOCK_SIZE_K"],
            config.num_stages,
            config.num_warps,
            config.num_consumer_groups,
            kw.get("USE_TMA_LOAD_ON_SCALES", False),
        )
        G, M, N, K = (
            named_args["G"],
            named_args["M_BUCKET"],
            named_args["N"],
            named_args["K"],
        )

        # 1. make sure we have enough smem
        max_shared_memory = driver.active.utils.get_device_properties(device)[
            "max_shared_mem"
        ]
        if torch.version.hip:
            required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
        else:
            required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
        if required_shared_memory > max_shared_memory:
            continue

        use_warp_specialization = num_consumer_groups >= 1

        M_PER_GROUP = M // G
        MIN_M_TILES = 32 if torch.version.hip else 64
        # 2. make sure we don't load M tiles that are too big
        if (
            not use_warp_specialization
            and BLOCK_M > MIN_M_TILES
            and BLOCK_M > (M_PER_GROUP * 2)
        ):
            continue
        # 3. make sure we don't load N tiles that are too small
        if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
            continue

        num_sm = driver.active.utils.get_device_properties(device)[
            "multiprocessor_count"
        ]
        N_TILES = N // BLOCK_N
        MIN_N_TILES = 32 if torch.version.hip else 64
        # 4. make sure we don't load N tiles that are too big
        if (
            not use_warp_specialization
            and BLOCK_N > MIN_N_TILES
            and M * N_TILES < num_sm
        ):
            continue
        # 5. make sure we don't load N tiles that are too small
        if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
            continue

        # 6. make sure K can be evenly divided
        if K % BLOCK_K != 0:
            continue

        # 7. make sure we can partition for ws
        if use_warp_specialization:
            if num_warps != 4:
                continue

            # "tritongpu-warp-spec-data-partition"
            m_slice = BLOCK_M // num_consumer_groups
            n_slice = BLOCK_N // num_consumer_groups
            if m_slice < 64 and n_slice < 256:
                continue

        if dtsize >= 2:
            if use_tma_load_on_scales:
                continue
        pruned_configs.append(config)

    return pruned_configs


@triton.autotune(
    configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={"early_config_prune": early_config_prune},
    restore_value=["c_ptr"],  # restore for scatter_add fusion
)
@triton.jit
def _fbgemm_grouped_gemm(
    a_desc_ptr,
    b_desc_ptr,
    c_ptr,
    workspace,
    scatter_add_indices,
    m_sizes,
    # problem sizes
    G: tl.constexpr,
    M_BUCKET,
    N: tl.constexpr,
    K: tl.constexpr,
    NUM_SMS: tl.constexpr,
    FUSE_SCATTER_ADD: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
    USE_FAST_ACCUM: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    NUM_CONSUMER_GROUPS: tl.constexpr,
) -> None:
    tl.static_assert(
        not (FUSE_SCATTER_ADD and USE_TMA_STORE),
        "Cannot fuse scatter add with TMA store!",
    )

    tidx = tl.program_id(0)

    dtype: tl.dtype = c_ptr.dtype.element_ty
    TMA_SIZE: tl.constexpr = tl.constexpr(128)
    if USE_TMA_STORE:
        c_desc_ptr = workspace + tidx * TMA_SIZE
    else:
        c_desc_ptr = None

    M_end_offset = 0
    M_end_offset = M_end_offset.to(tl.int64)
    iterated_tiles = 0
    iterated_tiles = iterated_tiles.to(tl.int64)
    for g in tl.range(G):
        # Move across groups
        m_size = tl.load(m_sizes + g)

        if m_size > 0:
            M_start_offset = M_end_offset
            M_end_offset = M_start_offset + m_size
            N_start_offset = g.to(tl.int64) * N
            n_size = N

            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
            num_tiles = num_m_tiles * num_n_tiles

            if USE_TMA_STORE:
                # pyre-ignore
                tl.extra.cuda.experimental_device_tensormap_create2d(
                    desc_ptr=c_desc_ptr,
                    global_address=c_ptr + M_start_offset * N,
                    load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
                    global_size=[m_size, n_size],
                    element_ty=c_ptr.dtype.element_ty,
                )
                # pyre-ignore
                tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # Move across tiles
            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
                gidx = tidx - iterated_tiles
                # Split M first and N second.
                tile_m_idx = gidx % num_m_tiles
                tile_n_idx = gidx // num_m_tiles

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
                tl.static_assert(K % BLOCK_SIZE_K == 0)
                if USE_TMA_LOAD:
                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        a = tl._experimental_descriptor_load(
                            a_desc_ptr,
                            [m_offset, k_offset],
                            [BLOCK_SIZE_M, BLOCK_SIZE_K],
                            dtype,
                        )
                        b = tl._experimental_descriptor_load(
                            b_desc_ptr,
                            [n_offset, k_offset],
                            [BLOCK_SIZE_N, BLOCK_SIZE_K],
                            dtype,
                        )
                        if USE_FAST_ACCUM:
                            accumulator = tl.dot(a, b.T, accumulator)
                        else:
                            accumulator += tl.dot(a, b.T)
                else:
                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                    offs_k = tl.arange(0, BLOCK_SIZE_K)
                    a_ptrs = (
                        a_desc_ptr
                        + (M_start_offset + offs_am[:, None]) * K
                        + offs_k[None, :]
                    )
                    b_ptrs = (
                        b_desc_ptr
                        + (N_start_offset + offs_bn[:, None]) * K
                        + offs_k[None, :]
                    )
                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
                        b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
                        accumulator += tl.dot(a, b.T)
                        a_ptrs += BLOCK_SIZE_K
                        b_ptrs += BLOCK_SIZE_K

                if USE_TMA_STORE:
                    m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                    tl._experimental_descriptor_store(
                        c_desc_ptr,
                        accumulator.to(c_ptr.dtype.element_ty),
                        [m_offset, n_offset],
                    )
                elif FUSE_SCATTER_ADD:
                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                    mask = offs_am < m_size
                    m_offsets = tl.load(
                        scatter_add_indices + M_start_offset + offs_am,
                        mask=mask,
                        cache_modifier=".ca",
                    )
                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                    c = accumulator.to(c_ptr.dtype.element_ty)
                    tl.atomic_add(
                        c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
                        c,
                        mask=mask[:, None] and offs_bn[None, :] < n_size,
                        sem="relaxed",
                    )
                else:
                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                    c = accumulator.to(c_ptr.dtype.element_ty)
                    tl.store(
                        c_ptr
                        + (M_start_offset + offs_am[:, None]) * N
                        + offs_bn[None, :],
                        c,
                        mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
                    )
                tidx += NUM_SMS

            iterated_tiles += num_tiles


# TODO(shikaili): Too much code duplication. Need to refactor.
@triton.autotune(
    configs=_NV_WS_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={"early_config_prune": early_config_prune},
    restore_value=["c_ptr"],  # restore for scatter_add fusion
)
@triton.jit
def _fbgemm_grouped_gemm_ws(
    a_desc_ptr,
    b_desc_ptr,
    c_ptr,
    workspace,
    scatter_add_indices,
    m_sizes,
    # problem sizes
    G: tl.constexpr,
    M_BUCKET: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    NUM_SMS: tl.constexpr,
    FUSE_SCATTER_ADD: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_FAST_ACCUM: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    NUM_CONSUMER_GROUPS: tl.constexpr,
    USE_TMA_LOAD_ON_SCALES: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
) -> None:
    tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
    tl.static_assert(not USE_TMA_LOAD_ON_SCALES, "Not supported!")
    tl.static_assert(
        not (FUSE_SCATTER_ADD and USE_TMA_STORE),
        "Cannot fuse scatter add with TMA store!",
    )

    tidx = tl.program_id(0)

    dtype: tl.dtype = c_ptr.dtype.element_ty
    TMA_SIZE: tl.constexpr = tl.constexpr(128)
    if USE_TMA_STORE:
        c_desc_ptr = workspace + tidx * TMA_SIZE
    else:
        c_desc_ptr = None

    M_end_offset = 0
    M_end_offset = M_end_offset.to(tl.int64)
    iterated_tiles = 0
    iterated_tiles = iterated_tiles.to(tl.int64)
    for g in tl.range(G):
        # Move across groups
        m_size = tl.load(m_sizes + g, cache_modifier=".ca")

        if m_size > 0:
            M_start_offset = M_end_offset
            M_end_offset = M_start_offset + m_size
            N_start_offset = g.to(tl.int64) * N

            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            tl.static_assert(N % BLOCK_SIZE_N == 0)
            NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
            num_tiles = num_m_tiles * NUM_N_TILES

            if USE_TMA_STORE:
                with tl.async_task([0]):
                    # pyre-ignore
                    tl.extra.cuda.experimental_device_tensormap_create2d(
                        desc_ptr=c_desc_ptr,
                        global_address=c_ptr + M_start_offset * N,
                        load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
                        global_size=[m_size, N],
                        element_ty=c_ptr.dtype.element_ty,
                    )
                    # pyre-ignore
                    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # Move across tiles
            next_iterated_tiles = iterated_tiles + num_tiles
            if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
                for i in range(tidx, next_iterated_tiles, NUM_SMS):
                    gidx = i - iterated_tiles
                    # Split M first and N second.
                    tile_m_idx = gidx % num_m_tiles
                    tile_n_idx = gidx // num_m_tiles

                    accumulator = tl.zeros(
                        (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
                    )
                    tl.static_assert(K % BLOCK_SIZE_K == 0)

                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        with tl.async_task([0]):
                            a = tl._experimental_descriptor_load(
                                a_desc_ptr,
                                [m_offset, k_offset],
                                [BLOCK_SIZE_M, BLOCK_SIZE_K],
                                dtype,
                            )
                            b = tl._experimental_descriptor_load(
                                b_desc_ptr,
                                [n_offset, k_offset],
                                [BLOCK_SIZE_N, BLOCK_SIZE_K],
                                dtype,
                            )
                        with tl.async_task([1, NUM_CONSUMER_GROUPS]):
                            if USE_FAST_ACCUM:
                                accumulator = tl.dot(a, b.T, accumulator)
                            else:
                                accumulator += tl.dot(a, b.T)

                    if USE_TMA_STORE:
                        with tl.async_task([1, NUM_CONSUMER_GROUPS]):
                            m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                            n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                            tl._experimental_descriptor_store(
                                c_desc_ptr,
                                accumulator.to(c_ptr.dtype.element_ty),
                                [m_offset, n_offset],
                            )
                    elif FUSE_SCATTER_ADD:
                        with tl.async_task([1, NUM_CONSUMER_GROUPS]):
                            offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
                                0, BLOCK_SIZE_M
                            )
                            mask = offs_am < m_size
                            m_offsets = tl.load(
                                scatter_add_indices + M_start_offset + offs_am,
                                mask=mask,
                                cache_modifier=".ca",
                            )
                            offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
                                0, BLOCK_SIZE_N
                            )
                            c = accumulator.to(c_ptr.dtype.element_ty)
                            tl.atomic_add(
                                c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
                                c,
                                mask=mask[:, None],
                                sem="relaxed",
                            )
                    else:
                        with tl.async_task([1, NUM_CONSUMER_GROUPS]):
                            offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
                                0, BLOCK_SIZE_M
                            )
                            offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
                                0, BLOCK_SIZE_N
                            )
                            c = accumulator.to(c_ptr.dtype.element_ty)
                            tl.store(
                                c_ptr
                                + (M_start_offset + offs_am[:, None]) * N
                                + offs_bn[None, :],
                                c,
                                mask=offs_am[:, None] < m_size,
                                cache_modifier=".cs",
                            )
                    tidx += NUM_SMS

            iterated_tiles += num_tiles


TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv


# TODO(shikaili): clean up redundant 'b_scale_desc_ptr' argument.
@triton.autotune(
    configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={
        "early_config_prune": functools.partial(
            early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
        )
    },
    restore_value=["c_ptr"],  # restore for scatter_add fusion
)
@triton.jit
def _fbgemm_grouped_gemm_fp8_rowwise(
    a_desc_ptr,
    a_scale_ptr,
    b_desc_ptr,
    b_scale_ptr,
    b_scale_desc_ptr,
    c_ptr,
    workspace,
    scatter_add_indices,
    m_sizes,
    # problem sizes
    G: tl.constexpr,
    M_BUCKET,
    N: tl.constexpr,
    K: tl.constexpr,
    NUM_SMS: tl.constexpr,
    FUSE_SCATTER_ADD: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
    USE_FAST_ACCUM: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    NUM_CONSUMER_GROUPS: tl.constexpr,
) -> None:
    tl.static_assert(
        not (FUSE_SCATTER_ADD and USE_TMA_STORE),
        "Cannot fuse scatter add with TMA store!",
    )

    tidx = tl.program_id(0)

    dtype = TT_FP8_DTYPE
    TMA_SIZE: tl.constexpr = tl.constexpr(128)
    if USE_TMA_STORE:
        c_desc_ptr = workspace + tidx * TMA_SIZE
    else:
        c_desc_ptr = None

    M_end_offset = 0
    M_end_offset = M_end_offset.to(tl.int64)
    iterated_tiles = 0
    iterated_tiles = iterated_tiles.to(tl.int64)
    for g in tl.range(G):
        # Move across groups
        m_size = tl.load(m_sizes + g)

        if m_size > 0:
            M_start_offset = M_end_offset
            M_end_offset = M_start_offset + m_size
            N_start_offset = g.to(tl.int64) * N
            n_size = N

            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
            num_tiles = num_m_tiles * num_n_tiles

            if USE_TMA_STORE:
                # pyre-ignore
                tl.extra.cuda.experimental_device_tensormap_create2d(
                    desc_ptr=c_desc_ptr,
                    global_address=c_ptr + M_start_offset * N,
                    load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
                    global_size=[m_size, n_size],
                    element_ty=c_ptr.dtype.element_ty,
                )
                # pyre-ignore
                tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # Move across tiles
            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
                gidx = tidx - iterated_tiles
                # Split M first and N second.
                tile_m_idx = gidx % num_m_tiles
                tile_n_idx = gidx // num_m_tiles

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
                tl.static_assert(K % BLOCK_SIZE_K == 0)
                if USE_TMA_LOAD:
                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        a = tl._experimental_descriptor_load(
                            a_desc_ptr,
                            [m_offset, k_offset],
                            [BLOCK_SIZE_M, BLOCK_SIZE_K],
                            dtype,
                        )
                        b = tl._experimental_descriptor_load(
                            b_desc_ptr,
                            [n_offset, k_offset],
                            [BLOCK_SIZE_N, BLOCK_SIZE_K],
                            dtype,
                        )
                        if USE_FAST_ACCUM:
                            accumulator = tl.dot(a, b.T, accumulator)
                        else:
                            accumulator += tl.dot(a, b.T)
                else:
                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                    offs_k = tl.arange(0, BLOCK_SIZE_K)
                    a_ptrs = (
                        a_desc_ptr
                        + (M_start_offset + offs_am[:, None]) * K
                        + offs_k[None, :]
                    )
                    b_ptrs = (
                        b_desc_ptr
                        + (N_start_offset + offs_bn[:, None]) * K
                        + offs_k[None, :]
                    )
                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
                        b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
                        accumulator += tl.dot(a, b.T)
                        a_ptrs += BLOCK_SIZE_K
                        b_ptrs += BLOCK_SIZE_K

                offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                a_scale = tl.load(
                    a_scale_ptr + M_start_offset + offs_am[:, None],
                    mask=offs_am[:, None] < m_size,
                )
                b_scale = tl.load(
                    b_scale_ptr + N_start_offset + offs_bn[None, :],
                    mask=offs_bn[None, :] < n_size,
                )
                c = accumulator.to(tl.float32) * a_scale * b_scale

                if USE_TMA_STORE:
                    m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                    tl._experimental_descriptor_store(
                        c_desc_ptr,
                        c.to(c_ptr.dtype.element_ty),
                        [m_offset, n_offset],
                    )
                elif FUSE_SCATTER_ADD:
                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                    mask = offs_am < m_size
                    m_offsets = tl.load(
                        scatter_add_indices + M_start_offset + offs_am,
                        mask=mask,
                        cache_modifier=".ca",
                    )
                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                    tl.atomic_add(
                        c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
                        c.to(c_ptr.dtype.element_ty),
                        mask=mask[:, None] and offs_bn[None, :] < n_size,
                        sem="relaxed",
                    )
                else:
                    tl.store(
                        c_ptr
                        + (M_start_offset + offs_am[:, None]) * N
                        + offs_bn[None, :],
                        c,
                        mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
                    )
                tidx += NUM_SMS

            iterated_tiles += num_tiles


# TODO(shikaili): Too much code duplication. Need to refactor.
@triton.autotune(
    configs=_NV_WS_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={
        "early_config_prune": functools.partial(
            early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
        )
    },
    restore_value=["c_ptr"],  # restore for scatter_add fusion
)
@triton.jit
def _fbgemm_grouped_gemm_fp8_rowwise_ws(
    a_desc_ptr,
    a_scale_ptr,
    b_desc_ptr,
    b_scale_ptr,
    b_scale_desc_ptr,
    c_ptr,
    workspace,
    scatter_add_indices,
    m_sizes,
    # problem sizes
    G: tl.constexpr,
    M_BUCKET: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    NUM_SMS: tl.constexpr,
    FUSE_SCATTER_ADD: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_FAST_ACCUM: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    NUM_CONSUMER_GROUPS: tl.constexpr,
    USE_TMA_LOAD_ON_SCALES: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
) -> None:
    tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
    tl.static_assert(
        not (FUSE_SCATTER_ADD and USE_TMA_STORE),
        "Cannot fuse scatter add with TMA store!",
    )

    tidx = tl.program_id(0)

    dtype = TT_FP8_DTYPE
    TMA_SIZE: tl.constexpr = tl.constexpr(128)
    if USE_TMA_STORE:
        c_desc_ptr = workspace + tidx * TMA_SIZE
    else:
        c_desc_ptr = None

    M_end_offset = 0
    M_end_offset = M_end_offset.to(tl.int64)
    iterated_tiles = 0
    iterated_tiles = iterated_tiles.to(tl.int64)
    for g in tl.range(G):
        # Move across groups
        m_size = tl.load(m_sizes + g, cache_modifier=".ca")

        if m_size > 0:
            M_start_offset = M_end_offset
            M_end_offset = M_start_offset + m_size
            N_start_offset = g.to(tl.int64) * N

            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            tl.static_assert(N % BLOCK_SIZE_N == 0)
            NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
            num_tiles = num_m_tiles * NUM_N_TILES

            if USE_TMA_STORE:
                with tl.async_task([0]):
                    # pyre-ignore
                    tl.extra.cuda.experimental_device_tensormap_create2d(
                        desc_ptr=c_desc_ptr,
                        global_address=c_ptr + M_start_offset * N,
                        load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
                        global_size=[m_size, N],
                        element_ty=c_ptr.dtype.element_ty,
                    )
                    # pyre-ignore
                    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # Move across tiles
            next_iterated_tiles = iterated_tiles + num_tiles
            if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
                for i in range(tidx, next_iterated_tiles, NUM_SMS):
                    gidx = i - iterated_tiles
                    # Split M first and N second.
                    tile_m_idx = gidx % num_m_tiles
                    tile_n_idx = gidx // num_m_tiles

                    accumulator = tl.zeros(
                        (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
                    )
                    tl.static_assert(K % BLOCK_SIZE_K == 0)

                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        with tl.async_task([0]):
                            a = tl._experimental_descriptor_load(
                                a_desc_ptr,
                                [m_offset, k_offset],
                                [BLOCK_SIZE_M, BLOCK_SIZE_K],
                                dtype,
                            )
                            b = tl._experimental_descriptor_load(
                                b_desc_ptr,
                                [n_offset, k_offset],
                                [BLOCK_SIZE_N, BLOCK_SIZE_K],
                                dtype,
                            )
                        with tl.async_task([1, NUM_CONSUMER_GROUPS]):
                            if USE_FAST_ACCUM:
                                accumulator = tl.dot(a, b.T, accumulator)
                            else:
                                accumulator += tl.dot(a, b.T)

                    if USE_TMA_LOAD_ON_SCALES:
                        with tl.async_task([0]):
                            b_scale = tl._experimental_descriptor_load(
                                b_scale_desc_ptr,
                                [n_offset],
                                [BLOCK_SIZE_N],
                                tl.float32,
                            )

                        with tl.async_task([1, NUM_CONSUMER_GROUPS]):
                            offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
                                0, BLOCK_SIZE_M
                            )
                            a_scale = tl.load(
                                a_scale_ptr + M_start_offset + offs_am[:, None],
                                mask=offs_am[:, None] < m_size,
                                cache_modifier=".ca",
                            )
                            c = accumulator.to(tl.float32) * a_scale * b_scale[None, :]
                    else:
                        with tl.async_task([1, NUM_CONSUMER_GROUPS]):
                            offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
                                0, BLOCK_SIZE_M
                            )
                            offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
                                0, BLOCK_SIZE_N
                            )
                            a_scale = tl.load(
                                a_scale_ptr + M_start_offset + offs_am[:, None],
                                mask=offs_am[:, None] < m_size,
                                cache_modifier=".ca",
                            )
                            b_scale = tl.load(
                                b_scale_ptr + N_start_offset + offs_bn[None, :],
                                cache_modifier=".ca",
                            )
                            c = accumulator.to(tl.float32) * a_scale * b_scale

                    if USE_TMA_STORE:
                        with tl.async_task([1, NUM_CONSUMER_GROUPS]):
                            m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                            n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                            tl._experimental_descriptor_store(
                                c_desc_ptr,
                                c.to(c_ptr.dtype.element_ty),
                                [m_offset, n_offset],
                            )
                    elif FUSE_SCATTER_ADD:
                        with tl.async_task([1, NUM_CONSUMER_GROUPS]):
                            offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
                                0, BLOCK_SIZE_M
                            )
                            mask = offs_am < m_size
                            m_offsets = tl.load(
                                scatter_add_indices + M_start_offset + offs_am,
                                mask=mask,
                                cache_modifier=".ca",
                            )
                            offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
                                0, BLOCK_SIZE_N
                            )
                            tl.atomic_add(
                                c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
                                c,
                                mask=mask[:, None],
                                sem="relaxed",
                            )
                    else:
                        with tl.async_task([1, NUM_CONSUMER_GROUPS]):
                            offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
                                0, BLOCK_SIZE_M
                            )
                            offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
                                0, BLOCK_SIZE_N
                            )
                            tl.store(
                                c_ptr
                                + (M_start_offset + offs_am[:, None]) * N
                                + offs_bn[None, :],
                                c,
                                mask=offs_am[:, None] < m_size,
                                cache_modifier=".cs",
                            )
                    tidx += NUM_SMS

            iterated_tiles += num_tiles


warnings.simplefilter("once")


def _grouped_gemm(
    *,
    x: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
    x_scale: Optional[torch.Tensor],
    w_scale: Optional[torch.Tensor],
    use_fast_accum: bool,
    use_warp_specialization: bool,
    output_tensor: Optional[torch.Tensor],
    scatter_add_indices: Optional[torch.Tensor],
) -> torch.Tensor:

    USE_TMA_LOAD = not torch.version.hip
    USE_TMA_STORE = False

    if USE_TMA_LOAD and not HAS_TMA_DESC:
        USE_TMA_LOAD = False
        warnings.warn("TMA load is disabled as there is no TMA descriptor support!")

    if USE_TMA_STORE and not HAS_TMA_DESC:
        USE_TMA_STORE = False
        warnings.warn("TMA store is disabled as there is no TMA descriptor support!")

    # TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton.
    if use_warp_specialization and torch.version.hip:
        warnings.warn("Warp specialization is disabled as it is not supported on ROCm.")
        use_warp_specialization = False

    if use_warp_specialization and not _HAS_WS_SUPPORT:
        warnings.warn(
            "Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs."
        )
        use_warp_specialization = False

    if use_warp_specialization:
        assert HAS_TMA_DESC
        USE_TMA_STORE = True  # Tuning decision

    G = m_sizes.shape[0]

    assert x.is_contiguous()
    assert w.is_contiguous()
    assert m_sizes.is_contiguous()

    M, K = x.shape
    N = w.shape[0] // G
    assert K == w.shape[1]

    if output_tensor is None:
        FUSE_SCATTER_ADD = False
        assert scatter_add_indices is None
        y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
    else:
        FUSE_SCATTER_ADD = True
        assert scatter_add_indices is not None
        assert scatter_add_indices.is_contiguous()
        assert scatter_add_indices.shape == (M,)
        y = output_tensor
    if M == 0 or N == 0:
        return y

    NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count

    desc_helper = None
    desc_x = x
    desc_w = w
    desc_ws = w_scale
    workspace = None

    if USE_TMA_LOAD:
        desc_helper = TmaAutoTuneHelper()
        desc_helper.init_tma_descriptor("x")
        desc_helper.init_tma_descriptor("w")
        desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
        desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
        if use_warp_specialization and w_scale is not None:
            desc_helper.init_tma_descriptor("ws")
            desc_ws = desc_helper.get_tma_descriptor_kernel_param("ws")

    if USE_TMA_STORE:
        workspace = torch.empty(
            NUM_SMS * TmaAutoTuneHelper.TMA_SIZE,
            device=x.device,
            dtype=torch.uint8,
        )

    def grid(META):
        if USE_TMA_LOAD:
            nonlocal desc_helper  # noqa: F824
            desc_helper.fill_2d_tma_descriptor(
                "x",
                x.data_ptr(),
                M,
                K,
                META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"],
                META["BLOCK_SIZE_K"],
                x.element_size(),
            )

            desc_helper.fill_2d_tma_descriptor(
                "w",
                w.data_ptr(),
                N * G,
                K,
                META["BLOCK_SIZE_N"],
                META["BLOCK_SIZE_K"],
                w.element_size(),
            )

            if META.get("USE_TMA_LOAD_ON_SCALES", False):
                desc_helper.fill_1d_tma_descriptor(
                    "ws",
                    w_scale.data_ptr(),
                    N * G,
                    META["BLOCK_SIZE_N"],
                    w_scale.element_size(),
                )

        return (NUM_SMS,)

    M_BUCKET_CAP = 16384
    M_BUCKET = min(triton.next_power_of_2(M), M_BUCKET_CAP)
    if x_scale is not None and w_scale is not None:
        assert x_scale.is_contiguous()
        assert w_scale.is_contiguous()
        fn = (
            _fbgemm_grouped_gemm_fp8_rowwise_ws
            if use_warp_specialization
            else _fbgemm_grouped_gemm_fp8_rowwise
        )
        args = (
            desc_x,
            x_scale,
            desc_w,
            w_scale,
            desc_ws,
            y,
            workspace,
            scatter_add_indices,
            m_sizes,
            G,
            M_BUCKET,
            N,
            K,
            NUM_SMS,
            FUSE_SCATTER_ADD,
            USE_TMA_LOAD,
        )
        if use_warp_specialization:
            args += (use_fast_accum,)
        else:
            args += (USE_TMA_STORE, use_fast_accum)
        fn[grid](*args)
    else:
        assert x_scale is None
        assert w_scale is None
        fn = (
            _fbgemm_grouped_gemm_ws if use_warp_specialization else _fbgemm_grouped_gemm
        )
        args = (
            desc_x,
            desc_w,
            y,
            workspace,
            scatter_add_indices,
            m_sizes,
            G,
            M_BUCKET,
            N,
            K,
            NUM_SMS,
            FUSE_SCATTER_ADD,
            USE_TMA_LOAD,
        )
        if use_warp_specialization:
            args += (use_fast_accum,)
        else:
            args += (USE_TMA_STORE, use_fast_accum)
        fn[grid](*args)

    return y


def grouped_gemm(
    x: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
    use_fast_accum: bool = True,
    *,
    _use_warp_specialization: bool = True,
    _output_tensor: Optional[torch.Tensor] = None,
    _scatter_add_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return _grouped_gemm(
        x=x,
        w=w,
        m_sizes=m_sizes,
        x_scale=None,
        w_scale=None,
        use_fast_accum=use_fast_accum,
        use_warp_specialization=_use_warp_specialization,
        output_tensor=_output_tensor,
        scatter_add_indices=_scatter_add_indices,
    )


def grouped_gemm_fp8_rowwise(
    x: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
    x_scale: torch.Tensor,
    w_scale: torch.Tensor,
    use_fast_accum: bool = True,
    *,
    _use_warp_specialization: bool = True,
    _output_tensor: Optional[torch.Tensor] = None,
    _scatter_add_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return _grouped_gemm(
        x=x,
        w=w,
        m_sizes=m_sizes,
        x_scale=x_scale,
        w_scale=w_scale,
        use_fast_accum=use_fast_accum,
        use_warp_specialization=_use_warp_specialization,
        output_tensor=_output_tensor,
        scatter_add_indices=_scatter_add_indices,
    )