machete_mainloop.cuh 63.7 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
//
// Based off of:
//   cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp
// Specifically:
//   https://github.com/NVIDIA/cutlass/tree/06b21349bcf6ddf6a1686a47a137ad1446579db9/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp
// Referred to as upstream from in the comments
//
// The main optimization machete implements compared to upstream is to prepack
// the weight matrix to more closely match the shape of the wgmma instructions
// allowing for wider (ideally 128bit) shared memory loads. For subbyte types
// this is done by packing values from multiple wgmma loads (for a single
// thread) into a single 128bit load. This is very similar to layout used in
// Marlin, although specific to the wgmma instructions.
//
// Since the wgmma instructions only support sourcing from registers for the A
// operand, and we want to upconvert/decompress the weight values/elements
// before feeding them into the tensor cores in registers, we need the weight
// matrix to be A. To achieve this we compute the transpose of Y = XW^t as
// Y^t = W^tX^t. This is mostly done outside of this file in
// csrc/quantization/machete/machete_mm_kernel.cuh, but this why A is the
// quantized/narrow type and has the prepacked layout despite the API being:
//   B_prepacked = machete_prepack_B(B)
//   Y = machete_mm(A, B_prepacked)
//
#pragma once

// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/detail/dependent_false.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/detail/layout.hpp"

#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/atom/copy_traits_sm90_tma.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp"
#include "cutlass/trace.h"

#include "cutlass/detail/collective.hpp"
// clang-format on

#include "cutlass_extensions/cute_utils.cuh"

namespace machete {

using namespace cute;
using namespace cutlass;
using namespace cutlass::gemm;
using namespace cutlass::gemm::collective;
using namespace cutlass::gemm::collective::detail;

template <class ElementATuple_, class GmemLayoutA, int AlignmentA,
          class ElementB_, class GmemLayoutB, int AlignmentB,
          class ElementAccumulator_, class TileShape_MNK,
          class ClusterShape_MNK, class StageCountType,
          class KernelScheduleType>
struct MacheteCollectiveMma {
  using Schedule = KernelScheduleType;
  static_assert(
      cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
          cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
          cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
          cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
          cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
          cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
      "KernelSchedule must be one of the warp specialized policies");

 public:
  static constexpr bool ALayoutIsPrepacked = true;

  // Prepacked block shape (N is M in the transposed problem)
  using PPBlockShape_MK = typename GmemLayoutA::PPBlockShape_NK;
  // Prepacked blocks per dim for a single MMA tile
  using PPBlocksPerTile_MK = decltype(make_shape(
      size<0>(TileShape_MNK{}) / size<0>(PPBlockShape_MK{}),
      size<2>(TileShape_MNK{}) / size<1>(PPBlockShape_MK{})));

  using IlvdBlkLayout = typename GmemLayoutA::IlvdBlkLayout;

  static_assert(size<0>(TileShape_MNK{}) % size<0>(PPBlockShape_MK{}) == 0,
                "M in PPBlockShape_MK must evenly divide M TileShape_MNK");
  static_assert(size<2>(TileShape_MNK{}) % size<1>(PPBlockShape_MK{}) == 0,
                "K in PPBlockShape_MK must evenly divide K TileShape_MNK");

  using ArchTag = arch::Sm90;
  using TileShape = TileShape_MNK;
  using ClusterShape = ClusterShape_MNK;
  using ElementA = deduce_mixed_width_dtype_t<0, ElementATuple_>;
  using StrideA = TagToStrideA_t<layout::RowMajor>;
  using ElementB = ElementB_;
  using StrideB = TagToStrideB_t<GmemLayoutB>;
  using ElementAccumulator = ElementAccumulator_;
  using ElementMma = ElementB;
  using ElementATuple =
      cute::conditional_t<!cute::is_tuple<ElementATuple_>::value,
                          cute::tuple<ElementA>, ElementATuple_>;

  static constexpr cute::GMMA::Major GmmaMajorA =
      gmma_rs_tag_to_major_A<layout::RowMajor>();
  static constexpr cute::GMMA::Major GmmaMajorB =
      gmma_rs_tag_to_major_B<GmemLayoutB>();

  // For coop schedules we have two warp groups cooperatively issuing wgmma
  // instructions so we use 2 atoms along the M dim (one for each warpgroup)
  using AtomLayoutMNK = cute::conditional_t<
      cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
      Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;

  using TiledMma = decltype(cute::make_tiled_mma(
      cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
                                 TileShape_MNK, GMMA::Major::K, GmmaMajorB>(),
      AtomLayoutMNK{}));

 private:
  //
  // the setup section (until "section setup end") contains a combination of
  // modified code from (used as a starting point):
  //   `cutlass/gemm/collective/builders/sm90_gmma_builder.inl`
  //   `cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp`
  //   (upstream)
  //
  // however in-order to simplify the code we combine a lot of the logic from
  // `CollectiveMma` and `CollectiveBuilder` into this class, this also makes
  // sense given that we have flexibility on layouts here. We also simplify the
  // code by only supporting scales and zeros for A (in the transposed problem,
  // B from an API perspective), also since we force A to be the narrow type
  // (i.e. the type to be upconverted) we can remove all the `SwapAB` logic in
  // the upstream also simplifying the code. This section includes new logic
  // (compared ustream) for handling the prepacked-A layouts (in the transposed
  // problem, B from an API perspective)
  //
  using ElementScale = deduce_mixed_width_dtype_t<1, ElementATuple_>;
  using ElementZero = deduce_mixed_width_dtype_t<2, ElementATuple_>;

  static constexpr bool IsANarrow = cutlass::sizeof_bits<ElementA>::value <
                                    cutlass::sizeof_bits<ElementB>::value;
  static_assert(IsANarrow,
                "A must be the narrow one since its the one that flows through "
                "registers.");

 public:
  static constexpr int PipelineStages =
      compute_stage_count_or_override_single_affine_transformed_input<
          sm90_smem_capacity_bytes, ElementA, ElementB, ElementScale,
          ElementZero, TileShape_MNK>(StageCountType{});

  struct DispatchPolicy {
    constexpr static int Stages = PipelineStages;
    using ClusterShape = ClusterShape_MNK;
    using Schedule = KernelScheduleType;
  };

  using GmemTiledCopyA =
      decltype(sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  using GmemTiledCopyB =
      decltype(sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));

  // ((T, V), (BlocksM, BlocksK), pipe) -> offset
  using SmemLayoutA = decltype(GmemLayoutA::TVbNbKL_to_offset(
      make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
                 Int<DispatchPolicy::Stages>{})));

  using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy(
      make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
                 Int<DispatchPolicy::Stages>{})));

  using SmemLayoutAtomARowMajor =
      decltype(rs_smem_selector<GmmaMajorA, ElementA,
                                decltype(cute::get<0>(TileShape_MNK{})),
                                decltype(cute::get<2>(TileShape_MNK{}))>());

  using SmemLayoutAtomScale = Layout<
      Shape<decltype(cute::shape<0>(SmemLayoutAtomARowMajor{})), cute::Int<1>>>;

  using SmemLayoutAtomB =
      decltype(rs_smem_selector<GmmaMajorB, ElementB,
                                decltype(cute::get<1>(TileShape_MNK{})),
                                decltype(cute::get<2>(TileShape_MNK{}))>());

  using SmemCopyAtomA = Copy_Atom<cute::DefaultCopy, ElementA>;
  using SmemCopyAtomB = void;

  //
  //  Validity checks
  //
  static_assert(is_static<TileShape_MNK>::value);
  static_assert(is_static<ClusterShape_MNK>::value);
  static_assert(is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
                           tma_alignment_bytes>(),
                "Should meet TMA alignment requirement\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
  static_assert(cutlass::detail::dependent_false<ElementA>,
                "Unsupported Toolkit for SM90 Collective Builder\n");
#endif

 private:
  enum class ConversionMode {
    DirectConvert,
    ConvertAndScale,
    ConvertAndScaleWithZero
  };

 public:
  //
  // Type Aliases
  //
  using KernelSchedule = KernelScheduleType;

  // For cases where we can't have a void type, we can use this to allow the
  // code to compile when the scale / zero is void.
  using NonVoidElementScale =
      cute::conditional_t<cute::is_void_v<ElementScale>, float, ElementScale>;
  using NonVoidElementZero =
      cute::conditional_t<cute::is_void_v<ElementZero>, float, ElementZero>;

  // These are always MN major
  using StrideScale = cute::Stride<cute::Int<1>, int64_t, int64_t>;
  // For cases where we can't have a void scale, we can use this to allow the
  // code to compile when the scale is void.
  using NonVoidStrideScale =
      cute::conditional_t<cute::is_void_v<StrideScale>,
                          cute::Stride<_1, int64_t, int64_t>, StrideScale>;

  static_assert((cutlass::gemm::detail::is_k_major<StrideA>()),
                "The transformed matrix (A) must be K-major.");

  static_assert((sizeof(ElementB) == 2) ||
                    (cutlass::gemm::detail::is_k_major<StrideA>() &&
                     cutlass::gemm::detail::is_k_major<StrideB>()),
                "The unscaled element (matrix B) must be 2 bytes OR both "
                "inputs must be K-major");

  static_assert(cutlass::gemm::detail::is_mn_major<NonVoidStrideScale>(),
                "Scale must be MN major [Col Major if A is scaled, Row Major "
                "if B is scaled].");

  static_assert(std::is_same_v<typename TiledMma::ValTypeC, ElementAccumulator>,
                "TiledMma::ValTypeC must be the same as ElementAccumulator.");

  using GmemTiledCopyScale = cute::SM90_TMA_LOAD;

  using SmemCopyAtomScale = Copy_Atom<cute::DefaultCopy, NonVoidElementScale>;

  // TMA converts f32 input to tf32 when copying from GMEM to SMEM
  // For all other types, cast to size equivalent uint type to avoid any
  // rounding by TMA.
  static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
  static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
  using InternalElementA =
      cute::conditional_t<ConvertF32toTF32A, tfloat32_t,
                          uint_bit_t<sizeof_bits_v<ElementA>>>;
  using InternalElementB =
      cute::conditional_t<ConvertF32toTF32B, tfloat32_t,
                          uint_bit_t<sizeof_bits_v<ElementB>>>;

  using TransformA = cute::identity;
  using TransformB = cute::identity;

  static constexpr int IsSubbyteA = cute::sizeof_bits_v<InternalElementA> < 8;
  using TmaElementA =
      cute::conditional_t<IsSubbyteA, uint8_t, InternalElementA>;

  using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
  using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;

  using PipelineParams = typename MainloopPipeline::Params;

  // One threads per CTA are producers (1 for operand tile)
  static constexpr int NumProducerThreadEvents = 1;

  using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
                                             shape<1>(SmemLayoutAtomScale{})));

  static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
                "SmemLayoutAtom must be rank 2 (M/N, K)");
  static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
                "SmemLayoutAtom must evenly divide tile shape.");
  static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
                "SmemLayoutAtom must evenly divide tile shape.");

  static_assert(rank(SmemLayoutAtomScale{}) == 2,
                "SmemLayoutAtomScale must be rank 2");
  static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0,
                "SmemLayoutAtomScale must equal the tile shape.");
  static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0,
                "SmemLayoutAtomScale must evenly divide tile k shape.");

  // Tile along modes in a way that maximizes the TMA box size
  using SmemLayoutB = decltype(tile_to_shape(
      SmemLayoutAtomB{},
      make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
                 Int<DispatchPolicy::Stages>{}),
      conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
                    Step<_2, _1, _3>, Step<_1, _2, _3>>{}));

  // It is assumed that the scales and zero-points share the same smem layout
  using SmemLayoutScale = decltype(tile_to_shape(
      SmemLayoutAtomScale{},
      make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}),
                 Int<PipelineStages>{})));

  // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major
  // only (e.g. tf32, fp32, fp8, int8).
  static constexpr bool IsLayoutAmnBmn =
      cute::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>,
                      layout::ColumnMajor> &&
      cute::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>,
                      layout::RowMajor>;

  static_assert(DispatchPolicy::Stages >= 2,
                "Specialization requires Stages set to value 2 or more.");
  static_assert(not cute::is_base_of<cute::GMMA::DescriptorIterator,
                                     typename TiledMma::FrgTypeA>::value &&
                    cute::is_base_of<cute::GMMA::DescriptorIterator,
                                     typename TiledMma::FrgTypeB>::value,
                "MMA atom must source A from rmem and B operand from smem_desc "
                "for this mainloop.");
  static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
                    cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
                "GmemTiledCopy - invalid SM90 TMA copy atom specified.");
  static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
                    cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
                "GmemTiledCopy - invalid SM90 TMA copy atom specified.");

  using GmmaSmemLayoutB = decltype(tile_to_shape(
      SmemLayoutAtomB{},
      make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
                 Int<DispatchPolicy::Stages>{}),
      conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
                    Step<_2, _1, _3>, Step<_1, _2, _3>>{}));

  // These two restrictions are related, so we place the assertions together.
  // To relax them, we need to handle loading more than 1 row of scales for
  // every main loop iteration. We must also handle updating the pipeline
  // transaction bytes on the fly. NOTE: Deleting this assertion without
  // required changes will cause the code to hang.
  static_assert(size<1>(SmemLayoutAtomScale{}) == 1,
                "size<1>(SmemLayoutAtomScale) must be 1.");

 private:
  static constexpr ConversionMode get_conversion_mode() {
    if constexpr (cute::is_void_v<ElementScale>) {
      return ConversionMode::DirectConvert;
    } else if constexpr (cute::is_void_v<ElementZero>) {
      return ConversionMode::ConvertAndScale;
    } else {
      return ConversionMode::ConvertAndScaleWithZero;
    }
  }

  static constexpr ConversionMode KernelConversionMode = get_conversion_mode();
  static constexpr bool ModeHasScales =
      KernelConversionMode == ConversionMode::ConvertAndScale ||
      KernelConversionMode == ConversionMode::ConvertAndScaleWithZero;

  // Same as upstream, should be kept the same when possible
  static constexpr auto elements_per_smem_scale() {
    if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
      return 0;
    } else if constexpr (ModeHasScales) {
      return cute::cosize_v<SmemLayoutScale>;
    } else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>,
                    "Type not handled in scale smem allocation.");
    }
  }

  // Same as upstream, should be kept the same when possible
  static constexpr auto elements_per_smem_zero() {
    if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
                  KernelConversionMode == ConversionMode::ConvertAndScale) {
      return 0;
    } else if constexpr (KernelConversionMode ==
                         ConversionMode::ConvertAndScaleWithZero) {
      return cute::cosize_v<SmemLayoutScale>;
    } else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>,
                    "Type not handled in scale smem allocation.");
    }
  }

  // Same as upstream, should be kept the same when possible, not formatte for
  // easier comparison
  // clang-format off
  // These methods use some the public members of the class. For that reason, we define them after the public section.
  static constexpr uint32_t
  compute_tma_transaction_bytes_mk() {
    constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(cute::sizeof_bits_v<InternalElementA>));

    if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
      return baseline_bytes;
    }
    else if constexpr (ModeHasScales) {
      constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast<uint32_t>(cute::sizeof_bits_v<ElementScale>));
      static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA
      if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
        return baseline_bytes + scale_tx_bytes;
      }
      else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
        // Scale and zero share smem layout
        constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast<uint32_t>(cute::sizeof_bits_v<ElementZero>));
        static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA
        return baseline_bytes + scale_tx_bytes + zero_tx_bytes;
      }
      else {
        static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
      }
    }
    else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
    }
  }

  static constexpr uint32_t
  compute_tma_transaction_bytes_nk() {
    return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(cute::sizeof_bits_v<InternalElementB>));
  }
  // clang-format on

  // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
  using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy(
      make_shape(int32_t(0), int32_t(0), int32_t(0)))));

  using ATensor = decltype(make_tensor(
      get_logical_ptr(static_cast<InternalElementA const*>(nullptr)),
      shape(GmemLayoutA::TVbNbKL_to_offset_copy(
          make_shape(int32_t(0), int32_t(0), int32_t(0)))),
      PrepackedStrideA{}));

  using BTensor = decltype(make_tensor(
      get_logical_ptr(static_cast<InternalElementB const*>(nullptr)),
      repeat_like(StrideB{}, int32_t(0)), StrideB{}));
  using ScaleTensor = decltype(make_tensor(
      get_logical_ptr(static_cast<NonVoidElementScale const*>(nullptr)),
      repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}));

  using ZeroTensor = decltype(make_tensor(
      get_logical_ptr(static_cast<NonVoidElementZero const*>(nullptr)),
      repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}));

  static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) {
    return make_tma_copy<TmaElementA>(
        GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}),
        shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})),
        size<1>(ClusterShape{}));  // mcast along N mode for this M load, if any
  }

  static constexpr auto make_tma_copy_scale(
      ScaleTensor tensor_scale = ScaleTensor{}) {
    return make_tma_copy(GmemTiledCopyScale{}, tensor_scale,
                         SmemLayoutScale{}(_, _, cute::Int<0>{}),
                         ScaleTileShape{},
                         _1{});  // mcast along N mode for this M load, if any
  }

  static constexpr auto make_tma_copy_zero(
      ZeroTensor tensor_zero = ZeroTensor{}) {
    return make_tma_copy(GmemTiledCopyScale{}, tensor_zero,
                         SmemLayoutScale{}(_, _, cute::Int<0>{}),
                         ScaleTileShape{},
                         _1{});  // mcast along N mode for this M load, if any
  }

  static constexpr auto make_tma_copy_B(BTensor tensor_b = BTensor{}) {
    return make_tma_copy(
        GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
        make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
        size<0>(ClusterShape{}));  // mcast along M mode for this N load, if any
  }

 public:
  // Same as upstream, should be kept the same when possible, not formatted for
  // easier comparison
  //  with `RealInternalElementA` -> `ElementA` since we support `SwapAB` logic
  // clang-format off
  static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); 

  static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{});

  // Just pick the max alignment of A and B since it is required to be at least 128B
  static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB);

  static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment");

  struct SharedStorage
  {
    static constexpr int scale_elements = elements_per_smem_scale();
    static constexpr int zero_elements = elements_per_smem_zero();
    struct TensorStorage : cute::aligned_struct<cute::max(SmemAlignmentA, SmemAlignmentB)> {
      cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> smem_A;
      cute::ArrayEngine<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
      cute::ArrayEngine<NonVoidElementScale, scale_elements> smem_scale;
      cute::ArrayEngine<NonVoidElementZero, zero_elements> smem_zero;
    } tensors;

    using PipelineStorage = typename MainloopPipeline::SharedStorage;
    PipelineStorage pipeline;
  };
  using TensorStorage = typename SharedStorage::TensorStorage;
  using PipelineStorage = typename SharedStorage::PipelineStorage;

  // Host side kernel arguments
  struct Arguments {
    ElementA const* ptr_A = nullptr;
    StrideA dA{};
    ElementB const* ptr_B = nullptr;
    StrideB dB{};
    ElementScale const* ptr_S = nullptr;
    NonVoidStrideScale dS{};
    int group_size = 0;
    ElementZero const* ptr_Z = nullptr;
    uint32_t mma_promotion_interval = 4;
  };
  // clang-format on

  //
  //  section setup end
  //

  // Similar (but not idendtical) to upstream, should be kept the same when
  // possible
  //  compared to upstream we use `make_tma_copy_A`, `make_tma_copy_B` etc. to
  //  define the TMA types
  // Device side kernel params
  struct Params {
   public:
    // Assumption: StrideA is congruent with Problem_MK
    using TMA_A = decltype(make_tma_copy_A());
    using TMA_Scale = decltype(make_tma_copy_scale());
    using TMA_Zero = decltype(make_tma_copy_zero());
    using TMA_B = decltype(make_tma_copy_B());

    // required by outer loop: i.e.
    //   cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp
    TMA_A tma_load_a;
    TMA_B tma_load_b;
    TMA_Scale tma_load_scale;
    TMA_Zero tma_load_zero;
    int64_t scale_k;
    int group_size;
    uint32_t tma_transaction_bytes = TmaTransactionBytes;
    uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
    uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
  };

  //
  // Methods
  //

  // Similar (but not idendtical) to upstream, should be kept the same when
  // possible
  //  compared to upstream we use `make_tma_copy_A` and `TVbNbKL_to_offset` here
  //  to handle the prepacked layout
  template <class ProblemShape>
  static constexpr Params to_underlying_arguments(
      ProblemShape const& problem_shape, Arguments const& args,
      void* workspace) {
    (void)workspace;

    // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
    // only rank-3 (MNK)
    auto problem_shape_MNKL = append<4>(problem_shape, 1);
    auto [M, N, K, L] = problem_shape_MNKL;

    auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
    auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);

    auto make_logical_tensor = [&](auto ptr, auto shape, auto stride) {
      return make_tensor(get_logical_ptr(ptr), make_layout(shape, stride));
    };

    typename Params::TMA_A tma_load_a;
    typename Params::TMA_B tma_load_b;
    typename Params::TMA_Scale tma_load_scale;
    typename Params::TMA_Zero tma_load_zero;

    auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
    tma_load_a = make_tma_copy_A(
        make_logical_tensor(ptr_A, shape(layout), stride(layout)));

    tma_load_b = make_tma_copy_B(
        make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB));

    int32_t scale_k =
        (ModeHasScales) ? (K + args.group_size - 1) / args.group_size : 0;
    int32_t group_size = (ModeHasScales) ? args.group_size : 0;

    if constexpr (ModeHasScales) {
      tma_load_scale = make_tma_copy_scale(
          make_logical_tensor(args.ptr_S, make_shape(M, scale_k, L), args.dS));
    }

    if constexpr (KernelConversionMode ==
                  ConversionMode::ConvertAndScaleWithZero) {
      tma_load_zero = make_tma_copy_zero(
          make_logical_tensor(args.ptr_Z, make_shape(M, scale_k, L), args.dS));
    }

    if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
                  KernelConversionMode == ConversionMode::ConvertAndScale ||
                  KernelConversionMode ==
                      ConversionMode::ConvertAndScaleWithZero) {
      return {tma_load_a,    tma_load_b, tma_load_scale,
              tma_load_zero, scale_k,    group_size};
    } else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>,
                    "Conversion mode not handled in to_underlying_arguments.");
    }
  }

  // Same as upstream, should be kept the same when possible, not formatted for
  // easier comparison
  //   with `SwapAB ? N : M -> M` since we don't support SwapAB
  // clang-format off
  template<class ProblemShape>
  static bool
  can_implement(
      ProblemShape const& problem_shape,
      [[maybe_unused]] Arguments const& args) {
    constexpr int tma_alignment_bits = 128;
    auto problem_shape_MNKL = append<4>(problem_shape, 1);
    auto [M,N,K,L] = problem_shape_MNKL;
    
    bool implementable = true;
    constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
    implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
    constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
    implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});

    if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
      implementable = implementable && (args.ptr_S == nullptr);
      implementable = implementable && (args.ptr_Z == nullptr);
    } 
    else if constexpr (ModeHasScales) {
      const int scale_mn = M;
      const int scale_k = (K + args.group_size - 1) / args.group_size;
      constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
      implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
      implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
      implementable = implementable && args.group_size != 0;
      implementable = implementable && (args.ptr_S != nullptr);

      if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
        implementable = implementable && (args.ptr_Z == nullptr);
      }
      else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
        constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits<ElementZero>::value;
        implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
        implementable = implementable && (args.ptr_Z != nullptr);
      } 
      else {
        static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
      }
    }
    else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
    }

    if (!implementable) {
      CUTLASS_TRACE_HOST("  CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
    }
    return implementable;
  }

  static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
  static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk();
  static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk();
  static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;

  /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  CUTLASS_DEVICE
  static void prefetch_tma_descriptors(Params const& mainloop_params) {
    cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
    cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());

    if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
      // Nothing extra to do
    } 
    else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
      cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor());
    }
    else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
      cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor());
      cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor());
    }  
    else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA prefetch.");
    }
    
  }
  // clang-format off

  // Modified from upstream, should be kept close to that when possible
  //  the main difference is special handling for the prepacked A layout
  //
  // Set up the data needed by this collective for load and mma.
  // Returns a tuple of tensors. The collective and the kernel layer have the
  // contract Returned tuple must contain at least two elements, with the first
  // two elements being: gA_mkl - The tma tensor, A after a local tile so it
  // has shape  (TILE_V,TILE_B,m,k,l) gB_nkl - The tma tensor, B after a local
  // tile so it has shape  (TILE_N,TILE_K,n,k,l) The rest of the tensors can be
  // specified as needed by this collective.
  // NOTE: TILE_B is the prepacked block index within a tile. TILE_V is the
  // values within a prepacked block.
  template <class ProblemShape_MNKL>
  CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL,
                                Params const& mainloop_params) const {
    using X = Underscore;
    auto M = get<0>(problem_shape_MNKL), N = get<1>(problem_shape_MNKL),
         K = get<2>(problem_shape_MNKL), L = get<3>(problem_shape_MNKL);

    // (TILE_V,TILE_B,m,k,l)
    auto make_gA_mkl = [&]() {
      // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
      auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
      Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout));
      return local_tile(mA_mkl,
                        make_shape(size<0>(layout), PPBlocksPerTile_MK{}),
                        make_coord(0, make_coord(_, _)));
    };

    // (TILE_N,TILE_K,n,k,l)
    auto make_gB_nkl = [&]() {
      Tensor mB_nkl =
          mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L));
      return local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
                        Step<X, _1, _1>{});
    };

    // (TILE_M,TILE_Scale_K,m,scale_k,l)
    auto make_gS_mkl = [&]() {
      auto scale_k = mainloop_params.scale_k;
      Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(
          make_shape(M, scale_k, L));
      return local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _));
    };

    // (TILE_M,TILE_Scale_K,m,scale_k,l)
    auto make_gZ_mkl = [&]() {
      auto scale_k = mainloop_params.scale_k;
      Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(
          make_shape(M, scale_k, L));
      return local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _));
    };

    if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
      return cute::make_tuple(make_gA_mkl(), make_gB_nkl());
    } else if constexpr (KernelConversionMode ==
                         ConversionMode::ConvertAndScale) {
      return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl());
    } else if constexpr (KernelConversionMode ==
                         ConversionMode::ConvertAndScaleWithZero) {
      return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl(),
                              make_gZ_mkl());
    } else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>,
                    "Conversion mode not handled in load_init.");
    }
  }

  // Similar to upstream, should be kept close to that when possible
  //  the main difference is in the layout comments
  // clang-format off
  /// Perform a collective-scoped matrix multiply-accumulate
  /// Producer Perspective
  /// This overload gets triggered when we have scales.
  template <
    class... Ts,
    class KTileIterator, class BlockCoord
  >
  CUTLASS_DEVICE void
  load(
      Params const& mainloop_params,
      MainloopPipeline pipeline, 
      PipelineState smem_pipe_write,
      cute::tuple<Ts...> const& load_inputs,
      BlockCoord const& blk_coord,
      KTileIterator k_tile_iter, int k_tile_count,
      int thread_idx,
      uint32_t block_rank_in_cluster,
      TensorStorage& shared_tensors) {
    if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
      static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs");
    } 
    else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
      static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs");
    } 
    else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
      static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs");
    } 
    else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA load.");
    }

    int lane_predicate = cute::elect_one_sync();

    if (lane_predicate) {
      Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{});      // (BLK_M,BLK_K,PIPE)
      Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{});      // (BLK_N,BLK_K,PIPE)
      Tensor sA  = as_position_independent_swizzle_tensor(sA_);                                   // (BLK_M,BLK_K,PIPE)
      Tensor sB  = as_position_independent_swizzle_tensor(sB_);                                   // (BLK_N,BLK_K,PIPE)

      //
      // Prepare the TMA loads for A, B and Scales
      //
      
      constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
      uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};

      Tensor gA_mkl = get<0>(load_inputs);
      Tensor gB_nkl = get<1>(load_inputs);

      auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
      auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);

      // Partition the inputs based on the current block coordinates.
      auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
      Tensor gA = gA_mkl(_,_,m_coord,_,l_coord);                                                     // (TILE_V,TILE_B,k)
      Tensor gB = gB_nkl(_,_,n_coord,_,l_coord);                                                     // (TILE_N,TILE_K,k)

      // Applies the mapping from block_tma_a
      Tensor tAgA = block_tma_a.partition_S(gA);                                                 // (TMA,TMA_M,TMA_K,k)
      Tensor tAsA = block_tma_a.partition_D(sA);                                              // (TMA,TMA_M,TMA_K,PIPE)

      Tensor tBgB = block_tma_b.partition_S(gB);                                                 // (TMA,TMA_N,TMA_K,k)
      Tensor tBsB = block_tma_b.partition_D(sB);                                              // (TMA,TMA_N,TMA_K,PIPE)

      uint16_t mcast_mask_a = 0;
      uint16_t mcast_mask_b = 0;
      uint16_t mcast_mask_s = 0;

      // Issue TmaLoads
      // Maps the tile -> block, value
      if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
        auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{};                       // (m,n) -> block_id
        for (int n = 0; n < size<1>(block_layout); ++n) {
          mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
        }
      }

      if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
        auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{};                       // (m,n) -> block_id
        for (int m = 0; m < size<0>(block_layout); ++m) {
          mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
        }
      }

      auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord);

      // Mainloop
      CUTLASS_PRAGMA_NO_UNROLL
      for ( ; k_tile_count > 0; --k_tile_count) {
        // LOCK smem_pipe_write for _writing_
        pipeline.producer_acquire(smem_pipe_write);

        //
        // Copy gmem to smem for *k_tile_iter
        //

        using BarrierType = typename MainloopPipeline::ProducerBarrierType;
        BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);

        int write_stage = smem_pipe_write.index();
        copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
        copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));

        if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
          // Nothing extra to do.
        }
        else if constexpr (ModeHasScales) {
          auto tSgS = get<0>(extra_input_partitions);
          auto tSsS = get<1>(extra_input_partitions);

          // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes
          // on the fly.
          // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K
          // is a multiple of the threadblock tile K
          const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{});
          const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K.
          copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage));

          if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
            // Nothing extra to do
          } 
          else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
            auto tZgZ = get<2>(extra_input_partitions);
            auto tZsZ = get<3>(extra_input_partitions);
            copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage));
          }
          else {
            static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for TMA copy op.");
          } 
        } 
        else {
          static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for TMA copy op.");
        }

        ++k_tile_iter;

        // Advance smem_pipe_write
        ++smem_pipe_write;
      }
    }
  }
  // clang-format off

  // Same as upstream, should be kept the same when possible, not formatted for
  // easier comparison
  // clang-format off
  // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  CUTLASS_DEVICE void
  load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
    int lane_predicate = cute::elect_one_sync();

    // Issue the epilogue waits
    if (lane_predicate) {
      /* This helps avoid early exit of blocks in Cluster
       * Waits for all stages to either be released (all 
       * Consumer UNLOCKs), or if the stage was never used
       * then would just be acquired since the phase was 
       * still inverted from make_producer_start_state
       */
      pipeline.producer_tail(smem_pipe_write);
    }
  }
  // clang-format on

  // Modified from upstream, should be kept close to that when possible
  //  the main differences are handling the prepacked A layout, and separating
  //  the loading of A from upcoverting A
  //
  // Perform a collective-scoped matrix multiply-accumulate
  // Consumer Perspective
  template <class FrgTensorC>
  CUTLASS_DEVICE void mma(MainloopPipeline pipeline,
                          PipelineState smem_pipe_read, FrgTensorC& accum,
                          int k_tile_count, int thread_idx,
                          TensorStorage& shared_tensors,
                          Params const& mainloop_params) {
    static_assert(is_rmem<FrgTensorC>::value,
                  "C tensor must be rmem resident.");
    static_assert(cute::rank(SmemLayoutB{}) == 3,
                  "Smem layout must be rank 3.");
    static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
                  "SmemLayoutAtomB must be rank 2.");
    static_assert(!cute::is_void_v<SmemCopyAtomA>,
                  "SM90 GMMA mainloops must specify a non-void copy atom for "
                  "RF sourced instructions.");
    static_assert(cute::is_void_v<SmemCopyAtomB>,
                  "SM90 GMMA mainloops cannot have a non-void copy atom for "
                  "smem sourced instructions.");

    // Obtain warp index
    int warp_idx = canonical_warp_idx_sync();
    [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128;

    // ((T, (FrgV,(RestM, RestK)), (BlocksM, BlocksK), pipe) -> offset
    auto constexpr smem_A = SmemLayoutA{};

    // convert:
    //   ((T, (MMA,(MMA_M, MMA_K)), (BlocksM, BlocksK), pipe) -> offset
    // to:
    //   (T, MMA, ((MMA_M, BlocksM), (MMA_K, BlocksK)), pipe) -> offset
    // which can be thought of as:
    //   (T, MMA, (MMA_M, MMA_K), pipe) -> offset
    auto constexpr smem_A_mma_ =
        make_layout(get<0, 0>(smem_A), get<0, 1, 0>(smem_A),
                    zip(get<0, 1, 1>(smem_A), get<1>(smem_A)), get<2>(smem_A));
    // flatten to:
    //   (T, MMA, MMA_M, MMA_K, pipe) -> offset
    auto constexpr smem_A_mma = smem_A_mma_(_, _, make_coord(_, _), _);

    Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()),
                            smem_A_mma);  // (T, MMA, MMA_M, MMA_K, pipe)
    Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()),
                            SmemLayoutB{});  // (BLK_N,BLK_K,PIPE)

    //
    // Define C accumulators and A/B partitioning
    //

    TiledMma tiled_mma;
    auto thread_mma = tiled_mma.get_thread_slice(thread_idx);

    Tensor tCsA = sA(thread_idx, _, _, _, _);  // (MMA,MMA_M,MMA_K,PIPE)
    Tensor tCsB = thread_mma.partition_B(sB);  // (MMA,MMA_N,MMA_K,PIPE)

    // Allocate fragments and descriptors
    Tensor tCrA_load = make_tensor<ElementA>(
        tCsA(_, _, _, Int<0>{}).shape());  // (MMA,MMA_N,MMA_K)
    Tensor tCrA_mma = make_fragment_like<ElementMma>(tCrA_load);

    Tensor tCrB = thread_mma.make_fragment_B(tCsB);  // (MMA,MMA_N,MMA_K,PIPE)

    static constexpr int A_CPY_VEC =
        decltype(max_common_vector(tCsA, tCrA_load)){};

    static constexpr int CONVERSION_WIDTH =
        std::min(A_CPY_VEC, int(size<0>(tCrA_mma)));

    auto load_A_to_registers = [&](int read_stage) {
      copy(create_auto_vectorizing_copy<ElementA, decltype(A_CPY_VEC)>(),
           tCsA(_, _, _, read_stage), tCrA_load(_, _, _));
    };

    // Partition of thread -> shared and thread -> RF
    auto partitioned_extra_info =
        partition_extra_mma_info(thread_mma, shared_tensors);
    auto copy_partitions_extra_info = retile_extra_mma_info(
        tiled_mma, partitioned_extra_info, warp_group_thread_idx);
    CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum));  // MMA_M
    CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum));      // N
    CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB));       // K
    CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB));       // PIPE
    CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB));  // PIPE

    //
    // PIPELINED MAIN LOOP
    //

    auto convert_A = [&, a_vec = Int<CONVERSION_WIDTH>{}](int k_block,
                                                          int read_stage) {
      load_extra_info_to_registers(partitioned_extra_info,
                                   copy_partitions_extra_info, k_block,
                                   read_stage);
      transform_A_kblock(tCrA_load, a_vec, tCrA_mma, partitioned_extra_info,
                         k_block);
    };

    // We release buffers to producer warps(dma load) with some mmas in flight
    PipelineState smem_pipe_release = smem_pipe_read;

    tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;

    warpgroup_fence_operand(accum);

    constexpr int K_BLOCK_MAX = size<2>(tCrA_load);

    ConsumerToken barrier_token = {BarrierStatus::WaitAgain};
    // first k tile
    {
      barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
      pipeline.consumer_wait(smem_pipe_read, barrier_token);

      int read_stage = smem_pipe_read.index();
      ++smem_pipe_read;
      barrier_token = pipeline.consumer_try_wait(smem_pipe_read);

      // copy smem->rmem for A operand
      load_A_to_registers(read_stage);
      convert_A(0, read_stage);

      // Unroll the K mode manually to set scale D to 1
      CUTLASS_PRAGMA_UNROLL
      for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
        if (k_block < K_BLOCK_MAX - 1) {
          convert_A(k_block + 1, smem_pipe_read.index());
        }
        warpgroup_arrive();
        // (V,M) x (V,N) => (V,M,N)
        cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
                   tCrB(_, _, k_block, read_stage), accum);
        tiled_mma.accumulate_ = GMMA::ScaleOut::One;
        warpgroup_commit_batch();
      }

      --k_tile_count;
      if (k_tile_count > 0) {
        // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to
        // overwrite the A registers for the first mma.
        warpgroup_wait<K_BLOCK_MAX - 1>();
        pipeline.consumer_wait(smem_pipe_read, barrier_token);
        load_A_to_registers(smem_pipe_read.index());
        convert_A(0, smem_pipe_read.index());
      }
    }

    if (k_tile_count == 0) {
      return;
    }

    warpgroup_fence_operand(accum);
    // Mainloop GMMAs
    CUTLASS_PRAGMA_NO_UNROLL
    for (; k_tile_count > 1; --k_tile_count) {
      //
      // Compute on k_tile
      //

      int read_stage = smem_pipe_read.index();
      ++smem_pipe_read;

      warpgroup_fence_operand(accum);
      // Unroll the K mode manually to set scale D to 1
      CUTLASS_PRAGMA_UNROLL
      for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
        warpgroup_arrive();
        // (V,M) x (V,N) => (V,M,N)
        cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
                   tCrB(_, _, k_block, read_stage), accum);
        tiled_mma.accumulate_ = GMMA::ScaleOut::One;
        warpgroup_commit_batch();

        warpgroup_wait<K_BLOCK_MAX - 1>();
        if (k_block == K_BLOCK_MAX - 1) {
          // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage,
          // so we can release prior barrier
          pipeline.consumer_release(
              smem_pipe_release);  // UNLOCK smem_pipe_release, done _computing_
                                   // on it
          ++smem_pipe_release;
        }

        if (k_block == 0) {
          barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
        }

        if (k_block == K_BLOCK_MAX - 1) {
          pipeline.consumer_wait(smem_pipe_read, barrier_token);
          load_A_to_registers(smem_pipe_read.index());
          convert_A(0, smem_pipe_read.index());
        } else {
          convert_A(k_block + 1, read_stage);
        }
      }
      warpgroup_fence_operand(accum);
    }

    warpgroup_fence_operand(accum);

    {
      //
      // Compute on k_tile
      //

      int read_stage = smem_pipe_read.index();

      warpgroup_fence_operand(accum);

      // Unroll the K mode manually to set scale D to 1
      CUTLASS_PRAGMA_UNROLL
      for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
        warpgroup_arrive();
        // (V,M) x (V,N) => (V,M,N)
        cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
                   tCrB(_, _, k_block, read_stage), accum);
        tiled_mma.accumulate_ = GMMA::ScaleOut::One;
        warpgroup_commit_batch();
        warpgroup_wait<K_BLOCK_MAX - 1>();
        if (k_block == K_BLOCK_MAX - 1) {
          // release prior barrier
          pipeline.consumer_release(
              smem_pipe_release);  // UNLOCK smem_pipe_release, done _computing_
                                   // on it
          ++smem_pipe_release;
        }

        if (k_block < K_BLOCK_MAX - 1) {
          convert_A(k_block + 1, read_stage);
        }
      }
    }

    warpgroup_fence_operand(accum);
  }

  // Perform a Consumer Epilogue to release all buffers
  CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
                               PipelineState smem_pipe_release,
                               int k_tile_count) {
    // Prologue GMMAs
    int prologue_mma_count = 1;
    k_tile_count -= prologue_mma_count;

    smem_pipe_release.advance(k_tile_count);

    // Wait on all GMMAs to complete
    warpgroup_wait<0>();

    for (int count = 0; count < prologue_mma_count; ++count) {
      pipeline.consumer_release(
          smem_pipe_release);  // UNLOCK smem_pipe_release, done _computing_ on
                               // it
      ++smem_pipe_release;
    }
  }

 private:
  // Same as upstream, should be kept the same when possible, not formatted for
  // easier comparison
  // clang-format off
  /// Utilities for any additional inputs inside of the TMA load
  template <class... Ts>
  CUTLASS_DEVICE
  auto partition_extra_tma_inputs(
    Params const& mainloop_params,
    cute::tuple<Ts...> const& load_inputs,
    TensorStorage& shared_tensors,
    uint2 const& cluster_local_block_id,
    int const m_coord, 
    int const l_coord) {

    if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
      return cute::make_tuple();
    } 
    else if constexpr (ModeHasScales) {
      Tensor sS  = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
      Tensor gS_mkl = get<2>(load_inputs);
      auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y);
      Tensor gS = gS_mkl(_,_,m_coord,_,l_coord);                                                  // (BLK_M,BLK_K,k)

      Tensor tSgS = block_tma_s.partition_S(gS);                                              // (TMA,TMA_M,TMA_K,k)
      Tensor tSsS = block_tma_s.partition_D(sS);                                              // (TMA,TMA_M,TMA_K,PIPE)
      if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
        return cute::make_tuple(tSgS, tSsS);
      } 
      else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
        Tensor sZ  = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
        Tensor gZ_mkl = get<3>(load_inputs);
        auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y);
        Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord);                                            // (BLK_M,BLK_K,k)

        Tensor tZgZ = block_tma_z.partition_S(gZ);                                            // (TMA,TMA_M,TMA_K,k)
        Tensor tZsZ = block_tma_z.partition_D(sZ);                                            // (TMA,TMA_M,TMA_K,PIPE)
        return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ);          
      }
      else {
        static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");      
      }
    }
    else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");      
    }
  }
  // clang-format off

  // Same as upstream, should be kept the same when possible, not formatted for
  // easier comparison
  // clang-format off
  /// Utilities for partitioning extra inputs for loading from smem in the mainloop.
  template <class ThreadMma>
  CUTLASS_DEVICE 
  auto partition_extra_mma_info(
    ThreadMma const& mma_thread_slice,
    TensorStorage& shared_tensors) {

    if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
      // nothing to do
      return cute::make_tuple();
    }
    else if constexpr (ModeHasScales) {
      Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
      Tensor tCsS = mma_thread_slice.partition_A(sS);
      Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape()); 

      if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
        return cute::make_tuple(tCsS, tCrS);
      }
      else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
        Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
        Tensor tCsZ = mma_thread_slice.partition_A(sZ);
        Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); 
        return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ);
      }
      else {
        static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
      }
    } 
    else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
    }
  }
  // clang-format on

  // Same as upstream, should be kept the same when possible, not formatted for
  // easier comparison
  // clang-format off
  /// Returns the tiled copy and copy views for the extra inputs.
  template <class TiledMma, class... Ts>
  CUTLASS_DEVICE
  auto retile_extra_mma_info(
    TiledMma const& tiled_mma,
    cute::tuple<Ts...>& partitioned_extra_info,
    int const warp_group_thread_idx) {

    if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
      // nothing to do
      return cute::make_tuple();
    }
    else if constexpr (ModeHasScales) {
      auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma);
      auto smem_thr_copy_S   = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx);
      Tensor tCrS_copy_view  = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info));        // (CPY,CPY_M,CPY_K)
      
      if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
        return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view);
      } 
      else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
        Tensor tCrZ_copy_view  = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info));      // (CPY,CPY_M,CPY_K)
        return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view);
      } 
      else {
        static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
      }
    } 
    else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
    }
  }
  // clang-format on

  // Similar to `copy_A_and_extra_info` upstream, should be kept the same when
  // possible
  //   the main differences this only loads the extra info into registers and
  //   not A (since we now preload more of A in the main pipeline)
  // Load scales and zeros into registers if required
  template <class... Ts, class... Us>
  CUTLASS_DEVICE void load_extra_info_to_registers(
      cute::tuple<Ts...> const& partitioned_mma_extra_info,
      cute::tuple<Us...> const& tiled_copy_and_views, int k_block,
      int read_stage) {
    if (k_block == 0) {
      // We are starting a new k-tile so copy the scale
      if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
        // nothing to do
      } else if constexpr (ModeHasScales) {
        auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views);
        auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views);
        auto tCsS = cute::get<0>(partitioned_mma_extra_info);
        copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage),
             tCrS_copy_view(_, _, k_block));
        if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
          // Nothing extra to do
        } else if constexpr (KernelConversionMode ==
                             ConversionMode::ConvertAndScaleWithZero) {
          auto tCsZ = cute::get<2>(partitioned_mma_extra_info);
          auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views);
          copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage),
               tCrZ_copy_view(_, _, k_block));
        } else {
          static_assert(cutlass::detail::dependent_false<KernelSchedule>,
                        "Conversion mode not handled in A -> RF path.");
        }
      } else {
        static_assert(cutlass::detail::dependent_false<KernelSchedule>,
                      "Conversion mode not handled in A -> RF path.");
      }
    }
  }

  // Similar to upstream, should be kept the same when possible.
  //   the main differences are that `convert_tensor` supports interleaved
  //   layouts and bfloat16 has been optimized. `transform_internal_A` has also
  //   been inlined for code simplicity.
  // Utilities to transform A.
  template <class TCrA_load, int VectorWidthA, class TCrA_mma, class... Ts>
  CUTLASS_DEVICE void transform_A_kblock(
      TCrA_load const& tCrA_load, cute::Int<VectorWidthA> vec_A,
      TCrA_mma& tCrA_mma, cute::tuple<Ts...> const& partitioned_extra_info,
      int const k_block) {
    auto in = tCrA_load(_, _, k_block);
    auto out = tCrA_mma(_, _, k_block);

    if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
      convert_tensor<IlvdBlkLayout>(in, out, vec_A);
    } else if constexpr (ModeHasScales) {
      auto tCrS = cute::get<1>(partitioned_extra_info);
      auto converted_inputs =
          make_fragment_like<ElementScale>(tCrA_mma)(_, _, k_block);
      auto scales = tCrS(_, _, 0);

      // First, we upcast the inputs to the scale type
      convert_tensor<IlvdBlkLayout>(in, converted_inputs, vec_A);
      // Apply scales and broadcast across inputs, store in converted_inputs

      // We need to cast to nv_bfloat16 for the multiply since
      // `cutlass::bfloat16_t` has an overloaded operator* that upconverts to
      // float, which nvcc will not optimize to using vectorized fma
      // instructions (i.e. hfma.bf16_v2)
      if constexpr (std::is_same_v<ElementScale, cutlass::bfloat16_t>) {
        cute::transform(
            recast<nv_bfloat16>(converted_inputs), recast<nv_bfloat16>(scales),
            recast<nv_bfloat16>(converted_inputs), cute::multiplies{});
      } else {
        cute::transform(converted_inputs, scales, converted_inputs,
                        cute::multiplies{});
      }

      // Apply zeros if required
      if constexpr (KernelConversionMode ==
                    ConversionMode::ConvertAndScaleWithZero) {
        auto tCrZ = cute::get<3>(partitioned_extra_info);
        auto converted_zeros = make_fragment_like<ElementScale>(tCrZ)(_, _, 0);

        convert_tensor<void>(tCrZ(_, _, 0), converted_zeros);
        if constexpr (std::is_same_v<ElementScale, cutlass::bfloat16_t>) {
          cute::transform(recast<nv_bfloat16>(converted_inputs),
                          recast<nv_bfloat16>(converted_zeros),
                          recast<nv_bfloat16>(converted_inputs), cute::plus{});
        } else {
          cute::transform(converted_inputs, converted_zeros, converted_inputs,
                          cute::plus{});
        }
      }

      // Finally, we convert the scaled inputs to the mma type.
      convert_tensor<void>(converted_inputs, out);
    } else {
      static_assert(cutlass::detail::dependent_false<KernelSchedule>,
                    "No A data is loaded.");
    }
  }

  // Modified from upstream, should be kept the same when possible
  //   the main differences is that this version supports interleaved converts
  // Utilities for transforming the A operand prior to issuing tensorcore math.
  template <typename IlvdBlkLayout, class EngineIn, class EngineOut,
            class TensorLayout,
            int ConversionVectorWidth = cosize_v<TensorLayout>>
  CUTLASS_DEVICE void convert_tensor(
      Tensor<EngineIn, TensorLayout> const& in,
      Tensor<EngineOut, TensorLayout>& out,
      cute::Int<ConversionVectorWidth> width = {}) {
    // This is an element-wise conversion where we expect both tensors to have
    // the same layout. As a result, we can cast as a cutlass array to use the
    // fast numeric converters without worrying about indexing into the layout.
    constexpr int N = cosize_v<TensorLayout>;

    // The inputs must be backed by registers & be statically sized.
    static_assert(is_rmem<EngineIn>::value,
                  "Input tensor for A conversion must come from registers");
    static_assert(is_rmem<EngineOut>::value,
                  "Output tensor for A conversion must come from registers");
    static_assert(is_static_v<TensorLayout>,
                  "Tensor layout for the conversion must be static");
    static_assert(cosize_v<TensorLayout> == size(TensorLayout{}),
                  "Cosize and size of the layout must be equal.");
    static_assert(
        N % ConversionVectorWidth == 0,
        "Conversion vector width must divide cosize of the tensor layout.");

    using SrcType = typename EngineIn::value_type;
    using DstType = typename EngineOut::value_type;

    using SrcArray = cutlass::Array<SrcType, ConversionVectorWidth>;
    using DstArray = cutlass::Array<DstType, ConversionVectorWidth>;

    constexpr cutlass::FloatRoundStyle RoundStyle =
        cutlass::FloatRoundStyle::round_to_nearest;

    using Converter = cutlass::InterleavedNumericArrayConverter<
        IlvdBlkLayout, DstType, SrcType, ConversionVectorWidth, RoundStyle>;

    constexpr int NumIterations = N / ConversionVectorWidth;

    for (int ii = 0; ii < NumIterations; ++ii) {
      SrcArray const* src_array_ptr =
          reinterpret_cast<SrcArray const*>(raw_pointer_cast(in.data())) + ii;
      DstArray* dst_array_ptr =
          reinterpret_cast<DstArray*>(raw_pointer_cast(out.data())) + ii;
      *dst_array_ptr = Converter::convert(*src_array_ptr);
    }
  }
};

}  // namespace machete