Commit 2b47d4fa authored by zhuwenwen's avatar zhuwenwen
Browse files

支持fusemoe对int4的scale合zero合并读取操作

parent ba1ff372
{ {
"1": { "1": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 4, "num_stages": 2,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
}, },
"2": { "2": {
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
"4": { "4": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
"8": { "8": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
...@@ -37,17 +37,17 @@ ...@@ -37,17 +37,17 @@
}, },
"16": { "16": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 1,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
}, },
"24": { "24": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
...@@ -56,8 +56,8 @@ ...@@ -56,8 +56,8 @@
"32": { "32": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
...@@ -65,52 +65,52 @@ ...@@ -65,52 +65,52 @@
"48": { "48": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 1,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
}, },
"64": { "64": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
}, },
"96": { "96": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 1,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
}, },
"128": { "128": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 1,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
}, },
"256": { "256": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 1,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
}, },
"512": { "512": {
"BLOCK_SIZE_M": 32, "BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
...@@ -118,11 +118,11 @@ ...@@ -118,11 +118,11 @@
}, },
"1024": { "1024": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 1,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
}, },
"1536": { "1536": {
...@@ -130,13 +130,13 @@ ...@@ -130,13 +130,13 @@
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
}, },
"2048": { "2048": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
...@@ -144,21 +144,21 @@ ...@@ -144,21 +144,21 @@
"num_ldmatrixes": 0 "num_ldmatrixes": 0
}, },
"3072": { "3072": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 4,
"num_stages": 2, "num_stages": 1,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
}, },
"4096": { "4096": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 4,
"num_stages": 2, "num_stages": 1,
"num_ldmatrixes": 0 "num_ldmatrixes": 0
} }
} }
{
"1": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_ldmatrixes": 0
},
"6144": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 1,
"num_ldmatrixes": 0
}
}
{
"1": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_ldmatrixes": 0
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_ldmatrixes": 0
}
}
{
"1": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 0
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"num_ldmatrixes": 0
}
}
...@@ -23,6 +23,163 @@ from vllm.utils import direct_register_custom_op ...@@ -23,6 +23,163 @@ from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
@triton.jit
def fused_moe_kernel_awq(
# Pointers to matrices
a_ptr, # [4, 7168]
b_ptr, # [256, 512, 3584]
c_ptr, # (8, 8, 512)
b_scale_ptr, # (256, 512, 56)
b_zp_ptr, # (256, 256, 56)
topk_weights_ptr,
sorted_token_ids_ptr, # [0, 1, 2, 3, 4]
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N: tl.constexpr,
K: tl.constexpr,
EM, # pading后的总索引长度
num_valid_tokens, # 有效索引的上限
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk, #1
stride_bn,
stride_cm,
stride_cn,
stride_bse,
stride_bsk,#1
stride_bsn,
stride_bze,
stride_bzk,
stride_bzn,
block_k_diviable: tl.constexpr,
group_size: tl.constexpr, # 128
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) # [block_m]
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N)) % N # [block_n]
offs_k = tl.arange(0, BLOCK_SIZE_K) # 0, 1, 2, ...... , 127 # # [block_k]
offs_k2 = tl.arange(0, BLOCK_SIZE_K // 2) # 0, 1, 2, ...... , 127 # # [block_k]
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak) # [block_m, block_k]
off_experts = tl.load(expert_ids_ptr + pid_m)
if use_int4_w4a16:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
# b_ptrs = b_ptr + off_experts * stride_be + \
# (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
b_ptrs = b_ptr + off_experts * stride_be + \
offs_bn[:, None] * stride_bn + (offs_k2[None, :]) * stride_bk
# tl.device_print("stride_bn",stride_bsn)>1
# tl.device_print("stride_bk",stride_bk)=1
b_shifter = (offs_k[:, None] % 2) * 4 # 0, 4
elif use_int8_w8a16:
b_ptrs = b_ptr + off_experts * stride_be + \
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
if not has_zp and use_int4_w4a16:
b_zp_num = 8
if not has_zp and use_int8_w8a16:
b_zp_num = 128
elif has_zp and use_int4_w4a16:
b_zp_shifter = (offs_bn[None, :] % 2) * 4 # 0, 4
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if not block_k_diviable:
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
k_other = 0.0
else:
k_mask = None
k_other = None
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs)
if use_int4_w4a16:
b = tl.interleave(b, b)
b= b.trans()
b = (b >> b_shifter) & 0xF
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
offs_bn[None, :] * stride_bsk + \
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsn
qzeros_scles = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
scales_int16 = tl.cast(qzeros_scles,tl.uint16)
b_scale = tl.cast(scales_int16,tl.float16,bitcast=True)
# tl.device_print("b_scale dequant",b_scale)
mid = qzeros_scles >> 16
# b_zp = tl.cast(mid,tl.float16,bitcast=False)
b_zp = tl.cast(mid,tl.float16)
# b_zp = tl.cast(zeros_int16,tl.float16,bitcast=False)
# tl.device_print("bzp",b_zp)
# We accumulate along the K dimension.
b = ((b - b_zp) * b_scale).to(tl.float16)
accumulator = tl.dot(a, b, acc=accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
if use_int4_w4a16:
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
else:
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@triton.jit @triton.jit
def fused_moe_kernel_gptq_awq( def fused_moe_kernel_gptq_awq(
# Pointers to matrices # Pointers to matrices
...@@ -562,7 +719,7 @@ def moe_align_block_size_triton( ...@@ -562,7 +719,7 @@ def moe_align_block_size_triton(
def moe_align_block_size( def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: num_experts: int, num_token: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
size for matrix multiplication. size for matrix multiplication.
...@@ -600,11 +757,18 @@ def moe_align_block_size( ...@@ -600,11 +757,18 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible - The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) if num_token:
sorted_ids = torch.empty((max_num_tokens_padded, ), if num_token < block_size:
dtype=torch.int32, max_num_tokens_padded = min(topk_ids.numel() * block_size, topk_ids.numel() + num_experts * (block_size - 1))
device=topk_ids.device) else:
sorted_ids.fill_(topk_ids.numel()) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.full((max_num_tokens_padded,), fill_value=topk_ids.numel(), dtype=torch.int32, device=topk_ids.device)
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ), expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32, dtype=torch.int32,
...@@ -754,7 +918,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -754,7 +918,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert B_scale is None assert B_scale is None
EM = sorted_token_ids.shape[0] EM = sorted_token_ids.shape[0]
if A.shape[0] < config["BLOCK_SIZE_M"]: if use_int4_w4a16:
EM = sorted_token_ids.shape[0]
elif A.shape[0] < config["BLOCK_SIZE_M"]:
# optimize for small batch_size. # optimize for small batch_size.
# We assume that top_ids of each token is unique, so # We assume that top_ids of each token is unique, so
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
...@@ -769,43 +935,82 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -769,43 +935,82 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert B_scale is not None and B_scale.ndim == 3 assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3 assert B_zp is None or B_zp.ndim == 3
fused_moe_kernel_gptq_awq[grid]( if os.environ.get('AWQ_MOE_SZ') == '1':
A, fused_moe_kernel_awq[grid](
B, A,
C, B,
B_scale, C,
B_zp, B_scale,
topk_weights, B_zp,
sorted_token_ids, topk_weights,
expert_ids, sorted_token_ids,
num_tokens_post_padded, expert_ids,
B.shape[1], num_tokens_post_padded,
A.shape[1], B.shape[1],
EM, A.shape[1],
topk_ids.numel(), EM,
A.stride(0), topk_ids.numel(),
A.stride(1), A.stride(0),
B.stride(0), A.stride(1),
B.stride(2), B.stride(0),
B.stride(1), B.stride(2),
C.stride(1), B.stride(1),
C.stride(2), C.stride(1),
B_scale.stride(0), C.stride(2),
B_scale.stride(2), B_scale.stride(0),
B_scale.stride(1), B_scale.stride(2),
B_zp.stride(0) if B_zp is not None else 0, B_scale.stride(1),
B_zp.stride(2) if B_zp is not None else 0, B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0, B_zp.stride(2) if B_zp is not None else 0,
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, B_zp.stride(1) if B_zp is not None else 0,
group_size=block_shape[1], block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
MUL_ROUTED_WEIGHT=mul_routed_weight, group_size=block_shape[1],
top_k=top_k, MUL_ROUTED_WEIGHT=mul_routed_weight,
compute_type=compute_type, top_k=top_k,
has_zp=B_zp is not None, compute_type=compute_type,
use_int4_w4a16=use_int4_w4a16, has_zp=B_zp is not None,
use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16,
**config, use_int8_w8a16=use_int8_w8a16,
) **config,
)
else:
fused_moe_kernel_gptq_awq[grid](
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
EM,
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0),
B_scale.stride(2),
B_scale.stride(1),
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
**config,
)
else: else:
fused_moe_kernel[grid]( fused_moe_kernel[grid](
...@@ -892,6 +1097,15 @@ def get_moe_configs( ...@@ -892,6 +1097,15 @@ def get_moe_configs(
config_file_path = os.path.join( config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120:
config_file_path_120 = config_file_path.replace(".json","_120.json")
if os.path.exists(config_file_path_120):
with open(config_file_path_120) as f:
logger.info("Using configuration from %s for MoE layer.",
config_file_path_120)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path) as f: with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.", logger.info("Using configuration from %s for MoE layer.",
...@@ -1379,8 +1593,12 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1379,8 +1593,12 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
if moe_ep_size == 1: if moe_ep_size == 1:
sorted_token_ids, expert_ids, num_tokens_post_padded = ( if use_int4_w4a16:
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E, curr_hidden_states.shape[0]))
else:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
else: else:
sorted_token_ids, expert_ids, num_tokens_post_padded = ( sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_ep_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E, moe_ep_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E,
......
...@@ -105,10 +105,11 @@ def get_model_architecture( ...@@ -105,10 +105,11 @@ def get_model_architecture(
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1': if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0' os.environ['FA_PAD'] = '0'
# awq相关配置
try: try:
if os.getenv('AWQ_PAD') == '0' or ((torch.cuda.isCurrentDeviceEco(torch.cuda.current_device())) and os.getenv('AWQ_PAD') == None): if os.getenv('AWQ_MOE_SZ') == None:
os.environ['AWQ_PAD'] = '0' os.environ['AWQ_MOE_SZ'] = '1'
else: if os.getenv('AWQ_PAD') == None and (torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120):
os.environ['AWQ_PAD'] = '1' os.environ['AWQ_PAD'] = '1'
except Exception as e: except Exception as e:
if os.getenv('AWQ_PAD') != '0': if os.getenv('AWQ_PAD') != '0':
......
...@@ -669,6 +669,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -669,6 +669,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1'
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
...@@ -734,6 +735,26 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -734,6 +735,26 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
dtype=dtype, dtype=dtype,
device=device), device=device),
}) })
def restore_qzeros_tensor(self, qzeros, qscales):
low_bits = qzeros & 0x0F
high_bits = qzeros >> 4
zeors_tensor = torch.stack([low_bits, high_bits], dim=2).view(qzeros.shape[0], -1 , qzeros.shape[-1])
zeors_int16 = zeors_tensor.to(torch.int16)
assert zeors_int16.shape == qscales.shape
uint16_tensor1 = zeors_int16.view(torch.uint16)
uint16_tensor2 = qscales.view(torch.uint16)
uint32_tensor1 = uint16_tensor1.to(torch.int32) << 16
uint32_tensor2 = uint16_tensor2.to(torch.int32)
result_tensor = uint32_tensor1 + uint32_tensor2
result_tensor =result_tensor.view(torch.uint32)
result_tensor = result_tensor.transpose(1, 2).contiguous()
return result_tensor
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
...@@ -877,6 +898,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -877,6 +898,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
"mlp.shared_experts.down_proj.qweight" "mlp.shared_experts.down_proj.qweight"
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
# moe_gather_sz
moe_key_words = ["mlp.experts.w13_qweight", "mlp.experts.w2_qweight"]
moe_combined_words = "|".join(moe_key_words)
for layername in loaded_params: for layername in loaded_params:
weight = params_dict[layername] weight = params_dict[layername]
...@@ -910,6 +934,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -910,6 +934,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous() zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous() qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if self.use_w4a16_moe_sz:
matches_moe = re.findall(moe_combined_words, layername)
# sz.shape == s.shape.T
if matches_moe:
qzeros=params_dict[layername.replace("qweight", "qzeros")]
scales=params_dict[layername.replace("qweight", "scales")]
sz_tensor = self.restore_qzeros_tensor(qzeros, scales)
scales.data = sz_tensor
return loaded_params return loaded_params
......
...@@ -86,10 +86,10 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -86,10 +86,10 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_config = copy.deepcopy(vllm_config) draft_worker_config = copy.deepcopy(vllm_config)
draft_worker_config.model_config = speculative_config.draft_model_config draft_worker_config.model_config = speculative_config.draft_model_config
draft_worker_config.quant_config = VllmConfig._get_quantization_config( # draft_worker_config.quant_config = VllmConfig._get_quantization_config(
draft_worker_config.model_config, # draft_worker_config.model_config,
vllm_config.load_config, # vllm_config.load_config,
) # )
speculative_config.draft_parallel_config.worker_cls =\ speculative_config.draft_parallel_config.worker_cls =\
draft_worker_config.parallel_config.sd_worker_cls draft_worker_config.parallel_config.sd_worker_cls
draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment