Commit bb596f6e authored by xiaowei.zhang's avatar xiaowei.zhang
Browse files

1. Update MOE; 2. Update sglang mHC; 3. Update test scripts; 4 Add new

   ops.
parent d9ebb683
[submodule "3rdparty/composable_kernel"] [submodule "3rdparty/composable_kernel"]
path = 3rdparty/composable_kernel path = 3rdparty/composable_kernel
url = ../composable_kernel.git url = ../composable_kernel
branch = main branch = rel-5.7.1
[submodule "3rdparty/moe_c"] [submodule "3rdparty/moe_c"]
path = 3rdparty/moe_c path = 3rdparty/moe_c
url = ../moe.git url = ../Moe
branch = master branch = W8A8
Subproject commit 8d05eec5aa99d5fa0cc5f5ef372a2ce02036bb73 Subproject commit a3b6d4d4825e8cf1b29160b9aa5ff8dbea08c8ea
...@@ -71,6 +71,8 @@ from .ops.rope import * ...@@ -71,6 +71,8 @@ from .ops.rope import *
from .ops.topk import * from .ops.topk import *
# from .ops.mha import * # from .ops.mha import *
from .ops.gradlib import * from .ops.gradlib import *
from .ops.mhc import *
from .ops.grouped_gemm import *
# from .ops.trans_ragged_layout import * # from .ops.trans_ragged_layout import *
# from . import mla # from . import mla
from .utility import dtypes,fp4_utils from .utility import dtypes,fp4_utils
This diff is collapsed.
...@@ -121,6 +121,57 @@ gfx938,no_quant,torch.float16,32768,352,4096,129,9,0,0,asm,13001+23001,16397.301 ...@@ -121,6 +121,57 @@ gfx938,no_quant,torch.float16,32768,352,4096,129,9,0,0,asm,13001+23001,16397.301
gfx938,no_quant,torch.float16,40960,352,4096,129,9,0,0,asm,13001+23001,20398.5288 gfx938,no_quant,torch.float16,40960,352,4096,129,9,0,0,asm,13001+23001,20398.5288
gfx938,no_quant,torch.float16,49152,352,4096,129,9,0,0,asm,13001+23001,24396.6972 gfx938,no_quant,torch.float16,49152,352,4096,129,9,0,0,asm,13001+23001,24396.6972
gfx938,no_quant,torch.float16,65536,352,4096,129,9,0,0,asm,13001+23001,32435.0655 gfx938,no_quant,torch.float16,65536,352,4096,129,9,0,0,asm,13001+23001,32435.0655
gfx938,no_quant,torch.bfloat16,1,192,2048,128,8,0,0,asm,10007+20000,39.2876
gfx938,no_quant,torch.bfloat16,2,192,2048,128,8,0,0,asm,10009+20000,51.8687
gfx938,no_quant,torch.bfloat16,4,192,2048,128,8,0,0,asm,10006+20000,73.8475
gfx938,no_quant,torch.bfloat16,8,192,2048,128,8,0,0,asm,10006+20000,107.3038
gfx938,no_quant,torch.bfloat16,16,192,2048,128,8,0,0,asm,10007+20000,149.3254
gfx938,no_quant,torch.bfloat16,32,192,2048,128,8,0,0,asm,10006+20000,182.4199
gfx938,no_quant,torch.bfloat16,64,192,2048,128,8,0,0,asm,10006+20000,213.6452
gfx938,no_quant,torch.bfloat16,128,192,2048,128,8,0,0,asm,10006+20000,205.5664
gfx938,no_quant,torch.bfloat16,256,192,2048,128,8,0,0,asm,11004+21001,225.3643
gfx938,no_quant,torch.bfloat16,512,192,2048,128,8,0,0,asm,11004+21001,268.9685
gfx938,no_quant,torch.bfloat16,1024,192,2048,128,8,0,0,asm,12002+22001,373.4482
gfx938,no_quant,torch.bfloat16,2048,192,2048,128,8,0,0,asm,12001+22001,544.0333
gfx938,no_quant,torch.bfloat16,4096,192,2048,128,8,0,0,asm,13001+23001,859.873
gfx938,no_quant,torch.bfloat16,8192,192,2048,128,8,0,0,asm,13001+23001,1515.4337
gfx938,no_quant,torch.bfloat16,16384,192,2048,128,8,0,0,asm,13001+23001,2881.8408
gfx938,no_quant,torch.bfloat16,32768,192,2048,128,8,0,0,asm,13001+23001,5550.2244
gfx938,no_quant,torch.bfloat16,65536,192,2048,128,8,0,0,asm,13001+23001,10944.1702
gfx938,no_quant,torch.bfloat16,1,384,2048,128,8,0,0,asm,10001+20000,53.056
gfx938,no_quant,torch.bfloat16,2,384,2048,128,8,0,0,asm,10006+20000,77.3086
gfx938,no_quant,torch.bfloat16,4,384,2048,128,8,0,0,asm,10006+20000,112.6348
gfx938,no_quant,torch.bfloat16,8,384,2048,128,8,0,0,asm,10006+20000,177.2747
gfx938,no_quant,torch.bfloat16,16,384,2048,128,8,0,0,asm,10006+20000,260.6267
gfx938,no_quant,torch.bfloat16,32,384,2048,128,8,0,0,asm,10006+20000,320.2976
gfx938,no_quant,torch.bfloat16,64,384,2048,128,8,0,0,asm,10006+20000,367.4922
gfx938,no_quant,torch.bfloat16,128,384,2048,128,8,0,0,asm,10009+20000,364.0352
gfx938,no_quant,torch.bfloat16,256,384,2048,128,8,0,0,asm,11004+21001,391.2504
gfx938,no_quant,torch.bfloat16,512,384,2048,128,8,0,0,asm,12000+22001,455.2254
gfx938,no_quant,torch.bfloat16,1024,384,2048,128,8,0,0,asm,12001+22001,542.8131
gfx938,no_quant,torch.bfloat16,2048,384,2048,128,8,0,0,asm,13001+23001,709.3484
gfx938,no_quant,torch.bfloat16,4096,384,2048,128,8,0,0,asm,13001+23001,1144.2526
gfx938,no_quant,torch.bfloat16,8192,384,2048,128,8,0,0,asm,13001+23001,1982.3018
gfx938,no_quant,torch.bfloat16,16384,384,2048,128,8,0,0,asm,13001+23001,3922.8848
gfx938,no_quant,torch.bfloat16,32768,384,2048,128,8,0,0,asm,13001+23001,7601.1435
gfx938,no_quant,torch.bfloat16,65536,384,2048,128,8,0,0,asm,13001+23001,15053.8397
gfx938,no_quant,torch.bfloat16,1,768,2048,128,8,0,0,asm,10006+20000,75.2789
gfx938,no_quant,torch.bfloat16,2,768,2048,128,8,0,0,asm,10006+20000,119.599
gfx938,no_quant,torch.bfloat16,4,768,2048,128,8,0,0,asm,10007+20000,189.241
gfx938,no_quant,torch.bfloat16,8,768,2048,128,8,0,0,asm,10006+20000,311.8679
gfx938,no_quant,torch.bfloat16,16,768,2048,128,8,0,0,asm,10008+20000,465.2827
gfx938,no_quant,torch.bfloat16,32,768,2048,128,8,0,0,asm,10008+20000,574.3358
gfx938,no_quant,torch.bfloat16,64,768,2048,128,8,0,0,asm,10008+20000,659.7834
gfx938,no_quant,torch.bfloat16,128,768,2048,128,8,0,0,asm,10008+20000,672.0162
gfx938,no_quant,torch.bfloat16,256,768,2048,128,8,0,0,asm,11002+21001,716.3866
gfx938,no_quant,torch.bfloat16,512,768,2048,128,8,0,0,asm,12005+22001,802.6013
gfx938,no_quant,torch.bfloat16,1024,768,2048,128,8,0,0,asm,13001+23001,945.1779
gfx938,no_quant,torch.bfloat16,2048,768,2048,128,8,0,0,asm,13001+23001,1243.9816
gfx938,no_quant,torch.bfloat16,4096,768,2048,128,8,0,0,asm,13001+23001,1989.4641
gfx938,no_quant,torch.bfloat16,8192,768,2048,128,8,0,0,asm,13001+23001,3554.2789
gfx938,no_quant,torch.bfloat16,16384,768,2048,128,8,0,0,asm,13001+23001,6779.7759
gfx938,no_quant,torch.bfloat16,32768,768,2048,128,8,0,0,asm,13001+23001,13203.9373
gfx938,no_quant,torch.bfloat16,65536,768,2048,128,8,0,0,asm,13001+23001,26121.7552
gfx936,no_quant,torch.float16,1,256,3072,256,8,0,0,asm,10006+20000,56.4327 gfx936,no_quant,torch.float16,1,256,3072,256,8,0,0,asm,10006+20000,56.4327
gfx936,no_quant,torch.float16,2,256,3072,256,8,0,0,asm,10006+20000,85.2664 gfx936,no_quant,torch.float16,2,256,3072,256,8,0,0,asm,10006+20000,85.2664
gfx936,no_quant,torch.float16,4,256,3072,256,8,0,0,asm,10004+20000,148.02 gfx936,no_quant,torch.float16,4,256,3072,256,8,0,0,asm,10004+20000,148.02
...@@ -221,3 +272,31 @@ gfx936,no_quant,torch.float16,12288,128,3072,256,8,0,0,asm,13001+23001,2844.9048 ...@@ -221,3 +272,31 @@ gfx936,no_quant,torch.float16,12288,128,3072,256,8,0,0,asm,13001+23001,2844.9048
gfx936,no_quant,torch.float16,16384,128,3072,256,8,0,0,asm,13001+23001,3597.2571 gfx936,no_quant,torch.float16,16384,128,3072,256,8,0,0,asm,13001+23001,3597.2571
gfx936,no_quant,torch.float16,24576,128,3072,256,8,0,0,asm,13001+23001,5205.65 gfx936,no_quant,torch.float16,24576,128,3072,256,8,0,0,asm,13001+23001,5205.65
gfx936,no_quant,torch.float16,32768,128,3072,256,8,0,0,asm,13001+23001,6847.9883 gfx936,no_quant,torch.float16,32768,128,3072,256,8,0,0,asm,13001+23001,6847.9883
gfx936,no_quant,torch.bfloat16,1,384,2048,128,8,0,0,asm,10005+20000,57.5107
gfx936,no_quant,torch.bfloat16,2,384,2048,128,8,0,0,asm,10005+20000,86.1507
gfx936,no_quant,torch.bfloat16,4,384,2048,128,8,0,0,asm,10001+20000,137.9569
gfx936,no_quant,torch.bfloat16,8,384,2048,128,8,0,0,asm,10001+20000,230.5798
gfx936,no_quant,torch.bfloat16,16,384,2048,128,8,0,0,asm,10001+20000,352.5754
gfx936,no_quant,torch.bfloat16,32,384,2048,128,8,0,0,asm,10001+20000,436.6174
gfx936,no_quant,torch.bfloat16,48,384,2048,128,8,0,0,asm,10001+20001,490.5933
gfx936,no_quant,torch.bfloat16,64,384,2048,128,8,0,0,asm,10001+20001,508.85309
gfx936,no_quant,torch.bfloat16,96,384,2048,128,8,0,0,asm,10001+20001,510.02899
gfx936,no_quant,torch.bfloat16,128,384,2048,128,8,0,0,asm,10001+20001,517.6922
gfx936,no_quant,torch.bfloat16,200,384,2048,128,8,0,0,asm,11000+21001,564.4711
gfx936,no_quant,torch.bfloat16,256,384,2048,128,8,0,0,asm,11000+21001,580.3952
gfx936,no_quant,torch.bfloat16,384,384,2048,128,8,0,0,asm,11000+21001,635.8056
gfx936,no_quant,torch.bfloat16,460,384,2048,128,8,0,0,asm,11000+21001,672.3782
gfx936,no_quant,torch.bfloat16,512,384,2048,128,8,0,0,asm,11006+20002,695.6287
gfx936,no_quant,torch.bfloat16,798,384,2048,128,8,0,0,asm,12004+22001,731.5276
gfx936,no_quant,torch.bfloat16,1024,384,2048,128,8,0,0,asm,12000+22001,779.5612
gfx936,no_quant,torch.bfloat16,1280,384,2048,128,8,0,0,asm,12000+22001,832.0495
gfx936,no_quant,torch.bfloat16,1440,384,2048,128,8,0,0,asm,13000+22001,891.2665
gfx936,no_quant,torch.bfloat16,1560,384,2048,128,8,0,0,asm,12004+22001,882.4158
gfx936,no_quant,torch.bfloat16,1880,384,2048,128,8,0,0,asm,13000+22001,885.6415
gfx936,no_quant,torch.bfloat16,2000,384,2048,128,8,0,0,asm,13000+23001,919.6288
gfx936,no_quant,torch.bfloat16,2200,384,2048,128,8,0,0,asm,12005+22001,965.1782
gfx936,no_quant,torch.bfloat16,2400,384,2048,128,8,0,0,asm,12001+22001,999.2413
gfx936,no_quant,torch.bfloat16,2800,384,2048,128,8,0,0,asm,13001+23001,1065.1948
gfx936,no_quant,torch.bfloat16,3200,384,2048,128,8,0,0,asm,13001+23001,1126.6853
gfx936,no_quant,torch.bfloat16,3660,384,2048,128,8,0,0,asm,13001+23001,1216.6051
gfx936,no_quant,torch.bfloat16,4096,384,2048,128,8,0,0,asm,13001+23001,1259.6619
\ No newline at end of file
...@@ -1268,3 +1268,72 @@ gfx938,f8_w8a8_block,torch.float16,14336,256,4096,256,8,0,0,asm,13001+23000,5030 ...@@ -1268,3 +1268,72 @@ gfx938,f8_w8a8_block,torch.float16,14336,256,4096,256,8,0,0,asm,13001+23000,5030
gfx938,f8_w8a8_block,torch.float16,16384,256,4096,256,8,0,0,asm,13001+23000,5608.4473 gfx938,f8_w8a8_block,torch.float16,16384,256,4096,256,8,0,0,asm,13001+23000,5608.4473
gfx938,f8_w8a8_block,torch.float16,17408,256,4096,256,8,0,0,asm,13001+23000,6038.0465 gfx938,f8_w8a8_block,torch.float16,17408,256,4096,256,8,0,0,asm,13001+23000,6038.0465
gfx938,f8_w8a8_block,torch.float16,24576,256,4096,256,8,0,0,asm,13001+23000,8143.5178 gfx938,f8_w8a8_block,torch.float16,24576,256,4096,256,8,0,0,asm,13001+23000,8143.5178
gfx938,f8_w8a8_block,torch.float16,1,512,4096,256,8,0,0,asm,10007+20000,82.9571
gfx938,f8_w8a8_block,torch.float16,2,512,4096,256,8,0,0,asm,10001+20000,112.2287
gfx938,f8_w8a8_block,torch.float16,4,512,4096,256,8,0,0,asm,10002+20000,168.5064
gfx938,f8_w8a8_block,torch.float16,6,512,4096,256,8,0,0,asm,10002+20000,218.1063
gfx938,f8_w8a8_block,torch.float16,8,512,4096,256,8,0,0,asm,10002+20000,257.9547
gfx938,f8_w8a8_block,torch.float16,10,512,4096,256,8,0,0,asm,10002+20000,302.8811
gfx938,f8_w8a8_block,torch.float16,12,512,4096,256,8,0,0,asm,10002+20000,333.6432
gfx938,f8_w8a8_block,torch.float16,14,512,4096,256,8,0,0,asm,10002+20000,373.4999
gfx938,f8_w8a8_block,torch.float16,16,512,4096,256,8,0,0,asm,10002+20000,393.4746
gfx938,f8_w8a8_block,torch.float16,20,512,4096,256,8,0,0,asm,10002+20000,455.5854
gfx938,f8_w8a8_block,torch.float16,24,512,4096,256,8,0,0,asm,10001+20000,514.044
gfx938,f8_w8a8_block,torch.float16,28,512,4096,256,8,0,0,asm,10002+20000,582.8807
gfx938,f8_w8a8_block,torch.float16,32,512,4096,256,8,0,0,asm,10002+20000,618.847
gfx938,f8_w8a8_block,torch.float16,36,512,4096,256,8,0,0,asm,10002+20000,647.2482
gfx938,f8_w8a8_block,torch.float16,40,512,4096,256,8,0,0,asm,10001+20000,676.0396
gfx938,f8_w8a8_block,torch.float16,44,512,4096,256,8,0,0,asm,10002+20000,707.8039
gfx938,f8_w8a8_block,torch.float16,48,512,4096,256,8,0,0,asm,10002+20000,718.1197
gfx938,f8_w8a8_block,torch.float16,56,512,4096,256,8,0,0,asm,10001+20000,757.1511
gfx938,f8_w8a8_block,torch.float16,64,512,4096,256,8,0,0,asm,10002+20000,781.3533
gfx938,f8_w8a8_block,torch.float16,80,512,4096,256,8,0,0,asm,10002+20000,834.0691
gfx938,f8_w8a8_block,torch.float16,96,512,4096,256,8,0,0,asm,10002+20000,871.6102
gfx938,f8_w8a8_block,torch.float16,112,512,4096,256,8,0,0,asm,10002+20000,885.1175
gfx938,f8_w8a8_block,torch.float16,128,512,4096,256,8,0,0,asm,10002+20000,903.3996
gfx938,f8_w8a8_block,torch.float16,160,512,4096,256,8,0,0,asm,10002+20000,918.0774
gfx938,f8_w8a8_block,torch.float16,192,512,4096,256,8,0,0,asm,10001+20000,936.3676
gfx938,f8_w8a8_block,torch.float16,224,512,4096,256,8,0,0,asm,10002+20000,943.6689
gfx938,f8_w8a8_block,torch.float16,256,512,4096,256,8,0,0,asm,10002+20000,951.9721
gfx938,f8_w8a8_block,torch.float16,320,512,4096,256,8,0,0,asm,10002+20000,969.4878
gfx938,f8_w8a8_block,torch.float16,384,512,4096,256,8,0,0,asm,10006+20000,1018.5991999999999
gfx938,f8_w8a8_block,torch.float16,448,512,4096,256,8,0,0,asm,11008+21000,1043.5172
gfx938,f8_w8a8_block,torch.float16,512,512,4096,256,8,0,0,asm,11010+21000,1063.8287
gfx938,f8_w8a8_block,torch.float16,576,512,4096,256,8,0,0,asm,11010+21000,1103.7024
gfx938,f8_w8a8_block,torch.float16,640,512,4096,256,8,0,0,asm,11007+21000,1088.9825
gfx938,f8_w8a8_block,torch.float16,704,512,4096,256,8,0,0,asm,11010+21000,1061.7066
gfx938,f8_w8a8_block,torch.float16,768,512,4096,256,8,0,0,asm,11010+21000,1106.7677
gfx938,f8_w8a8_block,torch.float16,832,512,4096,256,8,0,0,asm,11010+21000,1127.4246
gfx938,f8_w8a8_block,torch.float16,896,512,4096,256,8,0,0,asm,11010+21000,1168.4432
gfx938,f8_w8a8_block,torch.float16,960,512,4096,256,8,0,0,asm,11010+21000,1173.7484
gfx938,f8_w8a8_block,torch.float16,1024,512,4096,256,8,0,0,asm,11010+21000,1219.298
gfx938,f8_w8a8_block,torch.float16,1152,512,4096,256,8,0,0,asm,12002+22000,1338.2958
gfx938,f8_w8a8_block,torch.float16,1280,512,4096,256,8,0,0,asm,12003+22000,1279.3149
gfx938,f8_w8a8_block,torch.float16,1408,512,4096,256,8,0,0,asm,12003+22000,1309.1085
gfx938,f8_w8a8_block,torch.float16,1536,512,4096,256,8,0,0,asm,12002+22000,1346.0515
gfx938,f8_w8a8_block,torch.float16,1664,512,4096,256,8,0,0,asm,12003+22000,1330.7168
gfx938,f8_w8a8_block,torch.float16,1792,512,4096,256,8,0,0,asm,12002+22000,1433.3104
gfx938,f8_w8a8_block,torch.float16,1920,512,4096,256,8,0,0,asm,12004+22000,1557.4367
gfx938,f8_w8a8_block,torch.float16,2048,512,4096,256,8,0,0,asm,12003+22000,1661.2428
gfx938,f8_w8a8_block,torch.float16,2304,512,4096,256,8,0,0,asm,12005+22000,1940.813
gfx938,f8_w8a8_block,torch.float16,2560,512,4096,256,8,0,0,asm,13001+22000,2072.7203
gfx938,f8_w8a8_block,torch.float16,2816,512,4096,256,8,0,0,asm,13001+22000,2095.2974
gfx938,f8_w8a8_block,torch.float16,3072,512,4096,256,8,0,0,asm,13001+22000,2144.6107
gfx938,f8_w8a8_block,torch.float16,3328,512,4096,256,8,0,0,asm,13001+22000,2175.5328
gfx938,f8_w8a8_block,torch.float16,3584,512,4096,256,8,0,0,asm,13001+22000,2234.2864
gfx938,f8_w8a8_block,torch.float16,3840,512,4096,256,8,0,0,asm,12005+22000,2454.0841
gfx938,f8_w8a8_block,torch.float16,4096,512,4096,256,8,0,0,asm,12005+22000,2680.5851
gfx938,f8_w8a8_block,torch.float16,4608,512,4096,256,8,0,0,asm,12005+22000,3087.0188
gfx938,f8_w8a8_block,torch.float16,5120,512,4096,256,8,0,0,asm,12005+22000,3226.8833
gfx938,f8_w8a8_block,torch.float16,5632,512,4096,256,8,0,0,asm,12005+22000,3405.6863
gfx938,f8_w8a8_block,torch.float16,6144,512,4096,256,8,0,0,asm,12005+22000,3731.7076
gfx938,f8_w8a8_block,torch.float16,6656,512,4096,256,8,0,0,asm,13001+23000,3931.2194
gfx938,f8_w8a8_block,torch.float16,7168,512,4096,256,8,0,0,asm,13001+23000,4010.6804
gfx938,f8_w8a8_block,torch.float16,7680,512,4096,256,8,0,0,asm,13001+23000,4195.0509
gfx938,f8_w8a8_block,torch.float16,8192,512,4096,256,8,0,0,asm,13001+23000,4642.5692
gfx938,f8_w8a8_block,torch.float16,10240,512,4096,256,8,0,0,asm,13001+23000,5698.2073
gfx938,f8_w8a8_block,torch.float16,12288,512,4096,256,8,0,0,asm,13001+23000,6601.8277
gfx938,f8_w8a8_block,torch.float16,14336,512,4096,256,8,0,0,asm,13001+23000,7572.0575
gfx938,f8_w8a8_block,torch.float16,16384,512,4096,256,8,0,0,asm,13001+23000,8551.2295
gfx938,f8_w8a8_block,torch.float16,17408,512,4096,256,8,0,0,asm,13001+23000,9230.6924
gfx938,f8_w8a8_block,torch.float16,24576,512,4096,256,8,0,0,asm,13001+23000,12357.5934
\ No newline at end of file
...@@ -19,26 +19,78 @@ from typing import Any, Dict, Optional, Union ...@@ -19,26 +19,78 @@ from typing import Any, Dict, Optional, Union
import torch import torch
import torch.distributed import torch.distributed
from .parallel_state import get_tp_group from .parallel_state import get_tp_group, get_custom_group, has_custom_group
def tensor_model_parallel_all_reduce( def tensor_model_parallel_all_reduce(
input_: torch.Tensor, open_fp8_quant: bool = False input_: torch.Tensor,
use_new: bool = True,
open_fp8_quant: bool = False,
prefill_support: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_, open_fp8_quant) return get_tp_group().all_reduce(input_, use_new, open_fp8_quant, prefill_support)
def tensor_model_parallel_fused_allreduce_rmsnorm( def tensor_model_parallel_fused_allreduce_rmsnorm(
input_: torch.Tensor, residual_inp_: torch.Tensor, weight_: torch.Tensor, eps: float input_: torch.Tensor,
residual_inp_: torch.Tensor,
weight_: torch.Tensor,
eps: float,
prefill_support: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return get_tp_group().fused_allreduce_rmsnorm(input_, residual_inp_, weight_, eps) return get_tp_group().fused_allreduce_rmsnorm(
input_, residual_inp_, weight_, eps, prefill_support
)
def tensor_model_parallel_fused_allreduce_rmsnorm_quant(
input_: torch.Tensor,
residual_inp_: torch.Tensor,
weight_: torch.Tensor,
eps: float,
prefill_support: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return get_tp_group().fused_allreduce_rmsnorm_quant(
input_, residual_inp_, weight_, eps, prefill_support
)
def tensor_model_parallel_fused_allreduce_rmsnorm_quant_per_group(
input_: torch.Tensor,
residual_inp_: torch.Tensor,
weight_: torch.Tensor,
eps: float,
group_size: int = 128,
prefill_support: bool = False,
emit_bf16: bool = False,
):
return get_tp_group().fused_allreduce_rmsnorm_quant_per_group(
input_, residual_inp_, weight_, eps, group_size, prefill_support, emit_bf16=emit_bf16
)
def tensor_model_parallel_fused_qknorm_allreduce(
qkv_in: torch.Tensor,
q_w: torch.Tensor,
k_w: torch.Tensor,
eps: float,
):
return get_tp_group().fused_qknorm_allreduce(qkv_in, q_w, k_w, eps)
def tensor_model_parallel_custom_all_gather(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_custom_all_gather(input_: torch.Tensor) -> torch.Tensor:
return get_tp_group().custom_all_gather(input_) return get_tp_group().custom_all_gather(input_)
def tensor_model_parallel_reduce_scatter(
input_: torch.Tensor,
use_custom: bool = True,
dim: int = 0,
) -> torch.Tensor:
return get_tp_group().reduce_scatter_tensor(input_, use_custom, dim)
def tensor_model_parallel_all_gather( def tensor_model_parallel_all_gather(
input_: torch.Tensor, use_custom: bool = False, dim: int = -1 input_: torch.Tensor, use_custom: bool = False, dim: int = -1
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -59,3 +111,66 @@ def broadcast_tensor_dict( ...@@ -59,3 +111,66 @@ def broadcast_tensor_dict(
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
return tensor_dict return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src) return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
# ============================================================
# Custom group communication operations
# ============================================================
def _assert_has_custom_group():
assert has_custom_group(), (
"No custom group initialized. Call ensure_model_parallel_initialized "
"with custom_group_config to initialize custom groups."
)
def custom_all_reduce(
input_: torch.Tensor,
use_new: bool = True,
open_fp8_quant: bool = False,
group: Optional[str] = None,
) -> torch.Tensor:
"""All-reduce the input tensor across the user-specified custom group.
Args:
group: Name of the custom group. When only one custom group is
initialized this can be omitted. When multiple groups exist,
pass the group name to select which one to use.
"""
_assert_has_custom_group()
return get_custom_group(group).all_reduce(input_, use_new, open_fp8_quant)
def custom_all_gather(
input_: torch.Tensor,
use_custom: bool = True,
dim: int = 0,
group: Optional[str] = None,
) -> torch.Tensor:
"""All-gather the input tensor across the user-specified custom group.
Args:
group: Name of the custom group. When only one custom group is
initialized this can be omitted. When multiple groups exist,
pass the group name to select which one to use.
"""
_assert_has_custom_group()
return get_custom_group(group).all_gather(input_, use_custom, dim)
def custom_reduce_scatter(
input_: torch.Tensor,
use_custom: bool = True,
dim: int = 0,
group: Optional[str] = None,
) -> torch.Tensor:
"""Reduce-scatter the input tensor across the user-specified custom group.
Args:
group: Name of the custom group. When only one custom group is
initialized this can be omitted. When multiple groups exist,
pass the group name to select which one to use.
"""
_assert_has_custom_group()
return get_custom_group(group).reduce_scatter_tensor(input_, use_custom, dim)
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -11,6 +13,13 @@ from aiter import logger ...@@ -11,6 +13,13 @@ from aiter import logger
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
def _env_flag(name: str, default: bool) -> bool:
val = os.environ.get(name)
if val is None:
return default
return val.strip().lower() in ("1", "true", "yes", "on")
class CudaCommunicator(DeviceCommunicatorBase): class CudaCommunicator(DeviceCommunicatorBase):
def __init__( def __init__(
self, self,
...@@ -64,9 +73,20 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -64,9 +73,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
if use_custom_allreduce and self.world_size > 1: if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation. # Initialize a custom fast all-reduce implementation.
# AITER_AR_ENABLE_REG_CAPTURE controls whether inputs captured
# inside a CUDA graph are assumed to already live in the
# pre-registered IPC buffer (True, default), or whether the
# in-graph all-reduce should fall back to the unregistered
# copy-in path (False). Set this to "0" when callers cannot
# guarantee that captured input pointers were registered via
# ``CustomAllreduce.register_buffer``.
enable_register_for_capturing = _env_flag(
"AITER_AR_ENABLE_REG_CAPTURE", default=True
)
self.ca_comm = CustomAllreduce( self.ca_comm = CustomAllreduce(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
enable_register_for_capturing=enable_register_for_capturing,
# symm_mem_enabled=( # symm_mem_enabled=(
# self.symm_mem_comm is not None and not self.symm_mem_comm.disabled # self.symm_mem_comm is not None and not self.symm_mem_comm.disabled
# ), # ),
...@@ -118,7 +138,13 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -118,7 +138,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.all2all_manager.__class__.__name__, self.all2all_manager.__class__.__name__,
) )
def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor: def all_reduce(
self,
input_,
use_new: bool = True,
ca_fp8_quant: bool = False,
prefill_support: bool = False,
) -> torch.Tensor:
# always try quick reduce first, then custom allreduce, # always try quick reduce first, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*) # and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm qr_comm = self.qr_comm
...@@ -137,7 +163,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -137,7 +163,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
and not ca_comm.disabled and not ca_comm.disabled
and ca_comm.should_custom_ar(input_) and ca_comm.should_custom_ar(input_)
): ):
out = ca_comm.custom_all_reduce(input_, ca_fp8_quant) out = ca_comm.custom_all_reduce(input_, use_new=use_new, open_fp8_quant=ca_fp8_quant)
assert out is not None assert out is not None
return out return out
symm_mem_comm = self.symm_mem_comm symm_mem_comm = self.symm_mem_comm
...@@ -159,7 +185,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -159,7 +185,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return out return out
def fused_allreduce_rmsnorm( def fused_allreduce_rmsnorm(
self, input_, res_inp_, weight_, eps self, input_, res_inp_, weight_, eps, prefill_support: bool = False
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
n = input_.shape[-1] n = input_.shape[-1]
can_use_fuse_ar_rms = ( can_use_fuse_ar_rms = (
...@@ -174,10 +200,12 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -174,10 +200,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
and ca_comm.should_custom_ar(input_) and ca_comm.should_custom_ar(input_)
and can_use_fuse_ar_rms and can_use_fuse_ar_rms
): ):
res_out, out = ca_comm.custom_fused_ar_rms(input_, res_inp_, weight_, eps) out, res_out = ca_comm.custom_fused_ar_rms(
input_, res_inp_, weight_, eps, use_1stage=prefill_support
)
assert out is not None assert out is not None
assert res_out is not None assert res_out is not None
return res_out, out return out, res_out
# call split kernel # call split kernel
ar_out = self.all_reduce(input_) ar_out = self.all_reduce(input_)
out = torch.empty_like(ar_out) out = torch.empty_like(ar_out)
...@@ -193,7 +221,138 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -193,7 +221,138 @@ class CudaCommunicator(DeviceCommunicatorBase):
eps, eps,
0, 0,
) )
return residual_out, out return out, residual_out
def fused_allreduce_rmsnorm_quant(
self,
input_,
res_inp_,
weight_,
eps,
prefill_support: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
total_bytes = input_.numel() * input_.element_size()
K = int(input_.shape[-1])
use_1stage = total_bytes <= 128 * 1024
# Hygon (gfx938/gfx946) kernel-level bug: the fused
# AR+RMSNorm+FP8-quant kernel produces 100% NaN for bf16 at
# K=4096 whenever use_1stage=False, regardless of which post-
# 1-stage path is dispatched (the 2-stage 128KB-512KB path AND
# the split >512KB path both fail). Empirically confirmed on
# shapes (17,4096), (32,4096), (128,4096), and (and very likely
# (512,4096)+ which we don't proceed to). The 2-stage kernel
# works for fp16 at the same K=4096, and for bf16 at K=7168 /
# K=8192; the failure is therefore K=4096-and-bf16-specific.
# Until the C++ kernel is fixed upstream, fall back to the
# Python split path (RMSNorm-only fused kernel + separate
# hip_quant) for the entire problematic configuration.
problematic_bf16_non_1stage = (
input_.dtype == torch.bfloat16
and not use_1stage
and K == 4096
)
if (
K in [512, 1024, 2048, 4096]
and total_bytes <= 4096 * 1024
and not problematic_bf16_non_1stage
):
out, res_out, scale_out = self.ca_comm.custom_fused_ar_rms_quant(
input_, res_inp_, weight_, eps, use_1stage
)
else:
out_, res_out = self.fused_allreduce_rmsnorm(
input_, res_inp_, weight_, eps, prefill_support
)
from aiter import get_hip_quant, QuantType
from aiter.utility.dtypes import fp8
hip_quant = get_hip_quant(QuantType.per_Token)
out, scale_out = hip_quant(out_, quant_dtype=fp8)
assert out is not None
assert res_out is not None
assert scale_out is not None
return out, res_out, scale_out
def fused_allreduce_rmsnorm_quant_per_group(
self,
input_,
res_inp_,
weight_,
eps,
group_size=128,
prefill_support: bool = False,
emit_bf16: bool = False,
):
total_bytes = input_.numel() * input_.element_size()
K = int(input_.shape[-1])
use_1stage = total_bytes <= 128 * 1024
out = res_out = scale_out = bf16_out = None
fused_ok = False
# See ``fused_allreduce_rmsnorm_quant`` for context, with one
# important difference: per-token quant's custom-kernel
# whitelist is K in {512, 1024, 2048, 4096}, so larger K values
# (6144 / 7168 / 8192) always go to the Python fallback there
# and never expose the kernel bug for those K. Per-group quant
# has a much wider whitelist (any K with K % group_size == 0
# and K <= 16384), so it surfaces the same bug at additional K
# values (K=4096 and K=6144 both empirically confirmed NaN;
# K=7168 / K=8192 untested but likely affected). Widen the
# fallback to all bf16 + non-1-stage configurations to be safe;
# the perf cost is limited since this only affects bf16 inputs
# whose total bytes exceed 128 KB (medium / large prefill).
problematic_bf16_non_1stage = (
input_.dtype == torch.bfloat16
and not use_1stage
)
if (
K % group_size == 0
and K <= 16384
and total_bytes < 8 * 1024 * 8192
and not problematic_bf16_non_1stage
):
try:
result = self.ca_comm.custom_fused_ar_rms_per_group_quant(
input_, res_inp_, weight_, eps, group_size, use_1stage,
emit_bf16=emit_bf16,
)
if emit_bf16:
out, res_out, scale_out, bf16_out = result
else:
out, res_out, scale_out = result
fused_ok = True
except Exception:
pass
if not fused_ok:
out_, res_out = self.fused_allreduce_rmsnorm(
input_, res_inp_, weight_, eps, prefill_support
)
from aiter import get_hip_quant, QuantType
from aiter.utility.dtypes import fp8
hip_quant = get_hip_quant(QuantType.per_1x128)
out, scale_out = hip_quant(out_, quant_dtype=fp8)
if emit_bf16:
bf16_out = out_
assert out is not None
assert res_out is not None
assert scale_out is not None
if emit_bf16:
assert bf16_out is not None
return out, res_out, scale_out, bf16_out
return out, res_out, scale_out
def fused_qknorm_allreduce(
self,
qkv_in,
q_w,
k_w,
eps,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q_out, k_out, v_out = self.ca_comm.custom_fused_qknorm_ar(
qkv_in, q_w, k_w, eps
)
assert q_out is not None
assert k_out is not None
assert v_out is not None
return q_out, k_out, v_out
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size world_size = self.world_size
......
...@@ -319,7 +319,11 @@ class GroupCoordinator: ...@@ -319,7 +319,11 @@ class GroupCoordinator:
yield graph_capture_context yield graph_capture_context
def all_reduce( def all_reduce(
self, input_: torch.Tensor, ca_fp8_quant: bool = False self,
input_: torch.Tensor,
use_new: bool = True,
open_fp8_quant: bool = False,
prefill_support: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
User-facing all-reduce function before we actually call the User-facing all-reduce function before we actually call the
...@@ -340,7 +344,7 @@ class GroupCoordinator: ...@@ -340,7 +344,7 @@ class GroupCoordinator:
return input_ return input_
return all_reduce_( return all_reduce_(
input_, group_name=self.unique_name, ca_fp8_quant=ca_fp8_quant input_, group_name=self.unique_name, ca_fp8_quant=open_fp8_quant
) )
def _all_reduce_out_place( def _all_reduce_out_place(
...@@ -348,7 +352,7 @@ class GroupCoordinator: ...@@ -348,7 +352,7 @@ class GroupCoordinator:
) -> torch.Tensor: ) -> torch.Tensor:
if self.device_communicator is None: if self.device_communicator is None:
raise ValueError("No device communicator found") raise ValueError("No device communicator found")
return self.device_communicator.all_reduce(input_, ca_fp8_quant) return self.device_communicator.all_reduce(input_, ca_fp8_quant=ca_fp8_quant)
def fused_allreduce_rmsnorm( def fused_allreduce_rmsnorm(
self, self,
...@@ -356,11 +360,54 @@ class GroupCoordinator: ...@@ -356,11 +360,54 @@ class GroupCoordinator:
residual_inp_: torch.Tensor, residual_inp_: torch.Tensor,
weight_: torch.Tensor, weight_: torch.Tensor,
eps: float, eps: float,
prefill_support: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return fused_allreduce_rmsnorm_( return fused_allreduce_rmsnorm_(
input_, residual_inp_, weight_, eps, group_name=self.unique_name input_, residual_inp_, weight_, eps, group_name=self.unique_name
) )
def fused_allreduce_rmsnorm_quant(
self,
input_: torch.Tensor,
residual_inp_: torch.Tensor,
weight_: torch.Tensor,
eps: float,
prefill_support: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.fused_allreduce_rmsnorm_quant(
input_, residual_inp_, weight_, eps, prefill_support
)
def fused_allreduce_rmsnorm_quant_per_group(
self,
input_: torch.Tensor,
residual_inp_: torch.Tensor,
weight_: torch.Tensor,
eps: float,
group_size: int = 128,
prefill_support: bool = False,
emit_bf16: bool = False,
):
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.fused_allreduce_rmsnorm_quant_per_group(
input_, residual_inp_, weight_, eps, group_size, prefill_support,
emit_bf16=emit_bf16,
)
def fused_qknorm_allreduce(
self,
qkv_in: torch.Tensor,
q_w: torch.Tensor,
k_w: torch.Tensor,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.fused_qknorm_allreduce(qkv_in, q_w, k_w, eps)
def _fused_allreduce_rmsnorm_out_place( def _fused_allreduce_rmsnorm_out_place(
self, self,
input_: torch.Tensor, input_: torch.Tensor,
...@@ -375,12 +422,31 @@ class GroupCoordinator: ...@@ -375,12 +422,31 @@ class GroupCoordinator:
) )
def _all_gather_out_place(self, input_: torch.Tensor) -> torch.Tensor: def _all_gather_out_place(self, input_: torch.Tensor) -> torch.Tensor:
ca_comm = self.device_communicator.ca_comm ca_comm = (
assert ca_comm is not None self.device_communicator.ca_comm
assert not ca_comm.disabled if self.device_communicator is not None
out = ca_comm.custom_all_gather(input_) else None
assert out is not None )
return out if ca_comm is not None and not ca_comm.disabled:
out = ca_comm.custom_all_gather(input_)
assert out is not None
return out
# Fallback: ca_comm unavailable (e.g. non-tp custom group).
# Try pynccl first (graph-safe), then torch.distributed.
world_size = self.world_size
out_shape = (input_.shape[0] * world_size,) + input_.shape[1:]
output_tensor = torch.empty(
out_shape, dtype=input_.dtype, device=input_.device
)
if self.device_communicator is not None:
pynccl_comm = self.device_communicator.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.all_gather(output_tensor, input_)
return output_tensor
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
return output_tensor
def custom_all_gather(self, input_: torch.Tensor) -> torch.Tensor: def custom_all_gather(self, input_: torch.Tensor) -> torch.Tensor:
return outplace_all_gather(input_, group_name=self.unique_name) return outplace_all_gather(input_, group_name=self.unique_name)
...@@ -390,6 +456,33 @@ class GroupCoordinator: ...@@ -390,6 +456,33 @@ class GroupCoordinator:
raise ValueError("No device communicator found") raise ValueError("No device communicator found")
return self.device_communicator.reduce_scatter(input_, dim) return self.device_communicator.reduce_scatter(input_, dim)
def reduce_scatter_tensor(
self,
input_: torch.Tensor,
use_custom: bool = True,
dim: int = 0,
):
world_size = self.world_size
assert world_size > 1, "error! world_size = 1"
assert (
input_.numel() % world_size == 0
), "input shape error, input.numel() % world_size should equals to 0"
if input_.shape[0] % world_size == 0:
out_dim0 = input_.shape[0] // world_size
out_shape = (out_dim0,) + input_.shape[1:]
else:
out_shape = (input_.numel() // world_size,)
if use_custom and self.device_communicator is not None:
return self.device_communicator.reduce_scatter(input_, dim)
else:
output_ = torch.empty(
out_shape, dtype=input_.dtype, device=input_.device
)
torch.distributed.reduce_scatter_tensor(
output_, input_, group=self.device_group
)
return output_
def all_gather( def all_gather(
self, input_: torch.Tensor, use_custom: bool = False, dim: int = -1 self, input_: torch.Tensor, use_custom: bool = False, dim: int = -1
...@@ -897,6 +990,71 @@ def get_ep_group() -> GroupCoordinator: ...@@ -897,6 +990,71 @@ def get_ep_group() -> GroupCoordinator:
return _EP return _EP
_CUSTOM: Dict[str, GroupCoordinator] = {}
def has_custom_group() -> bool:
"""Return whether any custom group is initialized."""
return bool(_CUSTOM)
def get_custom_group(
name: Optional[str] = None,
) -> "Union[GroupCoordinator, Dict[str, GroupCoordinator]]":
"""Get custom group coordinator(s).
- If only one custom group is initialized, returns the GroupCoordinator
instance directly (name is optional).
- If multiple custom groups are initialized and name is None, returns the
full dict so the caller can select by name.
- If name is given, returns that specific GroupCoordinator.
"""
assert _CUSTOM, "custom allreduce group is not initialized"
if name is not None:
assert name in _CUSTOM, (
f"custom group '{name}' not found, "
f"available: {list(_CUSTOM.keys())}"
)
return _CUSTOM[name]
if len(_CUSTOM) == 1:
return next(iter(_CUSTOM.values()))
return dict(_CUSTOM)
class CustomGroupConfig:
"""Configuration builder for custom communication groups.
Each group is defined by a rank list that can be:
- 1D List[int]: all ranks form a single communication group,
e.g. [0,1,2,3,4,5,6,7] → one TP8 group
- 2D List[List[int]]: multiple independent subgroups,
e.g. [[0,1,2,3],[4,5,6,7]] → two independent TP4 groups
Usage:
config = CustomGroupConfig()
config.add_group("tp_group", [[0,1,2,3],[4,5,6,7]])
ensure_model_parallel_initialized(..., custom_group_config=config.data())
Or pass a raw dict directly:
ensure_model_parallel_initialized(..., custom_group_config={
"tp_group": [[0,1,2,3],[4,5,6,7]],
})
"""
def __init__(self):
self._groups: Dict[str, List] = {}
def add_group(self, name: str, ranks: List) -> "CustomGroupConfig":
assert name not in self._groups, f"custom group '{name}' already exists"
assert ranks, f"custom group '{name}': ranks list must not be empty"
self._groups[name] = ranks
return self
def data(self) -> Dict[str, List]:
assert self._groups, "no custom groups have been added"
return dict(self._groups)
# kept for backward compatibility # kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group get_pipeline_model_parallel_group = get_pp_group
...@@ -996,6 +1154,7 @@ def initialize_model_parallel( ...@@ -996,6 +1154,7 @@ def initialize_model_parallel(
# decode_context_model_parallel_size: Optional[int] = 1, # decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None, backend: Optional[str] = None,
data_parallel_size: int = 1, data_parallel_size: int = 1,
custom_group_config: Optional[Dict[str, List]] = None,
) -> None: ) -> None:
""" """
Initialize model parallel groups. Initialize model parallel groups.
...@@ -1006,6 +1165,12 @@ def initialize_model_parallel( ...@@ -1006,6 +1165,12 @@ def initialize_model_parallel(
pipeline_model_parallel_size: number of GPUs used for pipeline model pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism. parallelism.
backend: name of torch distributed communication backend. backend: name of torch distributed communication backend.
custom_group_config: optional dict mapping group names to rank lists.
Each value can be:
- 1D List[int]: all ranks form a single group,
e.g. [0,1,2,3,4,5,6,7]
- 2D List[List[int]]: multiple independent subgroups,
e.g. [[0,1,2,3],[4,5,6,7]]
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...@@ -1109,6 +1274,55 @@ def initialize_model_parallel( ...@@ -1109,6 +1274,55 @@ def initialize_model_parallel(
group_ranks, get_world_group().local_rank, backend, group_name="ep" group_ranks, get_world_group().local_rank, backend, group_name="ep"
) )
# Build the custom allreduce group(s) (optional).
global _CUSTOM
assert not _CUSTOM, "custom allreduce group is already initialized"
if custom_group_config is not None:
for gname, ranks in custom_group_config.items():
assert (
isinstance(ranks, list) and len(ranks) > 0
), f"custom group '{gname}': value must be a non-empty list"
if all(isinstance(r, int) for r in ranks):
group_ranks = [ranks]
elif all(isinstance(g, list) for g in ranks):
group_ranks = ranks
subgroup_size = len(group_ranks[0])
for g in group_ranks:
assert len(g) == subgroup_size, (
f"custom group '{gname}': all subgroups must "
f"have the same size, expected {subgroup_size} "
f"but got {len(g)}"
)
assert all(isinstance(r, int) for r in g), (
f"custom group '{gname}': subgroup elements "
f"must be integers"
)
else:
raise AssertionError(
f"custom group '{gname}': value must be List[int] "
f"(1D) or List[List[int]] (2D)"
)
all_ranks_flat = [r for g in group_ranks for r in g]
assert len(all_ranks_flat) == world_size, (
f"custom group '{gname}': total ranks "
f"({len(all_ranks_flat)}) must equal world_size ({world_size})"
)
assert len(set(all_ranks_flat)) == world_size, (
f"custom group '{gname}': contains duplicate ranks"
)
assert set(all_ranks_flat) == set(range(world_size)), (
f"custom group '{gname}': must cover all ranks 0..{world_size - 1}"
)
_CUSTOM[gname] = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
group_name=f"custom_{gname}",
)
logger.info( logger.info(
"rank %s in world size %s is assigned as " "rank %s in world size %s is assigned as "
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", "DP rank %s, PP rank %s, TP rank %s, EP rank %s",
...@@ -1126,6 +1340,7 @@ def ensure_model_parallel_initialized( ...@@ -1126,6 +1340,7 @@ def ensure_model_parallel_initialized(
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
backend: Optional[str] = None, backend: Optional[str] = None,
data_parallel_size: int = 1, data_parallel_size: int = 1,
custom_group_config: Optional[Dict[str, List]] = None,
) -> None: ) -> None:
"""Helper to initialize model parallel groups if they are not initialized, """Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
...@@ -1138,6 +1353,7 @@ def ensure_model_parallel_initialized( ...@@ -1138,6 +1353,7 @@ def ensure_model_parallel_initialized(
pipeline_model_parallel_size, pipeline_model_parallel_size,
backend, backend,
data_parallel_size, data_parallel_size,
custom_group_config,
) )
return return
...@@ -1209,6 +1425,11 @@ def destroy_model_parallel(): ...@@ -1209,6 +1425,11 @@ def destroy_model_parallel():
_PP.destroy() _PP.destroy()
_PP = None _PP = None
global _CUSTOM
for g in _CUSTOM.values():
g.destroy()
_CUSTOM.clear()
def destroy_distributed_environment(): def destroy_distributed_environment():
global _WORLD global _WORLD
......
...@@ -733,11 +733,11 @@ def torch_moe_blockscale( ...@@ -733,11 +733,11 @@ def torch_moe_blockscale(
# [expert, model_dim/blk_m, inter_dim/blk_k] # [expert, model_dim/blk_m, inter_dim/blk_k]
fc2_scale=None, fc2_scale=None,
expert_mask=None, expert_mask=None,
computeType=torch.float32,
): ):
computeType = dtypes.fp32 hidden_states = hidden_states.float().to(computeType)
hidden_states = hidden_states.to(computeType) w1 = w1.float().to(computeType)
w1 = w1.to(computeType) w2 = w2.float().to(computeType)
w2 = w2.to(computeType)
token_num, topk = topk_ids.shape token_num, topk = topk_ids.shape
expert, model_dim, inter_dim = w2.shape expert, model_dim, inter_dim = w2.shape
B, D = hidden_states.shape B, D = hidden_states.shape
...@@ -767,9 +767,8 @@ def torch_moe_blockscale( ...@@ -767,9 +767,8 @@ def torch_moe_blockscale(
nblk_n = inter_dim // blk_n nblk_n = inter_dim // blk_n
nblk_k = model_dim // blk_k nblk_k = model_dim // blk_k
if fc1_scale is not None: if fc1_scale is not None:
# gose to quant D_w8a8/w8a8 fc1_scale = fc1_scale.to(computeType)
# blk_n, blk_k = scale_blks fc2_scale = fc2_scale.to(computeType)
# expert, nblk_n, nblk_k = fc1_scale.shape
fc1_scale = rearrange( fc1_scale = rearrange(
fc1_scale.view(-1, 1) fc1_scale.view(-1, 1)
.repeat(1, blk_n * blk_k) .repeat(1, blk_n * blk_k)
......
...@@ -12,13 +12,13 @@ from aiter import logger ...@@ -12,13 +12,13 @@ from aiter import logger
from aiter import per_token_quant_hip, per_block_quant_wrapper, get_hip_quant from aiter import per_token_quant_hip, per_block_quant_wrapper, get_hip_quant
from aiter import ActivationType, QuantType, dtypes from aiter import ActivationType, QuantType, dtypes
from aiter import silu_and_mul,gelu_and_mul from aiter import silu_and_mul,gelu_and_mul
from aiter.ops.triton.fused_moe import ( from aiter.ops.triton.fused_moe import triton_moe_sum
triton_moe_sum, from aiter.ops.triton.moe_activation import (
triton_silu_and_mul, _normalize_activation_and_gate,
triton_gelu_and_mul, _apply_activation,
triton_relu2,
) )
from aiter.jit.core import AITER_ROOT_DIR from aiter.jit.core import AITER_ROOT_DIR
# from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size # from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
# from vllm.model_executor.layers.quantization.utils.int8_utils import ( # from vllm.model_executor.layers.quantization.utils.int8_utils import (
...@@ -111,6 +111,7 @@ def run_fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -111,6 +111,7 @@ def run_fused_experts_asm_impl(hidden_states: torch.Tensor,
dtype, dtype,
inplace, inplace,
activation, activation,
None, # is_gated
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w4a8, use_int8_w4a8,
...@@ -181,6 +182,7 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -181,6 +182,7 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
dtype: torch.dtype, dtype: torch.dtype,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
is_gated: Optional[bool] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w4a8: bool = False, use_int8_w4a8: bool = False,
...@@ -200,7 +202,12 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -200,7 +202,12 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
persist_cu: Optional[int] = 0, persist_cu: Optional[int] = 0,
use_shuffle: Optional[int] = 0, use_shuffle: Optional[int] = 0,
solution_id: Optional[str] = None, solution_id: Optional[str] = None,
routed_scaling_factor: Optional[float] = 1.0)-> torch.Tensor: routed_scaling_factor: Optional[float] = 1.0,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None)-> torch.Tensor:
activation, is_gated = _normalize_activation_and_gate(activation, is_gated)
# Check constraints. # Check constraints.
if use_int8_w4a8: if use_int8_w4a8:
assert block_shape[0] == 0 and block_shape[1] == 64, "[ERROR]ASM Fused MoE only support w4a8 block_shape=64 now." assert block_shape[0] == 0 and block_shape[1] == 64, "[ERROR]ASM Fused MoE only support w4a8 block_shape=64 now."
...@@ -342,14 +349,14 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -342,14 +349,14 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
2, 2,
config["SOL_ID1"], config["SOL_ID1"],
config["BLOCK_SIZE_M"]) config["BLOCK_SIZE_M"])
if activation == "silu": _apply_activation(
triton_silu_and_mul(d_silu,d_w1_out) activation=activation,
# silu_and_mul(d_silu,d_w1_out) is_gated=is_gated,
elif activation == "gelu": activated_out=d_silu,
triton_gelu_and_mul(d_silu,d_w1_out) ffn1_out_2d=d_w1_out.view(-1, N),
# gelu_and_mul(d_silu,d_w1_out) gemm1_alpha=gemm1_alpha,
else: gemm1_limit=gemm1_limit,
raise ValueError(f"Unsupported FusedMoe activation: {activation}") )
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
if block_shape is not None and block_shape[1] == 32: if block_shape is not None and block_shape[1] == 32:
aiter.asm_fmoe_stage2(d_w2_out, aiter.asm_fmoe_stage2(d_w2_out,
...@@ -442,14 +449,15 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -442,14 +449,15 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
odtype, odtype,
config["PERSIST_GROUP1"], config["PERSIST_GROUP1"],
use_shuffle) use_shuffle)
if activation == "silu":
triton_silu_and_mul(d_silu,d_w1_out) _apply_activation(
# silu_and_mul(d_silu,d_w1_out) activation=activation,
elif activation == "gelu": is_gated=is_gated,
triton_gelu_and_mul(d_silu,d_w1_out) activated_out=d_silu,
# gelu_and_mul(d_silu,d_w1_out) ffn1_out_2d=d_w1_out.view(-1, N),
else: gemm1_alpha=gemm1_alpha,
raise ValueError(f"Unsupported FusedMoe activation: {activation}") gemm1_limit=gemm1_limit,
)
bridge_q,bridge_scale = per_token_quant_hip(d_silu) bridge_q,bridge_scale = per_token_quant_hip(d_silu)
#bridge_q,bridge_scale = per_token_quant_int8(d_silu) #bridge_q,bridge_scale = per_token_quant_int8(d_silu)
...@@ -515,14 +523,15 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -515,14 +523,15 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
config["SOL_ID1"], config["SOL_ID1"],
odtype, odtype,
config["PERSIST_GROUP1"]) config["PERSIST_GROUP1"])
if activation == "silu":
triton_silu_and_mul(d_silu,d_w1_out) _apply_activation(
# silu_and_mul(d_silu,d_w1_out) activation=activation,
elif activation == "gelu": is_gated=is_gated,
triton_gelu_and_mul(d_silu,d_w1_out) activated_out=d_silu,
# gelu_and_mul(d_silu,d_w1_out) ffn1_out_2d=d_w1_out.view(-1, N),
else: gemm1_alpha=gemm1_alpha,
raise ValueError(f"Unsupported FusedMoe activation: {activation}") gemm1_limit=gemm1_limit,
)
#quant_func = get_hip_quant(QuantType.per_1x64) #quant_func = get_hip_quant(QuantType.per_1x64)
#bridge_q,bridge_scale = quant_func(d_silu, quant_dtype=dtypes.i8) #bridge_q,bridge_scale = quant_func(d_silu, quant_dtype=dtypes.i8)
...@@ -586,14 +595,15 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -586,14 +595,15 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
odtype, odtype,
config["PERSIST_GROUP1"], config["PERSIST_GROUP1"],
use_shuffle) use_shuffle)
if activation == "silu":
triton_silu_and_mul(d_silu,d_w1_out) _apply_activation(
# silu_and_mul(d_silu,d_w1_out) activation=activation,
elif activation == "gelu": is_gated=is_gated,
triton_gelu_and_mul(d_silu,d_w1_out) activated_out=d_silu,
# gelu_and_mul(d_silu,d_w1_out) ffn1_out_2d=d_w1_out.view(-1, N),
else: gemm1_alpha=gemm1_alpha,
raise ValueError(f"Unsupported FusedMoe activation: {activation}") gemm1_limit=gemm1_limit,
)
#FIXME: aiter quant method performance is little worse than triton. Change it latter!! #FIXME: aiter quant method performance is little worse than triton. Change it latter!!
bridge_q, bridge_scale = per_block_quant_wrapper((1,block_shape[1]))(per_token_quant_hip)(d_silu, quant_dtype=torch.int8) bridge_q, bridge_scale = per_block_quant_wrapper((1,block_shape[1]))(per_token_quant_hip)(d_silu, quant_dtype=torch.int8)
...@@ -657,14 +667,15 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -657,14 +667,15 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
odtype, odtype,
config["PERSIST_GROUP1"], config["PERSIST_GROUP1"],
use_shuffle) use_shuffle)
if activation == "silu":
triton_silu_and_mul(d_silu,d_w1_out) _apply_activation(
# silu_and_mul(d_silu,d_w1_out) activation=activation,
elif activation == "gelu": is_gated=is_gated,
triton_gelu_and_mul(d_silu,d_w1_out) activated_out=d_silu,
# gelu_and_mul(d_silu,d_w1_out) ffn1_out_2d=d_w1_out.view(-1, N),
else: gemm1_alpha=gemm1_alpha,
raise ValueError(f"Unsupported FusedMoe activation: {activation}") gemm1_limit=gemm1_limit,
)
bridge_q,bridge_scale= per_token_quant_hip(d_silu, quant_dtype=torch.float8_e4m3fn) bridge_q,bridge_scale= per_token_quant_hip(d_silu, quant_dtype=torch.float8_e4m3fn)
aiter.asm_fmoe_a8(d_w2_out, aiter.asm_fmoe_a8(d_w2_out,
...@@ -726,14 +737,15 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -726,14 +737,15 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
odtype, odtype,
config["PERSIST_GROUP1"], config["PERSIST_GROUP1"],
use_shuffle) use_shuffle)
if activation == "silu":
triton_silu_and_mul(d_silu,d_w1_out) _apply_activation(
# silu_and_mul(d_silu,d_w1_out) activation=activation,
elif activation == "gelu": is_gated=is_gated,
triton_gelu_and_mul(d_silu,d_w1_out) activated_out=d_silu,
# gelu_and_mul(d_silu,d_w1_out) ffn1_out_2d=d_w1_out.view(-1, N),
else: gemm1_alpha=gemm1_alpha,
raise ValueError(f"Unsupported FusedMoe activation: {activation}") gemm1_limit=gemm1_limit,
)
bridge_q,bridge_scale = per_block_quant_wrapper((1,block_shape[1]))(per_token_quant_hip)(d_silu, quant_dtype=torch.float8_e4m3fn) bridge_q,bridge_scale = per_block_quant_wrapper((1,block_shape[1]))(per_token_quant_hip)(d_silu, quant_dtype=torch.float8_e4m3fn)
aiter.asm_fmoe_a8(d_w2_out, aiter.asm_fmoe_a8(d_w2_out,
...@@ -795,16 +807,14 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -795,16 +807,14 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
config["PERSIST_GROUP1"], config["PERSIST_GROUP1"],
use_shuffle) use_shuffle)
#return d_w1_out #return d_w1_out
if activation == "silu": _apply_activation(
triton_silu_and_mul(d_silu,d_w1_out) activation=activation,
# silu_and_mul(d_silu,d_w1_out) is_gated=is_gated,
elif activation == "gelu": activated_out=d_silu,
triton_gelu_and_mul(d_silu,d_w1_out) ffn1_out_2d=d_w1_out.view(-1, N),
# gelu_and_mul(d_silu,d_w1_out) gemm1_alpha=gemm1_alpha,
elif activation == "relu2": gemm1_limit=gemm1_limit,
triton_relu2(d_silu,d_w1_out) )
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
aiter.asm_fmoe_a8(d_w2_out, aiter.asm_fmoe_a8(d_w2_out,
d_silu, d_silu,
...@@ -1136,4 +1146,4 @@ def calculate_persist_groups(persist_cu, config, quant_type): ...@@ -1136,4 +1146,4 @@ def calculate_persist_groups(persist_cu, config, quant_type):
if config[f"SOL_ID{i}"] in sol_id_table: if config[f"SOL_ID{i}"] in sol_id_table:
config[f"PERSIST_GROUP{i}"] = persist_cu * sol_id_table[config[f'SOL_ID{i}']] config[f"PERSIST_GROUP{i}"] = persist_cu * sol_id_table[config[f'SOL_ID{i}']]
else: else:
config[f"PERSIST_GROUP{i}"] = persist_cu config[f"PERSIST_GROUP{i}"] = persist_cu
\ No newline at end of file
This diff is collapsed.
...@@ -552,7 +552,7 @@ def build_module( ...@@ -552,7 +552,7 @@ def build_module(
"-Wno-vla-cxx-extension", "-Wno-vla-cxx-extension",
"-Wno-undefined-func-template", "-Wno-undefined-func-template",
"-Wno-macro-redefined", "-Wno-macro-redefined",
"-Wno-missing-template-arg-list-after-template-kw", # "-Wno-missing-template-arg-list-after-template-kw",
"-fgpu-flush-denormals-to-zero", "-fgpu-flush-denormals-to-zero",
] ]
...@@ -794,6 +794,7 @@ def compile_ops( ...@@ -794,6 +794,7 @@ def compile_ops(
fc_name: Optional[str] = None, fc_name: Optional[str] = None,
gen_func: Optional[Callable[..., dict[str, Any]]] = None, gen_func: Optional[Callable[..., dict[str, Any]]] = None,
gen_fake: Optional[Callable[..., Any]] = None, gen_fake: Optional[Callable[..., Any]] = None,
develop: bool = False,
): ):
def decorator(func): def decorator(func):
func.arg_checked = False func.arg_checked = False
...@@ -897,12 +898,18 @@ def compile_ops( ...@@ -897,12 +898,18 @@ def compile_ops(
doc_str = re.sub(pattern, r"Optional[\1]", doc_str) doc_str = re.sub(pattern, r"Optional[\1]", doc_str)
for el in enum_types: for el in enum_types:
doc_str = re.sub(f" aiter.*{el} ", f" {el} ", doc_str) doc_str = re.sub(f" aiter.*{el} ", f" {el} ", doc_str)
try:
from ..utility.aiter_types import aiter_tensor_t as _aiter_tensor_t
except ImportError:
_aiter_tensor_t = None
namespace = { namespace = {
"List": List, "List": List,
"Optional": Optional, "Optional": Optional,
"torch": torch, "torch": torch,
"typing": typing, "typing": typing,
} }
if _aiter_tensor_t is not None:
namespace["aiter_tensor_t"] = _aiter_tensor_t
exec( exec(
f"from aiter import*\ndef {doc_str}: pass", f"from aiter import*\ndef {doc_str}: pass",
...@@ -955,13 +962,34 @@ def compile_ops( ...@@ -955,13 +962,34 @@ def compile_ops(
return True return True
if not func.arg_checked: if not func.arg_checked:
func.arg_checked = check_args() if develop:
func.arg_checked = True # skip type-check when develop=True; tensors are converted below
else:
func.arg_checked = check_args()
if AITER_LOG_MORE == 2: if AITER_LOG_MORE == 2:
from ..test_common import log_args from ..test_common import log_args
log_args(func, *args, **kwargs) log_args(func, *args, **kwargs)
# develop=True: convert torch.Tensor → pybind aiter_tensor_t and inject HIP stream.
# develop=False (default): all existing ops pass through unchanged.
if develop:
import torch
from ..utility.dtypes import torch_to_aiter_pybind
args = tuple(
torch_to_aiter_pybind(a) if isinstance(a, torch.Tensor) else a
for a in args
)
kwargs = {
k: (torch_to_aiter_pybind(v) if isinstance(v, torch.Tensor) else v)
for k, v in kwargs.items()
}
module._set_current_hip_stream(
torch.cuda.current_stream().cuda_stream
)
return op(*args, **kwargs) return op(*args, **kwargs)
@torch_compile_guard(device="cuda", gen_fake=gen_fake, calling_func_=func) @torch_compile_guard(device="cuda", gen_fake=gen_fake, calling_func_=func)
......
...@@ -80,6 +80,41 @@ ...@@ -80,6 +80,41 @@
"hipify": "True", "hipify": "True",
"blob_gen_cmd": "''" "blob_gen_cmd": "''"
}, },
"module_grouped_gemm": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/grouped_gemm_ck_pybind.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_ck/grouped_gemm_kernels.cu'",
"f'{CK_DIR}/example_hcu/ck_tile/19_grouped_gemm/grouped_gemm.cpp'",
"f'{CK_DIR}/example_hcu/ck_tile/19_grouped_gemm/instances/grouped_gemm_fp16.cpp'",
"f'{CK_DIR}/example_hcu/ck_tile/19_grouped_gemm/instances/grouped_gemm_bf16.cpp'",
"f'{CK_DIR}/example_hcu/ck_tile/19_grouped_gemm/instances/grouped_gemm_fp8.cpp'",
"f'{CK_DIR}/example_hcu/ck_tile/19_grouped_gemm/instances/grouped_gemm_int8.cpp'"
],
"flags_extra_cc": [
"'-DCK_TILE_GROUPED_GEMM_FAST_BUILD'",
"'-DCK_TILE_GROUPED_GEMM_FAST_FP16'",
"'-DCK_TILE_GROUPED_GEMM_FAST_BF16'",
"'-DCK_TILE_GROUPED_GEMM_FAST_FP8'",
"'-DCK_TILE_GROUPED_GEMM_FAST_INT8'",
"'-DCK_TILE_GROUPED_GEMM_FAST_RC_ONLY'"
],
"flags_extra_hip": [
"'-DCK_TILE_GROUPED_GEMM_FAST_BUILD'",
"'-DCK_TILE_GROUPED_GEMM_FAST_FP16'",
"'-DCK_TILE_GROUPED_GEMM_FAST_BF16'",
"'-DCK_TILE_GROUPED_GEMM_FAST_FP8'",
"'-DCK_TILE_GROUPED_GEMM_FAST_INT8'",
"'-DCK_TILE_GROUPED_GEMM_FAST_RC_ONLY'"
],
"extra_ldflags": "None",
"extra_include": [
"f'{CK_DIR}/example_hcu/ck_tile/19_grouped_gemm'",
"f'{CK_DIR}/example_hcu/ck_tile/19_grouped_gemm/instances'"
],
"verbose": "False",
"hipify": "False",
"blob_gen_cmd": "''"
},
"module_moe_utils":{ "module_moe_utils":{
"srcs": [ "srcs": [
"f'{AITER_CSRC_DIR}/pybind/moe_utils_pybind.cu'", "f'{AITER_CSRC_DIR}/pybind/moe_utils_pybind.cu'",
...@@ -351,7 +386,7 @@ ...@@ -351,7 +386,7 @@
"f'{MOE_C_DIR}/csrc_for_aiter'", "f'{MOE_C_DIR}/csrc_for_aiter'",
"f'{AITER_CSRC_DIR}/py_itfs_moe_c/moe_c.cu'" "f'{AITER_CSRC_DIR}/py_itfs_moe_c/moe_c.cu'"
], ],
"flags_extra_cc": ["' -mllvm -support-768-vgprs=true -mllvm -disable-machine-sink '" "flags_extra_cc": ["' -mllvm -support-768-vgprs=true -mllvm -disable-machine-sink -w'"
], ],
"flags_extra_hip": [], "flags_extra_hip": [],
"extra_ldflags": "None", "extra_ldflags": "None",
...@@ -386,5 +421,17 @@ ...@@ -386,5 +421,17 @@
"verbose": "False", "verbose": "False",
"hipify": "True", "hipify": "True",
"blob_gen_cmd": "''" "blob_gen_cmd": "''"
},
"module_mhc": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/mhc_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/mhc_kernels.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
} }
} }
This diff is collapsed.
{
"1": {
"BLOCK_SIZE_M": 16,
"MODE": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"3": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"5": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"6": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"7": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"9": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"10": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"11": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"12": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"13": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"14": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"15": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"MODE": 1
},
"512": {
"BLOCK_SIZE_M": 16,
"MODE": 5
},
"1024": {
"BLOCK_SIZE_M": 16,
"MODE": 5
},
"2048": {
"BLOCK_SIZE_M": 16,
"MODE": 5
},
"4096": {
"BLOCK_SIZE_M": 16,
"MODE": 5
},
"8192": {
"BLOCK_SIZE_M": 16,
"MODE": 5
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment