custom_all_reduce.cuh 60 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
#pragma once
/*
   * Copyright (C) 2024-2025, The vLLM team.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "aiter_hip_common.h"
#include "ck_tile/core.hpp"
#include "communication_asm.h"
#include "hip_float8.h"
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>


namespace aiter
{

  constexpr int kMaxBlocks = 80;
  // note: we don't want to use atomics for signals because peer atomics are no
  // supported on PCIe links
  struct Signal
  {
    alignas(128) uint32_t start[kMaxBlocks][8];
    alignas(128) uint32_t end[kMaxBlocks][8];
    alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank
  };

#ifdef USE_ROCM
  struct __align__(16) RankData { const void *ptrs[8]; };
#else
  struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
#endif

  struct __align__(16) RankSignals
  {
#ifndef USE_ROCM
    volatile
#endif
        Signal *signals[8];
  };

  // like std::array, but aligned
  template <typename T, int sz>
  struct __align__(alignof(T) * sz) array_t
  {
    T data[sz];
    using type = T;
    static constexpr int size = sz;
  };

  // use packed type to maximize memory efficiency
  // goal: generate ld.128 and st.128 instructions
  template <typename T>
  struct packed_t
  {
    // the (P)acked type for load/store
    using P = array_t<T, 16 / sizeof(T)>;
    // the (A)ccumulator type for reduction
    using A = array_t<float, 16 / sizeof(T)>;
  };

#define DINLINE __device__ __forceinline__

  // scalar cast functions
  DINLINE float upcast_s(half val) { return __half2float(val); }

  template <typename T>
  DINLINE T downcast_s(float val);
  template <>
  DINLINE half downcast_s(float val)
  {
    return __float2half(val);
  }

  // scalar add functions
  // for some reason when compiling with Pytorch, the + operator for half and
  // bfloat is disabled so we call the intrinsics directly
  DINLINE half &assign_add(half &a, half b)
  {
    a = __hadd(a, b);
    return a;
  }
  DINLINE float &assign_add(float &a, float b) { return a += b; }

#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
  DINLINE float upcast_s(__hip_bfloat16 val) { return __bfloat162float(val); }
  template <>
  DINLINE __hip_bfloat16 downcast_s(float val)
  {
    return __float2bfloat16(val);
  }
  DINLINE __hip_bfloat16 &assign_add(__hip_bfloat16 &a, __hip_bfloat16 b)
  {
    a = __hadd(a, b);
    return a;
  }
#endif

  template <typename T, int N>
  DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b)
  {
#pragma unroll
    for (int i = 0; i < N; i++)
    {
      assign_add(a.data[i], b.data[i]);
    }
    return a;
  }

  template <typename T, int N>
  DINLINE array_t<float, N> upcast(array_t<T, N> val)
  {
    if constexpr (std::is_same<T, float>::value)
    {
      return val;
    }
    else
    {
      array_t<float, N> out;
#pragma unroll
      for (int i = 0; i < N; i++)
      {
        out.data[i] = upcast_s(val.data[i]);
      }
      return out;
    }
  }

  template <typename O>
  DINLINE O downcast(array_t<float, O::size> val)
  {
    if constexpr (std::is_same<typename O::type, float>::value)
    {
      return val;
    }
    //   else if constexpr (std::is_same<typename O::type, __hip_bfloat16>::value)
    //   {
    //     O out;
    // #pragma unroll
    //     for (int i = 0; i < O::size; i++)
    //     {
    //       union fcvt {
    //           uint32_t i32;
    //           float f32;
    //       } u;
    //       u.f32 = val.data[i];
    //       out.data[i] = __builtin_bit_cast(__hip_bfloat16, uint16_t(u.i32 >> 16));
    //     }
    //     return out;
    //   }
    else
    {
      O out;
#pragma unroll
      for (int i = 0; i < O::size; i++)
      {
        out.data[i] = downcast_s<typename O::type>(val.data[i]);
      }
      return out;
    }
  }

  // This function is meant to be used as the first synchronization in the all
  // reduce kernel. Thus, it doesn't need to make any visibility guarantees for
  // prior memory accesses. Note: volatile writes will not be reordered against
  // other volatile writes.
  template <int ngpus>
  DINLINE void start_sync(const RankSignals &sg,
#ifndef USE_ROCM
                          volatile
#endif
                          Signal *self_sg,
                          int rank)
  {
#ifdef USE_ROCM
    uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
    if (threadIdx.x < ngpus)
    {
      // simultaneously write to the corresponding flag of all ranks.
      // Latency = 1 p2p write
      __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
                              flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
      // wait until we got true from all ranks
      while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
                                    __ATOMIC_RELAXED,
                                    __MEMORY_SCOPE_DEVICE) < flag)
        ;
    }
    __syncthreads();
    // use one thread to update flag
    if (threadIdx.x == 0)
      self_sg->_flag[blockIdx.x] = flag;
#else
    if (threadIdx.x < ngpus)
    {
      // reset flag for next time
      self_sg->end[blockIdx.x][threadIdx.x] = 0;
      // simultaneously write to the corresponding flag of all ranks.
      // Latency = 1 p2p write
      sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
      // wait until we got true from all ranks
      while (!self_sg->start[blockIdx.x][threadIdx.x])
        ;
    }
    __syncthreads();
#endif
  }

  // This function is meant to be used as the second or the final synchronization
  // barrier in the all reduce kernel. If it's the final synchronization barrier,
  // we don't need to make any visibility guarantees for prior memory accesses.
  template <int ngpus, bool final_sync = false>
  DINLINE void end_sync(const RankSignals &sg,
#ifndef USE_ROCM
                        volatile
#endif
                        Signal *self_sg,
                        int rank)
  {
#ifdef USE_ROCM
    __syncthreads();
    // eliminate the case that prior writes are not visible after signals become
    // visible. Note that I did not managed to make this happen through a lot of
    // testing. Might be the case that hardware provides stronger guarantee than
    // the memory model.
    uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
    if (threadIdx.x < ngpus)
    {
      // simultaneously write to the corresponding flag of all ranks.
      // Latency = 1 p2p write
      __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
                              flag,
                              final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
                              __MEMORY_SCOPE_SYSTEM);
      // wait until we got true from all ranks
      while (
          __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
                                 final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
                                 __MEMORY_SCOPE_DEVICE) < flag)
        ;
    }
    __syncthreads();
    // use one thread to update flag
    if (threadIdx.x == 0)
      self_sg->_flag[blockIdx.x] = flag;
#else
    __syncthreads();
    // eliminate the case that prior writes are not visible after signals become
    // visible. Note that I did not managed to make this happen through a lot of
    // testing. Might be the case that hardware provides stronger guarantee than
    // the memory model.
    if constexpr (!final_sync)
      __threadfence_system();
    if (threadIdx.x < ngpus)
    {
      // reset flag for next time
      self_sg->start[blockIdx.x][threadIdx.x] = 0;
      // simultaneously write to the corresponding flag of all ranks.
      // Latency = 1 p2p write
      sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
      // wait until we got true from all ranks
      while (!self_sg->end[blockIdx.x][threadIdx.x])
        ;
    }
    if constexpr (!final_sync)
      __syncthreads();
#endif
  }

  template <typename P, int ngpus, typename A>
  DINLINE P packed_reduce(const P *ptrs[], int idx)
  {
    A tmp = upcast(ptrs[0][idx]);
#pragma unroll
    for (int i = 1; i < ngpus; i++)
    {
      packed_assign_add(tmp, upcast(ptrs[i][idx]));
    }
    return downcast<P>(tmp);
  }

  template <typename T, int ngpus>
  __global__ void __launch_bounds__(512, 1)
      cross_device_reduce_1stage_naive(RankData *_dp, RankSignals sg,
#ifndef USE_ROCM
                                 volatile
#endif
                                 Signal *self_sg,
                                 T *__restrict__ result, int rank, int size)
  {
    using P = typename packed_t<T>::P;
    using A = typename packed_t<T>::A;
    // note: we don't reorder the address so the accumulation order is the same
    // for all ranks, ensuring bitwise identical results
    auto dp = *_dp;
    start_sync<ngpus>(sg, self_sg, rank);
    // do the actual reduction
    for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
         idx += gridDim.x * blockDim.x)
    {
      ((P *)result)[idx] = packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
    }
    end_sync<ngpus, true>(sg, self_sg, rank);
    // // Step-2 consumes data written by peers in step-1, so we need
    // // visibility guarantees from this barrier.
    // end_sync<ngpus>(sg, self_sg, rank);
  }

  template <typename P>
#ifdef USE_ROCM
  DINLINE P *get_tmp_buf(Signal *sg)
  {
#else
  DINLINE P *get_tmp_buf(volatile Signal *sg)
  {
#endif
    return (P *)(((Signal *)sg) + 1);
  }

  template <typename T, int ngpus>
  __global__ void __launch_bounds__(512, 1)
      cross_device_reduce_2stage_naive(RankData *_dp, RankSignals sg,
#ifndef USE_ROCM
                                 volatile
#endif
                                 Signal *self_sg,
                                 T *__restrict__ result, int rank, int size)
  {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = gridDim.x * blockDim.x;
    using P = typename packed_t<T>::P;
    using A = typename packed_t<T>::A;
    int part = size / ngpus;
    int start = rank * part;
    int end = rank == ngpus - 1 ? size : start + part;
    int largest_part = part + size % ngpus;
    const P *ptrs[ngpus];
    P *tmps[ngpus];
#pragma unroll
    for (int i = 0; i < ngpus; i++)
    {
      int target = (rank + i) % ngpus;
      ptrs[i] = (const P *)_dp->ptrs[target];
      tmps[i] = get_tmp_buf<P>(sg.signals[target]);
    }
    auto tmp_out = tmps[0];
    start_sync<ngpus>(sg, self_sg, rank);
    // stage 1: reduce scatter
    for (int idx = start + tid; idx < end; idx += stride)
    {
      tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
    }
    end_sync<ngpus>(sg, self_sg, rank);

    // stage 2: allgather. Note: it's important to match the tid between
    // the two stages, because visibility across devices is only guaranteed
    // between threads that have the same tid. If thread i computes the sum of
    // start + i in the first stage, then thread i also gathers start + i from all
    // ranks.
    for (int idx = tid; idx < largest_part; idx += stride)
    {
#pragma unroll
      for (int i = 0; i < ngpus; i++)
      {
        int gather_from_rank = ((rank + i) % ngpus);
        if (gather_from_rank == ngpus - 1 || idx < part)
        {
          int dst_idx = gather_from_rank * part + idx;
          ((P *)result)[dst_idx] = tmps[i][idx];
        }
      }
    }
  }

#define THREAD_NUM 512

// Toggle whether fused allreduce+rmsnorm keeps per-element rms input in float
// before the final cast to output dtype.
#ifndef AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32
#define AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32 1
#endif

  template <typename T, int ngpus>
  __global__ void __launch_bounds__(512, 1)
      cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
#ifndef USE_ROCM
                                 volatile
#endif
                                 Signal *self_sg,
                                 T *__restrict__ result, int rank, int size)
  {
    using P = typename packed_t<T>::P;
    using A = typename packed_t<T>::A;
    constexpr int pack_size = packed_t<T>::P::size;
    constexpr int tnum_gpu = THREAD_NUM / ngpus;
    __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size];
    // note: we don't reorder the address so the accumulation order is the same
    // for all ranks, ensuring bitwise identical results
    auto dp = *_dp;

    // load one gpu data each wave
    int warp_id = threadIdx.x / tnum_gpu;
    int lane_id = threadIdx.x % tnum_gpu;
    start_sync<ngpus>(sg, self_sg, rank);
    // do the actual reduction
    for (int idx = blockIdx.x * tnum_gpu + lane_id; idx < size;
         idx += gridDim.x * tnum_gpu)
    {
      *(reinterpret_cast<P*>(&tmp_smem[0]) + threadIdx.x) = ((const P**)&dp.ptrs[0])[warp_id][idx];
      __syncthreads();
      if (warp_id == 0)
      {
        A add_reg;
#pragma unroll
        for (int i = 0; i < pack_size; ++i)
        {
          add_reg.data[i] = ck_tile::type_convert<float>(tmp_smem[threadIdx.x * pack_size + i]);
        }
        constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size;
#pragma unroll
        for (int i = 1; i < ngpus; ++i)
        {
#pragma unroll
          for (int j = 0; j < pack_size; ++j)
          {
            add_reg.data[j] += ck_tile::type_convert<float>(tmp_smem[smem_gpu_loop_stride * i + threadIdx.x * pack_size + j]);
          }
        }
        P write_reg;
#pragma unroll
        for (int i = 0; i < pack_size; ++i)
        {
          write_reg.data[i] = ck_tile::type_convert<T>(add_reg.data[i]);
        }
        ((P *)result)[idx] = write_reg;
      }
      __syncthreads();
    }
    // maybe do not need device sync
    // end_sync<ngpus, true>(sg, self_sg, rank);
  }

  template <typename T, int ngpus>
  __global__ void __launch_bounds__(512, 1)
      cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
#ifndef USE_ROCM
                                 volatile
#endif
                                 Signal *self_sg,
                                 T *__restrict__ result, int rank, int size)
  {
    constexpr int pack_size = packed_t<T>::P::size;
    constexpr int tnum_gpu = THREAD_NUM / ngpus;
    using P = typename packed_t<T>::P;
    using A = typename packed_t<T>::A;
    __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size];
    int warp_id = threadIdx.x / tnum_gpu;
    int lane_id = threadIdx.x % tnum_gpu;
    int tid = blockIdx.x * tnum_gpu + lane_id;
    int stride = gridDim.x * tnum_gpu;
    int part = size / ngpus;
    int start = rank * part;
    int end = rank == ngpus - 1 ? size : start + part;
    int largest_part = part + size % ngpus;
    const P *ptrs[ngpus];
    P *tmps[ngpus];
#pragma unroll
    for (int i = 0; i < ngpus; i++)
    {
      int target = (rank + i) % ngpus;
      ptrs[i] = (const P *)_dp->ptrs[target];
      tmps[i] = get_tmp_buf<P>(sg.signals[target]);
    }
    auto tmp_out = tmps[0];
    start_sync<ngpus>(sg, self_sg, rank);
    // stage 1: reduce scatter
    for (int idx = start + tid; idx < end; idx += stride)
    {
      *(reinterpret_cast<P*>(&tmp_smem[0]) + threadIdx.x) = ptrs[warp_id][idx];
      __syncthreads();
      // cal add in first 64 threads
      if (warp_id == 0)
      {
        A add_reg;
#pragma unroll
        for (int i = 0; i < pack_size; ++i)
        {
          add_reg.data[i] = ck_tile::type_convert<float>(tmp_smem[pack_size * threadIdx.x + i]);
        }
        constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size;
#pragma unroll
        for (int i = 1; i < ngpus; ++i)
        {
#pragma unroll
          for (int j = 0; j < pack_size; ++j)
          {
            add_reg.data[j] += ck_tile::type_convert<float>(tmp_smem[i * smem_gpu_loop_stride + pack_size * threadIdx.x + j]);
          }
        }
        P write_reg;
#pragma unroll
        for (int i = 0; i < pack_size; ++i)
        {
          write_reg.data[i] = ck_tile::type_convert<T>(add_reg.data[i]);
        }
        tmp_out[idx - start] = write_reg;
      }
      __syncthreads();
    }
    end_sync<ngpus>(sg, self_sg, rank);

    // stage 2: allgather. Note: it's important to match the tid between
    // the two stages, because visibility across devices is only guaranteed
    // between threads that have the same tid. If thread i computes the sum of
    // start + i in the first stage, then thread i also gathers start + i from all
    // ranks.
    for (int idx = tid; idx < largest_part; idx += stride)
    {
        int dst_idx = (warp_id + rank) % ngpus * part + idx;
        ((P *)result)[dst_idx] = tmps[warp_id][idx];
    }
  }

  /*
   * naive allgather
   * for case: input(1345,)
   * */
  template <typename T, int ngpus>
  __global__ void __launch_bounds__(512, 1) allgather_naive(
      RankData* _dp,
      RankSignals sg,
      Signal* self_sg,
      T* __restrict__ result,
      int rank,
      int size
  )
  {
    constexpr int tnum_gpu = THREAD_NUM / ngpus;
    int warp_id = threadIdx.x / tnum_gpu;
    int lane_id = threadIdx.x % tnum_gpu;
    int tid = blockIdx.x * tnum_gpu + lane_id;
    int stride = gridDim.x * tnum_gpu;
    const T* ptrs[ngpus];

#pragma unroll
    for (int i = 0; i < ngpus; ++i)
    {
      ptrs[i] = (const T*)_dp->ptrs[i];
    }
    start_sync<ngpus>(sg, self_sg, rank);

    for (int idx = tid; idx < size; idx += stride)
    {
      int write_idx = warp_id * size + idx;
      result[write_idx] = ptrs[warp_id][idx];
    }
  }

  template <typename T, int ngpus>
  __global__ void __launch_bounds__(512, 1) allgather_vec(
      RankData* _dp,
      RankSignals sg,
      Signal* self_sg,
      T* __restrict__ result,
      int rank,
      int size
  )
  {
    constexpr int tnum_gpu = THREAD_NUM / ngpus;
    using P = typename packed_t<T>::P;
    int warp_id = threadIdx.x / tnum_gpu;
    int lane_id = threadIdx.x % tnum_gpu;
    int tid = blockIdx.x * tnum_gpu + lane_id;
    int stride = gridDim.x * tnum_gpu;
    const P* ptrs[ngpus];

#pragma unroll
    for (int i = 0; i < ngpus; ++i)
    {
      ptrs[i] = (const P*)_dp->ptrs[i];
    }
    start_sync<ngpus>(sg, self_sg, rank);

    for (int idx = tid; idx < size; idx += stride)
    {
      int write_idx = warp_id * size + idx;
      *(reinterpret_cast<P*>(&result[0]) + write_idx) = ptrs[warp_id][idx];
    }
  }

  // fp8 quant all-reduce code start
  template <typename T>
  struct Fp16Filter
  {
    static const bool value = false;
  };

  template <>
  struct Fp16Filter<half>
  {
    static const bool value = true;
  };

  template <typename T>
  struct Bf16Filter
  {
    static const bool value = false;
  };

  template <>
  struct Bf16Filter<__hip_bfloat16>
  {
    static const bool value = true;
  };

  // dtypes only support half and bf16 now
#define FP16_FILTER \
  typename std::enable_if<Fp16Filter<T>::value, void>::type* = nullptr

#define BF16_FILTER \
  typename std::enable_if<Bf16Filter<T>::value, void>::type* = nullptr

  template <template <typename> class functor, typename T, int size>
  DINLINE T packReduce(array_t<T, size> pack)
  {
    auto op = functor<T>();
    T ret_val = pack.data[0];
#pragma unroll
    for (int i = 1; i < size; ++i)
    {
      ret_val = op(ret_val, pack.data[i]);
    }
    return ret_val;
  }

  template <template<typename> class functor, typename T, int size>
  DINLINE array_t<T, size> packOp(array_t<T, size> a, array_t<T, size> b)
  {
    auto op = functor<T>();
    array_t<T, size> ret_pack;
#pragma unroll
    for (int i = 0; i < size; ++i)
    {
      ret_pack.data[i] = op(a.data[i], b.data[i]);
    }
    return ret_pack;
  }

  template <typename T>
  struct AddFunctor
  {
    DINLINE T operator() (T a, T b)
    {
      return a + b;
    }
  };

  template <>
  struct AddFunctor<half>
  {
    DINLINE half operator() (half a, half b)
    {
      float a_fp32 = ck_tile::type_convert<float>(a);
      float b_fp32 = ck_tile::type_convert<float>(b);
      return ck_tile::type_convert<half>(a_fp32 + b_fp32);
    }
  };

  template <>
  struct AddFunctor<__hip_bfloat16>
  {
    DINLINE __hip_bfloat16 operator() (__hip_bfloat16 a, __hip_bfloat16 b)
    {
      float a_fp32 = ck_tile::type_convert<float>(a);
      float b_fp32 = ck_tile::type_convert<float>(b);
      return ck_tile::type_convert<__hip_bfloat16>(a_fp32 + b_fp32);
    }
  };

  template <typename T>
  struct MaxFunctor
  {
    DINLINE T operator() (T a, T b)
    {
      return max(a, b);
    }
  };

  /*
   * todo:
   * static_cast may not safe
   * need a convert dtype template function defined by myself
   *
   * done
   * */
  template <typename T>
  struct AbsMaxFunctor
  {
    DINLINE T operator() (T a, T b)
    {
      T zero_t = ck_tile::type_convert<T>(0.0f);
      a = a > zero_t ? a : zero_t - a;
      b = b > zero_t ? b : zero_t - b;
      return max(a, b);
    }
  };

  template <template <typename> class functor, typename T, int reduce_range>
  DINLINE T warpReduce(T val)
  {
    auto op = functor<T>();
#pragma unroll
    for (int stride = reduce_range / 2; stride > 0; stride >>= 1)
    {
      T tmp = __shfl_xor(val, stride, reduce_range);
      val = op(val, tmp);
    }
    return val;
  }

  // the following code only support bf16 and fp16
  template <typename T>
  DINLINE hip_fp8 elementQuant(T input, T scale_functor)
  {
    return hip_fp8(ck_tile::type_convert<float>(input) / ck_tile::type_convert<float>(scale_functor));
  }

  template <typename T>
  DINLINE T elementDequant(hip_fp8 input, T scale_functor)
  {
    return ck_tile::type_convert<T>(float(input) * ck_tile::type_convert<float>(scale_functor));
  }

  template <typename T, int pack_size>
  DINLINE array_t<hip_fp8, pack_size> packQuant(array_t<T, pack_size> inp_pack, T scale_functor)
  {
    array_t<hip_fp8, pack_size> ret_val;
#pragma unroll
    for (int i = 0; i < pack_size; ++i)
    {
      ret_val.data[i] = elementQuant<T>(inp_pack.data[i], scale_functor);
    }
    return ret_val;
  }

  template <typename T, int pack_size>
  DINLINE array_t<T, pack_size> packDequant(array_t<hip_fp8, pack_size> inp_pack, T scale_functor)
  {
    array_t<T, pack_size> ret_val;
#pragma unroll
    for (int i = 0; i < pack_size; ++i)
    {
      ret_val.data[i] = elementDequant<T>(inp_pack.data[i], scale_functor);
    }
    return ret_val;
  }

  // convert fp16 pack to fp32 pack
  template <typename T, int pack_size>
  DINLINE array_t<float, pack_size> packUpcast(array_t<T, pack_size> inp)
  {
    array_t<float, pack_size> ret_val;
#pragma unroll
    for (int i = 0; i < pack_size; ++i)
    {
      ret_val.data[i] = ck_tile::type_convert<float>(inp.data[i]);
    }
    return ret_val;
  }

  template <typename T, int pack_size>
  DINLINE array_t<T, pack_size> packDowncast(array_t<float, pack_size> inp)
  {
    array_t<T, pack_size> ret_val;
#pragma unroll
    for (int i = 0; i < pack_size; ++i)
    {
      ret_val.data[i] = ck_tile::type_convert<T>(inp.data[i]);
    }
    return ret_val;
  }


  template <typename T, int pack_size, int ngpus>
  DINLINE array_t<T, pack_size> multiGPUPackReduce(const array_t<T, pack_size> *ptrs[ngpus], int index)
  {
    array_t<float, pack_size> ret_val = packUpcast<T, pack_size>(ptrs[0][index]);
#pragma unroll
    for (int gpu_id = 1; gpu_id < ngpus; ++gpu_id)
    {
      array_t<float, pack_size> tmp = packUpcast<T, pack_size>(ptrs[gpu_id][index]);
#pragma unroll
      for (int i = 0; i < pack_size; ++i)
      {
        ret_val.data[i] += tmp.data[i];
      }
    }
    return packDowncast<T, pack_size>(ret_val);
  }

  // bf16 quant fp8 kernel function
  // too slow need to be optimized
  // fp16
  template <typename T, int quant_scale, int pack_size, int ngpus, FP16_FILTER>
  __global__ __forceinline__ void __launch_bounds__(512, 1) allReduceQuantFp8(RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size)
  {
    float FP8_UPBOUND = ck_tile::type_convert<float>(ck_tile::numeric<ck_tile::fp8_t>::max());
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = gridDim.x * blockDim.x;
    using inp_pack = array_t<T, pack_size>;
    using fp8_pack = array_t<hip_fp8, pack_size>;
    int part = size / ngpus;
    int start = rank * part;
    int end = rank == ngpus - 1 ? size : start + part;
    int largest_part = part + size % ngpus;
    const inp_pack *ptrs[ngpus];
    fp8_pack *tmps[ngpus];
#pragma unroll
    for (int i = 0; i < ngpus; i++)
    {
      int target = (rank + i) % ngpus;
      ptrs[i] = (const inp_pack *)_dp->ptrs[target];
      tmps[i] = get_tmp_buf<fp8_pack>(sg.signals[target]);
    }
    auto tmp_out = tmps[0];
    start_sync<ngpus>(sg, self_sg, rank);
    // stage 1: reduce scatter
    for (int idx = start + tid; idx < end; idx += stride)
    {
      inp_pack half8_reg;
      // half8_reg = packed_reduce<P, ngpus, A>(ptrs, idx);
      half8_reg = multiGPUPackReduce<T, pack_size, ngpus>(ptrs, idx);
      ((inp_pack *)result)[idx] = half8_reg;
      // quant
      T thread_max = packReduce<AbsMaxFunctor, T, pack_size>(half8_reg);
      thread_max = warpReduce<MaxFunctor, T, quant_scale / pack_size>(thread_max);
      T scale_factor = ck_tile::type_convert<T>(ck_tile::type_convert<float>(thread_max) / FP8_UPBOUND);
      tmp_out[idx - start] = packQuant<T, pack_size>(half8_reg, scale_factor);
      if (threadIdx.x % (quant_scale / pack_size) == 0)
      {
        *(reinterpret_cast<T*>(&tmp_out[part]) + (idx - start) / (quant_scale / pack_size)) = scale_factor;
      }
    }
    end_sync<ngpus>(sg, self_sg, rank);

    // stage 2: all-gather
    for (int idx = tid; idx < largest_part; idx += stride)
    {
#pragma unroll
      for (int i = 1; i < ngpus; i++)
      {
        int gather_from_rank = ((rank + i) % ngpus);
        if (gather_from_rank == ngpus - 1 || idx < part)
        {
          // dequant
          T scale_factor;
          int factor_stride = quant_scale / pack_size;
          if (threadIdx.x % factor_stride == 0)
          {
            scale_factor = *(reinterpret_cast<T*>(&tmps[i][part]) + idx / factor_stride);
          }
          scale_factor = __shfl(scale_factor, (threadIdx.x / factor_stride) * factor_stride);
          inp_pack half8_reg = packDequant<T, pack_size>(tmps[i][idx], scale_factor);
          int dst_idx = gather_from_rank * part + idx;
          ((inp_pack *)result)[dst_idx] = half8_reg;
        }
      }
    }
  }

  // fused allreduce rmsnorm first step
  template <typename T, int ngpus>
  __global__ void __launch_bounds__(512, 1) reduce_scatter_cross_device_store(
      RankData* _dp,
      RankSignals sg,
      Signal* self_sg,
      int rank,
      int size
  )
  {
    constexpr int pack_size = packed_t<T>::P::size;
    constexpr int tnum_gpu = THREAD_NUM / ngpus;
    using P = typename packed_t<T>::P;
    using A = typename packed_t<T>::A;
    __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size];
    int warp_id = threadIdx.x / tnum_gpu;
    int lane_id = threadIdx.x % tnum_gpu;
    int tid = blockIdx.x * tnum_gpu + lane_id;
    const P* ptrs[ngpus];
    P* tmps[ngpus];
#pragma unroll
    for (int i = 0; i < ngpus; ++i)
    {
      ptrs[i] = (const P*)_dp->ptrs[i];
      tmps[i] = get_tmp_buf<P>(sg.signals[i]);
    }
    start_sync<ngpus>(sg, self_sg, rank);

    // the case of fused_allreduce_rmsnorm does not need thread level boundary check
    int part = size / (pack_size * ngpus);
    for (int idx = tid; idx < part; idx += gridDim.x * tnum_gpu)
    {
      // cross device read by all warp
      P input_reg = ptrs[warp_id][rank * part + idx];
      *(reinterpret_cast<P*>(&tmp_smem[0]) + threadIdx.x) = input_reg;
      __syncthreads();
      // calculate and save in first warp
      if (warp_id == 0)
      {
        A add_reg;
#pragma unroll
        for (int i = 0; i < pack_size; ++i)
        {
          add_reg.data[i] = ck_tile::type_convert<float>(tmp_smem[pack_size * threadIdx.x + i]);
        }
#pragma unroll
        for (int i = 1; i < ngpus; ++i)
        {
#pragma unroll
          for (int j = 0; j < pack_size; ++j)
          {
            add_reg.data[j] += ck_tile::type_convert<float>(tmp_smem[i * pack_size * tnum_gpu + pack_size * threadIdx.x + j]);
          }
        }
        P add_rslt;
#pragma unroll
        for (int i = 0; i < pack_size; ++i)
        {
          add_rslt.data[i] = ck_tile::type_convert<T>(add_reg.data[i]);
        }
        *(reinterpret_cast<P*>(&tmp_smem[0]) + lane_id) = add_rslt;
      }
      __syncthreads();

      // cross device store
      P rslt = *(reinterpret_cast<P*>(&tmp_smem[0]) + lane_id);
      tmps[warp_id][rank * part + idx] = rslt;
    }
    end_sync<ngpus, true>(sg, self_sg, rank);
  }

  template <int reduce_range>
  DINLINE void smemReduceSum(float* smem_addr)
  {
    // a warp executes the same instruction
#pragma unroll
    for (int stride = reduce_range / 2; stride > 32; stride >>= 1)
    {
      if (threadIdx.x < stride)
      {
        smem_addr[threadIdx.x] += smem_addr[threadIdx.x + stride];
      }
      __syncthreads();
    }
    volatile float* v_smem = &smem_addr[0];
    if (threadIdx.x < 32)
    {
      v_smem[threadIdx.x] += v_smem[threadIdx.x + 32];
      v_smem[threadIdx.x] += v_smem[threadIdx.x + 16];
      v_smem[threadIdx.x] += v_smem[threadIdx.x + 8];
      v_smem[threadIdx.x] += v_smem[threadIdx.x + 4];
      v_smem[threadIdx.x] += v_smem[threadIdx.x + 2];
      v_smem[threadIdx.x] += v_smem[threadIdx.x + 1];
    }
    __syncthreads();
  }

  /*
   * input case n dim should be divided by 4096 with dtype bf16
   * and should be divided by 2048 with dtype fp32
   * */
  template <typename T, int tnum, int n_loop>
  __global__ void __launch_bounds__(tnum, 1) local_device_load_rmsnorm_naive(
      RankSignals sg,
      T* __restrict__ residual_inp,
      T* __restrict__ residual_out,
      T* __restrict__ results,
      T* __restrict__ weight,
      float eps,
      int rank,
      int m,
      int n
  )
  {
    constexpr int pack_size = packed_t<T>::P::size;
    using P = typename packed_t<T>::P;
    using A = typename packed_t<T>::A;
    __shared__ float smem[tnum];
    P* tmps = get_tmp_buf<P>(sg.signals[rank]);

    for (int bid = blockIdx.x; bid < m; bid += gridDim.x)
    {
      float square_sum = 0.0f;
#if AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32
      A rmsnorm_inp[n_loop];
#else
      P rmsnorm_inp[n_loop];
#endif
      P w_arr[n_loop];
#pragma unroll
      for (int n_iter = 0; n_iter < n_loop; ++n_iter)
      {
        int read_idx = bid * n_loop * blockDim.x + n_iter * blockDim.x + threadIdx.x;
        P reduce_out_pack = tmps[read_idx];
        P residual_inp_pack = *(reinterpret_cast<P*>(residual_inp) + read_idx);
        w_arr[n_iter] = *(reinterpret_cast<P*>(weight) + n_iter * blockDim.x + threadIdx.x);
        A reduce_pack;
#pragma unroll
        for (int i = 0; i < pack_size; ++i)
        {
          float res_inp = ck_tile::type_convert<float>(residual_inp_pack.data[i]);
          float ar_out = ck_tile::type_convert<float>(reduce_out_pack.data[i]);
          float rms_inp = res_inp + ar_out;
#if AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32
          rmsnorm_inp[n_iter].data[i] = rms_inp;
#else
          rmsnorm_inp[n_iter].data[i] = ck_tile::type_convert<T>(rms_inp);
#endif
          reduce_pack.data[i] = rms_inp * rms_inp;
        }
        square_sum += packReduce<AddFunctor, float, pack_size>(reduce_pack);
      }
      smem[threadIdx.x] = square_sum;
      __syncthreads();
      smemReduceSum<tnum>(&smem[0]);
      square_sum = smem[0];
      float denom = rsqrtf(square_sum / n + eps);
#pragma unroll
      for (int n_iter = 0; n_iter < n_loop; ++n_iter)
      {
        P rmsnorm_rslt;
        P residual_pack;
#pragma unroll
        for (int i = 0; i < pack_size; ++i)
        {
#if AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32
          float x_f32 = rmsnorm_inp[n_iter].data[i];
#else
          float x_f32 = ck_tile::type_convert<float>(rmsnorm_inp[n_iter].data[i]);
#endif
          float w_f32 = ck_tile::type_convert<float>(w_arr[n_iter].data[i]);
          rmsnorm_rslt.data[i] = ck_tile::type_convert<T>(x_f32 * w_f32 * denom);
          residual_pack.data[i] = ck_tile::type_convert<T>(x_f32);
        }
        int write_idx = bid * n_loop * blockDim.x + n_iter * blockDim.x + threadIdx.x;
        *(reinterpret_cast<P*>(results) + write_idx) = rmsnorm_rslt;
        *(reinterpret_cast<P*>(residual_out) + write_idx) = residual_pack;
      }
    }
  }

  /*
   * block size can be 256 and 512
   * corresponding 2048 and 4096 elem per block
   * */
  template <typename T, int tnum, int n_loop>
  __global__ void __launch_bounds__(tnum, 1) local_device_load_rmsnorm(
      RankSignals sg,
      T* __restrict__ residual_inp,
      T* __restrict__ residual_out,
      T* __restrict__ results,
      T* __restrict__ weight,
      float eps,
      int rank,
      int m,
      int n
  )
  {
    constexpr int pack_size = packed_t<T>::P::size;
    using P = typename packed_t<T>::P;
    using A = typename packed_t<T>::A;
    __shared__ float smem[tnum];
    P* tmps = get_tmp_buf<P>(sg.signals[rank]);

    for (int bid = blockIdx.x; bid < m; bid += gridDim.x)
    {
      float square_sum = 0.0f;
#if AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32
      A rmsnorm_inp[n_loop];
#else
      P rmsnorm_inp[n_loop];
#endif
      P w_arr[n_loop];
#pragma unroll
      for (int n_iter = 0; n_iter < n_loop; ++n_iter)
      {
        if (n_iter * tnum + threadIdx.x < (n / pack_size))
        {
          int read_idx = bid * (n / pack_size) + n_iter * tnum + threadIdx.x;
          P reduce_out_pack = tmps[read_idx];
          P residual_inp_pack = *(reinterpret_cast<P*>(residual_inp) + read_idx);
          w_arr[n_iter] = *(reinterpret_cast<P*>(weight) + n_iter * tnum + threadIdx.x);
          A reduce_pack;
#pragma unroll
          for (int i = 0; i < pack_size; ++i)
          {
            float ar_out = ck_tile::type_convert<float>(reduce_out_pack.data[i]);
            float res_inp = ck_tile::type_convert<float>(residual_inp_pack.data[i]);
            float rms_inp = ar_out + res_inp;
#if AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32
            rmsnorm_inp[n_iter].data[i] = rms_inp;
#else
            rmsnorm_inp[n_iter].data[i] = ck_tile::type_convert<T>(rms_inp);
#endif
            reduce_pack.data[i] = rms_inp * rms_inp;
          }
          square_sum += packReduce<AddFunctor, float, pack_size>(reduce_pack);
        }
      }
      smem[threadIdx.x] = square_sum;
      __syncthreads();
      smemReduceSum<tnum>(&smem[0]);
      square_sum = smem[0];
      float denom = rsqrtf(square_sum / n + eps);
#pragma unroll
      for (int n_iter = 0; n_iter < n_loop; ++n_iter)
      {
        if (n_iter * tnum + threadIdx.x < (n / pack_size))
        {
          P rmsnorm_rslt;
          P residual_pack;
#pragma unroll
          for (int i = 0; i < pack_size; ++i)
          {
#if AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32
            float x_f32 = rmsnorm_inp[n_iter].data[i];
#else
            float x_f32 = ck_tile::type_convert<float>(rmsnorm_inp[n_iter].data[i]);
#endif
            float w_f32 = ck_tile::type_convert<float>(w_arr[n_iter].data[i]);
            rmsnorm_rslt.data[i] = ck_tile::type_convert<T>(x_f32 * w_f32 * denom);
            residual_pack.data[i] = ck_tile::type_convert<T>(x_f32);
          }
          int write_idx = bid * (n / pack_size) + n_iter * tnum + threadIdx.x;
          *(reinterpret_cast<P*>(results) + write_idx) = rmsnorm_rslt;
          *(reinterpret_cast<P*>(residual_out) + write_idx) = residual_pack;
        }
      }
    }
  }

  template <typename T, int n_loop>
  __global__ void __launch_bounds__(256, 1) local_device_load_rmsnorm_512n(
      RankSignals sg,
      T* __restrict__ residual_inp,
      T* __restrict__ residual_out,
      T* __restrict__ results,
      T* __restrict__ weight,
      float eps,
      int rank,
      int m,
      int n
  )
  {
    constexpr int pack_size = packed_t<T>::P::size;
    using P = typename packed_t<T>::P;
    using A = typename packed_t<T>::A;
    P* tmps = get_tmp_buf<P>(sg.signals[rank]);
    int warp_id = threadIdx.x / 64;
    int lane_id = threadIdx.x % 64;
    int warp_num = blockDim.x / 64;

    for (int bid = blockIdx.x * warp_num + warp_id; bid < m; bid += gridDim.x * warp_num)
    {
      float square_sum = 0.0f;
#if AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32
      A rmsnorm_inp[n_loop];
#else
      P rmsnorm_inp[n_loop];
#endif
      P w_arr[n_loop];
#pragma unroll
      for (int n_iter = 0; n_iter < n_loop; ++n_iter)
      {
        int read_idx = bid * 64 * n_loop + n_iter * 64 + lane_id;
        P reduce_out_pack = tmps[read_idx];
        P residual_inp_pack = *(reinterpret_cast<P*>(residual_inp) + read_idx);
        w_arr[n_iter] = *(reinterpret_cast<P*>(weight) + n_iter * 64 + lane_id);
        A reduce_pack;
#pragma unroll
        for (int i = 0; i < pack_size; ++i)
        {
          float ar_out = ck_tile::type_convert<float>(reduce_out_pack.data[i]);
          float res_inp = ck_tile::type_convert<float>(residual_inp_pack.data[i]);
          float rms_inp = ar_out + res_inp;
#if AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32
          rmsnorm_inp[n_iter].data[i] = rms_inp;
#else
          rmsnorm_inp[n_iter].data[i] = ck_tile::type_convert<T>(rms_inp);
#endif
          reduce_pack.data[i] = rms_inp * rms_inp;
        }
        float tmp_sum = packReduce<AddFunctor, float, pack_size>(reduce_pack);
        square_sum += tmp_sum;
      }
      square_sum = warpReduce<AddFunctor, float, 64>(square_sum);
      float denom = rsqrtf(square_sum / n + eps);
#pragma unroll
      for (int n_iter = 0; n_iter < n_loop; ++n_iter)
      {
        P rmsnorm_rslt;
        P residual_pack;
#pragma unroll
        for (int i = 0; i < pack_size; ++i)
        {
#if AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32
          float x_f32 = rmsnorm_inp[n_iter].data[i];
#else
          float x_f32 = ck_tile::type_convert<float>(rmsnorm_inp[n_iter].data[i]);
#endif
          float w_f32 = ck_tile::type_convert<float>(w_arr[n_iter].data[i]);
          rmsnorm_rslt.data[i] = ck_tile::type_convert<T>(x_f32 * w_f32 * denom);
          residual_pack.data[i] = ck_tile::type_convert<T>(x_f32);
        }
        int write_idx = bid * 64 * n_loop + n_iter * 64 + lane_id;
        *(reinterpret_cast<P*>(results) + write_idx) = rmsnorm_rslt;
        *(reinterpret_cast<P*>(residual_out) + write_idx) = residual_pack;
      }
    }
  }

  using IPC_KEY = std::array<uint8_t, sizeof(hipIpcMemHandle_t)>;
  static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t));
  static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t));

  class CustomAllreduce
  {
  public:
    int rank_;
    int world_size_;
    bool full_nvlink_;

    // below are device pointers
    RankSignals sg_;
    std::unordered_map<void *, RankData *> buffers_;
    Signal *self_sg_;

    // stores the registered device pointers from all ranks
    RankData *d_rank_data_base_, *d_rank_data_end_;
    std::vector<void *> graph_unreg_buffers_;
    // a map from IPC handles to opened IPC pointers
    std::map<IPC_KEY, char *> ipc_handles_;

#ifdef DTK_ENV
    hipEvent_t event_;
    void* buffer_ptr_;
    size_t buffer_size_;
#endif

    /**
     * meta is a pointer to device metadata and temporary buffer for allreduce.
     *
     * There's a total of sizeof(Signal) of prefix before the actual data,
     * so meta + 1 points to actual temporary buffer.
     *
     * note: this class does not own any device memory. Any required buffers
     * are passed in from the constructor
     */
    CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
                    const hipIpcMemHandle_t *handles,
                    const std::vector<int64_t> &offsets, int rank,
                    bool fully_connected = true)
        : rank_(rank),
          world_size_(offsets.size()),
          full_nvlink_(fully_connected),
          self_sg_(meta),
          d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
          d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData))
    {
#ifdef DTK_ENV
      hipEventCreateWithFlags(&event_, hipEventReleaseToSystem | hipEventDisableTiming);
      buffer_size_ = 4;
      hipHostMalloc(&buffer_ptr_, buffer_size_, hipHostMallocDefault);
#endif

      for (int i = 0; i < world_size_; i++)
      {
        Signal *rank_sg;
        if (i != rank_)
        {
          char *handle = open_ipc_handle(&handles[i]);
          handle += offsets[i];
          rank_sg = (Signal *)handle;
        }
        else
        {
          rank_sg = self_sg_;
        }
        sg_.signals[i] = rank_sg;
      }
    }

    char *open_ipc_handle(const void *ipc_handle)
    {
      auto [it, new_handle] =
          ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
      if (new_handle)
      {
        char *ipc_ptr;
        HIP_CALL(hipIpcOpenMemHandle((void **)&ipc_ptr,
                                       *((const hipIpcMemHandle_t *)ipc_handle),
                                       hipIpcMemLazyEnablePeerAccess));
        it->second = ipc_ptr;
      }
      return it->second;
    }

    std::pair<std::vector<uint8_t>, std::vector<int64_t>>
    get_graph_buffer_ipc_meta()
    {
      auto num_buffers = graph_unreg_buffers_.size();
      auto handle_sz = sizeof(hipIpcMemHandle_t);
      std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
      std::vector<int64_t> offsets(num_buffers);
      for (int i = 0; i < num_buffers; i++)
      {
        auto ptr = graph_unreg_buffers_[i];
        void *base_ptr;
        // note: must share the base address of each allocation, or we get wrong
        // address
        if (hipPointerGetAttribute(&base_ptr,
#ifdef USE_ROCM
                                  HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#else
                                  CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#endif
                                  (hipDeviceptr_t)ptr) != CUDA_SUCCESS)
          throw std::runtime_error("failed to get pointer attr");
        HIP_CALL(hipIpcGetMemHandle(
            (hipIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
        offsets[i] = ((char *)ptr) - ((char *)base_ptr);
      }
      return std::make_pair(handles, offsets);
    }

    void check_rank_data_capacity(size_t num = 1)
    {
      if (d_rank_data_base_ + num > d_rank_data_end_)
        throw std::runtime_error(
            "Rank data buffer is overflowed by " +
            std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
    }

    void register_buffer(const std::vector<torch::Tensor> &handles,
                         const std::vector<int64_t> &offsets, void *self)
    {
      check_rank_data_capacity();
      RankData data;
      for (int i = 0; i < world_size_; i++)
      {
        if (i != rank_)
        {
          hipIpcMemHandle_t* ipc_handle_ptr = (hipIpcMemHandle_t*)handles[i].data_ptr();
          char *handle = open_ipc_handle((void*)ipc_handle_ptr);
          handle += offsets[i];
          data.ptrs[i] = handle;
        }
        else
        {
          data.ptrs[i] = self;
        }
      }
      auto d_data = d_rank_data_base_++;
      HIP_CALL(
          hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice));
      buffers_[self] = d_data;
    }

    RankData *get_buffer_RD(hipStream_t stream, void *input)
    {
      RankData *ptrs;
      auto it = buffers_.find(input);
      if (it != buffers_.end())
      {
        ptrs = it->second;
      }
      else
      {
        hipStreamCaptureStatus status;
        HIP_CALL(hipStreamIsCapturing(stream, &status));
        if (status == hipStreamCaptureStatusActive)
        {
          ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
          graph_unreg_buffers_.push_back(input);
        }
        else
        {
          throw std::runtime_error(
              "buffer address " +
              std::to_string(reinterpret_cast<uint64_t>(input)) +
              " is not registered!");
        }
      }

      return ptrs;
    }

    // note: when registering graph buffers, we intentionally choose to not
    // deduplicate the addresses. That means if the allocator reuses some
    // addresses, they will be registered again. This is to account for the remote
    // possibility of different allocation patterns between ranks. For example,
    // rank 1 may get the same input address for the second allreduce, but rank 2
    // got a different address. IPC handles have internal reference counting
    // mechanism so overhead should be small.
    void register_graph_buffers(
        const std::vector<torch::Tensor> &handles,
        const std::vector<torch::Tensor> &offsets)
    {
      auto num_buffers = graph_unreg_buffers_.size();
      check_rank_data_capacity(num_buffers);
      std::vector<RankData> rank_data(num_buffers);
      for (int i = 0; i < num_buffers; i++)
      {
        auto self_ptr = graph_unreg_buffers_[i];
        auto &rd = rank_data[i];
        for (int j = 0; j < world_size_; j++)
        {
          if (j != rank_)
          {
            hipIpcMemHandle_t* ipc_handle_ptr = (hipIpcMemHandle_t*)handles[j].data_ptr() + i;
            char *handle = open_ipc_handle(ipc_handle_ptr);
            handle += *((int64_t*)offsets[j].data_ptr() + i);
            rd.ptrs[j] = handle;
          }
          else
          {
            rd.ptrs[j] = self_ptr;
          }
        }
      }
      HIP_CALL(hipMemcpy(d_rank_data_base_, rank_data.data(),
                           sizeof(RankData) * num_buffers,
                           hipMemcpyHostToDevice));
      d_rank_data_base_ += num_buffers;
      graph_unreg_buffers_.clear();
    }

    /*
     * call all reduce fp8 kernel
     * case size in single gpu: (128, 8192)
     * support 8 gpu only
     * should make ngpus as template param
     * should quant scale match hidden_dim when hidden_dim less than 128?
     * */
    template <typename T>
    void runFp8QuantKernel(hipStream_t stream, T* input, T* output, int size)
    {
      RankData *ptrs = get_buffer_RD(stream, input);
      // 32 block 512 thread or 64 block 256 thread
#define DISPATHC_UNIT(pack_size, quant_scale, ngpus)                                                                             \
  do                                                                                                                             \
  {                                                                                                                              \
    case ngpus:                                                                                                                  \
    {                                                                                                                            \
      allReduceQuantFp8<T, quant_scale, pack_size, ngpus><<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size); \
      return ;                                                                                                                   \
    }                                                                                                                            \
  }while(0)

#define DISPATCH_CALL(pack_size, block_size, quant_scale)                                \
  do                                                                                     \
  {                                                                                      \
   block.x = block_size;                                                                 \
    grid.x = min((16384 / block_size), (single_device_size / (pack_size * block_size))); \
    size /= pack_size;                                                                   \
    switch (world_size_)                                                                 \
    {                                                                                    \
      DISPATHC_UNIT(pack_size, quant_scale, 2);                                          \
      DISPATHC_UNIT(pack_size, quant_scale, 4);                                          \
      DISPATHC_UNIT(pack_size, quant_scale, 6);                                          \
      DISPATHC_UNIT(pack_size, quant_scale, 8);                                          \
    }                                                                                    \
  } while(0)

      int single_device_size = size / world_size_;
      constexpr int max_thread_num = 512;
      constexpr int max_pack_size = 8;
      constexpr int max_elem_perblock = max_thread_num * max_pack_size;
      dim3 grid, block;
      if (single_device_size % 128 == 0)
      {
        DISPATCH_CALL(8, 256, 128);
      }
      else if (single_device_size % 64 == 0)
      {
        DISPATCH_CALL(8, 256, 64);
      }
      else if (single_device_size % 32 == 0)
      {
        DISPATCH_CALL(8, 256, 32);
      }
      else if (single_device_size % 16 == 0)
      {
        DISPATCH_CALL(8, 256, 16);
      }
      else // 512
      {
        DISPATCH_CALL(8, 256, 8);
      }
    }

    /**
     * This is the result after careful grid search. Using 36 blocks give the best
     * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
     * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
     * Not quite sure the underlying reason, but my guess is that too many SMs
     * will cause contention on NVLink bus.
     */
    template <typename T>
    void allreduce(hipStream_t stream, T *input, T *output, int size,
#ifndef USE_ROCM
                   int threads = 512, int block_limit = 20){
#else
                   int threads = 512, int block_limit = 16)
    {
#endif
        auto d = packed_t<T>::P::size;
    if (size % d != 0)
      throw std::runtime_error(
          "custom allreduce currently requires input length to be multiple "
          "of " +
          std::to_string(d));
    if (block_limit > kMaxBlocks)
      throw std::runtime_error("max supported block limit is " +
                               std::to_string(kMaxBlocks) + ". Got " +
                               std::to_string(block_limit));

    RankData *ptrs = get_buffer_RD(stream, input);

    auto bytes = size * sizeof(T);
    size /= d;
    int blocks = 16;
    bool call_1stage = false;
    bool call_2stage = false;
    if (world_size_ == 2)
    {
      call_1stage = true;
    }
    else if (full_nvlink_)
    {
      if ((world_size_ <= 4 && bytes < 160 * 1024) || (world_size_ <= 8 && bytes < 80 * 1024))
      {
        call_1stage = true;
      }
      else
      {
        call_2stage = true;
      }
    }
    if (call_1stage)
    {
      blocks = std::min(kMaxBlocks, (size + (threads / world_size_) - 1) / (threads / world_size_));
    }
    else if (call_2stage)
    {
      blocks = std::min(kMaxBlocks, (size / world_size_ + (threads / world_size_) - 1) / (threads / world_size_));
    }

#ifdef DTK_ENV
#define KL(ngpus, name)                                                       \
  name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
                                                 rank_, size);
#else
#define KL(ngpus, name)                                               \
    void* args[] = {&ptrs, &sg_, &self_sg_, &output, &rank_, &size};  \
    hipExtLaunchKernel(                                               \
        reinterpret_cast<void*>(name<T, ngpus>), dim3(blocks),        \
        dim3(threads), args, 0, stream, nullptr, nullptr,             \
        hipExtAddAcquireSystemScope);
#endif


#define dispatch(ngpus, name)                   \
    do                                          \
    {                                           \
      if (bytes % 128 == 0 && world_size_ != 6) \
      {                                         \
        KL(ngpus, name)                         \
      }                                         \
      else                                      \
      {                                         \
        KL(ngpus, name##_naive)                 \
      }                                         \
    } while(0)

#define REDUCE_CASE(ngpus)                         \
  case ngpus:                                      \
  {                                                \
    if (call_1stage)                               \
    {                                              \
      dispatch(ngpus, cross_device_reduce_1stage); \
    }                                              \
    else if (call_2stage)                          \
    {                                              \
      dispatch(ngpus, cross_device_reduce_2stage); \
    }                                              \
    break;                                         \
  }

    switch (world_size_)
    {
      REDUCE_CASE(2)
      REDUCE_CASE(4)
      REDUCE_CASE(6)
      REDUCE_CASE(8)
    default:
      throw std::runtime_error(
          "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
          "gpus = " +
          std::to_string(world_size_));
    }
#undef REDUCE_CASE
#undef KL
  }

  template <typename T>
  void dispatchAllGather(hipStream_t stream, T* input, T* output, int size)
  {
    RankData* ptrs = get_buffer_RD(stream, input);
    auto d = packed_t<T>::P::size;
    dim3 block(512);
    if (size % d != 0)
    {
      int block_num = (size + 512 - 1) / 512;
      dim3 grid(std::min(block_num, 80));
      switch (world_size_)
      {
        case 8:
          allgather_naive<T, 8><<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
          break;
        case 4:
          allgather_naive<T, 4><<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
          break;
        case 2:
          allgather_naive<T, 2><<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
          break;
        default:
          printf("allgather world_size error\n");
      }
    }
    else
    {
      size /= d;
      int tnum_per_block = 512 / world_size_;
      int block_num = (size + tnum_per_block - 1) / tnum_per_block;
      dim3 grid(std::min(block_num, 80));
      switch (world_size_)
      {
        case 8:
          allgather_vec<T, 8><<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
          break;
        case 4:
          allgather_vec<T, 4><<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
          break;
        case 2:
          allgather_vec<T, 2><<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
          break;
        default:
          printf("allgather world_size error\n");
      }
    }
  }

  template <typename T>
  void dispatchFusedAllReduceRMSNorm(hipStream_t stream, T* input, T* residual_inp, T* residual_out, T* output, T* weight, float eps, int m, int n)
  {
    auto d = packed_t<T>::P::size;
    int size = m * n;
    if (size % d != 0)
    {
      throw std::runtime_error(
          "custom allreduce currently requires input length to be multiple "
          "of " +
          std::to_string(d));
    }
    RankData* ptrs = get_buffer_RD(stream, input);
    hipDevice_t dev;
    hipDeviceProp_t dev_prop;
    hipGetDevice(&dev);
    hipGetDeviceProperties(&dev_prop, dev);
    uint32_t num_cu = dev_prop.multiProcessorCount;

    // step 1, run reduce-scatter + allgather cross device save
    dim3 block(512);
    int block_num = ((size / world_size_) + 512 - 1) / 512;
    dim3 grid(std::min(block_num, 80));
    void* args[] = {&ptrs, &sg_, &self_sg_, &rank_, &size};
    switch (world_size_)
    {
#ifdef DTK_ENV
      case 8:
        reduce_scatter_cross_device_store<T, 8><<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, rank_, size);
        break;
      case 4:
        reduce_scatter_cross_device_store<T, 4><<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, rank_, size);
        break;
      case 2:
        reduce_scatter_cross_device_store<T, 2><<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, rank_, size);
        break;
#else
      case 8:
        hipExtLaunchKernel(reinterpret_cast<void*>(reduce_scatter_cross_device_store<T, 8>),
                           grid, block, args, 0, stream, nullptr, nullptr, hipExtAddAcquireSystemScope);
        break;
      case 4:
        hipExtLaunchKernel(reinterpret_cast<void*>(reduce_scatter_cross_device_store<T, 4>),
                           grid, block, args, 0, stream, nullptr, nullptr, hipExtAddAcquireSystemScope);
        break;
      case 2:
        hipExtLaunchKernel(reinterpret_cast<void*>(reduce_scatter_cross_device_store<T, 2>),
                           grid, block, args, 0, stream, nullptr, nullptr, hipExtAddAcquireSystemScope);
        break;
#endif
      default:
        printf("fused allreduce rmsnorm world size error\n");
    }

    // step 2, run allgather local device load + rmsnorm
    int n_bytes = n * sizeof(T);
    auto setGrid = [&](int naive_grid_size, const void* kernel_ptr)
    {
      int occupancy;
      hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel_ptr, block.x, 0);
      grid.x = naive_grid_size < num_cu * occupancy ? naive_grid_size : num_cu * occupancy;
    };

#define launch_fused_allreduce_rmsnorm(template_kernel)                                                               \
    do                                                                                                                \
    {                                                                                                                 \
      auto kernel_ptr = reinterpret_cast<const void*>(template_kernel);                                               \
      setGrid(naive_grid_size, kernel_ptr);                                                                           \
      template_kernel<<<grid, block, 0, stream>>>(sg_, residual_inp, residual_out, output, weight, eps, rank_, m, n); \
    } while (0)

    if (n_bytes % 1024 == 0)
    {
      if (8192 <= n_bytes && n_bytes <= 32768)
      {
        int naive_grid_size = m;
        int n_loop = n_bytes / 8192; // 1, 2, 3, 4
        if (n_bytes % 8192 == 0)
        {
          switch (n_loop)
          {
            case 1:
              launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_naive<T, 512, 1>));
              break;
            case 2:
              launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_naive<T, 512, 2>));
              break;
            case 3:
              launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_naive<T, 512, 3>));
              break;
            case 4:
              launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_naive<T, 512, 4>));
              break;
          }
        }
        else
        {
          n_loop += 1;
          switch (n_loop)
          {
            case 2:
              launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm<T, 512, 2>));
              break;
            case 3:
              launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm<T, 512, 3>));
              break;
            case 4:
              launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm<T, 512, 4>));
              break;
          }
        }
      }
      else if (4096 <= n_bytes && n_bytes < 8192)
      {
        block.x = 256;
        int naive_grid_size = m;
        if (n_bytes == 4096)
        {
          // naive n_loop = 1
          launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_naive<T, 256, 1>));
        }
        else
        {
          // n_loop = 2
          launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm<T, 256, 2>));
        }
      }
      else if (1024 <= n_bytes && n_bytes < 4096)
      {
        block.x = 256;
        int naive_grid_size = (m + 3) / 4;
        int n_loop = n_bytes / 1024;
        switch (n_loop)
        {
          case 1:
            launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_512n<T, 1>));
            break;
          case 2:
            launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_512n<T, 2>));
            break;
          case 3:
            launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_512n<T, 3>));
            break;
        }
      }
      else
      {
        printf("fused allreduce rmsnorm shape size error\n");
      }
    }
    else
    {
      printf("fused allreduce rmsnorm shape error\n");
    }
  }

  ~CustomAllreduce()
  {
#ifdef DTK_ENV
    if (buffer_ptr_) {
      hipHostFree(buffer_ptr_);
    }
    hipEventDestroy(event_);
#endif

    for (auto [_, ptr] : ipc_handles_)
    {
      HIP_CALL(hipIpcCloseMemHandle(ptr));
    }
  }
}; // namespace aiter
/**
 * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
 a template instantiation:
 * template void aiter::CustomAllreduce::allreduce<half>(hipStream_t, half *,
 half *, int, int, int);
*/
} // namespace aiter