Commit f3deca99 authored by gaoqiong's avatar gaoqiong
Browse files

增加blockint8支持优化

parent 5c241fa9
...@@ -21,6 +21,91 @@ from vllm.platforms import current_platform ...@@ -21,6 +21,91 @@ from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
device_name = current_platform.get_device_name().replace(" ", "_")
if device_name=='BW200' or device_name=='K100_AI':
stage1_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 8},#11
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 2},#12
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 4,"num_stages": 0,"num_warps": 2},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #32
]
stage2_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 8},#12
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #16
]
else:
stage1_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 4,"kpack": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 8},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#12
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #32
]
stage2_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#12
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #16
]
@triton.jit @triton.jit
def fused_moe_kernel_awq( def fused_moe_kernel_awq(
...@@ -1516,23 +1601,24 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1516,23 +1601,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938 # https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, if not use_int8_w8a8:
use_int8_w8a8=use_int8_w8a8, config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a8=use_int8_w8a8,
use_int4_w4a16=use_int4_w4a16, use_int8_w8a16=use_int8_w8a16,
dtype=hidden_states.dtype) use_int4_w4a16=use_int4_w4a16,
dtype=hidden_states.dtype)
get_config_func = functools.partial(
try_get_optimal_moe_config, get_config_func = functools.partial(
w1.shape, try_get_optimal_moe_config,
w2.shape, w1.shape,
topk_ids.shape[1], w2.shape,
config_dtype, topk_ids.shape[1],
block_shape=block_shape, config_dtype,
use_nn_moe=use_nn_moe, block_shape=block_shape,
) use_nn_moe=use_nn_moe,
)
config = get_config_func(M) config = get_config_func(M)
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,
...@@ -1584,6 +1670,33 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1584,6 +1670,33 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
if use_int8_w8a8:
m=curr_hidden_states.shape[0]
if m<=16:
config =stage1_best_config[m-1]
elif m<=32:
config =stage1_best_config[15]
elif m<=64:
config =stage1_best_config[16]
elif m<256:
config ={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0,
"num_warps": 4
}
else:
config ={
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 0,
"num_warps": 4
}
if moe_ep_size == 1: if moe_ep_size == 1:
if use_int4_w4a16: if use_int4_w4a16:
sorted_token_ids, expert_ids, num_tokens_post_padded = ( sorted_token_ids, expert_ids, num_tokens_post_padded = (
...@@ -1620,7 +1733,33 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1620,7 +1733,33 @@ def fused_experts_impl(hidden_states: torch.Tensor,
torch.ops._C.silu_and_mul(intermediate_cache2, torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N)) intermediate_cache1.view(-1, N))
if use_int8_w8a8:
m1=intermediate_cache2.shape[0]
if m1<=16:
config =stage2_best_config[m1-1]
elif m1<=32:
config =stage2_best_config[15]
elif m1<=64:
config =stage2_best_config[16]
elif m1<256:
config ={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0,
"num_warps": 4
}
else:
config ={
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 0,
"num_warps": 4
}
invoke_fused_moe_kernel(intermediate_cache2, invoke_fused_moe_kernel(intermediate_cache2,
w2, w2,
intermediate_cache3, intermediate_cache3,
......
...@@ -9,12 +9,13 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -9,12 +9,13 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import W8a8GetCacheJSON
# from sglang.srt.utils import get_device_name # from sglang.srt.utils import get_device_name
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
W8A8_TRITONJSON=W8a8GetCacheJSON()
@triton.jit @triton.jit
def _per_token_quant_int8( def _per_token_quant_int8(
...@@ -335,16 +336,16 @@ def w8a8_block_int8_matmul( ...@@ -335,16 +336,16 @@ def w8a8_block_int8_matmul(
C_shape = A.shape[:-1] + (N,) C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype) C = A.new_empty(C_shape, dtype=output_dtype)
#configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1]) # configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
#if configs: # if configs:
# # If an optimal configuration map has been found, look up the # # If an optimal configuration map has been found, look up the
# # optimal config # # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))] # config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
#else: # else:
# Default config # #Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] # #Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1])) # #print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
# config = { # config = {
# "BLOCK_SIZE_M": 32, #64 # "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0], # "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1], # "BLOCK_SIZE_K": block_size[1],
...@@ -352,43 +353,78 @@ def w8a8_block_int8_matmul( ...@@ -352,43 +353,78 @@ def w8a8_block_int8_matmul(
# "num_warps": 4, # "num_warps": 4,
# "num_stages": 3, # "num_stages": 3,
# } # }
#print("W8A8_TRITONJSON.triton_json_dict[0]:",W8A8_TRITONJSON.triton_json_dict[0])
if M<=64: if len(W8A8_TRITONJSON.triton_json_dict)==0:
config = { config=None
"BLOCK_SIZE_M": 16, #64 #print("len(W8A8_TRITONJSON.triton_json_dict)=0:",len(W8A8_TRITONJSON.triton_json_dict))
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128, elif f"1_{N}_{K}_block[{block_n},{block_k}]" in W8A8_TRITONJSON.triton_json_dict[0]:
"GROUP_SIZE_M": 2, if M<=16:
"num_warps": 4, m_=M
"num_stages": 0, elif M<=64:
} m_= (M + 3) & -4 #取值到最近的4的倍数
elif M<128: elif M<=160:
config = { m_=(M + 7) & -8
"BLOCK_SIZE_M": 32, #64
"BLOCK_SIZE_N": 64, elif M<200: #256
"BLOCK_SIZE_K": 128, m_=160
"GROUP_SIZE_M": 2, elif M<480: #512
"num_warps": 4, m_=256
"num_stages": 0, elif M<960: #1024
} m_=512
elif M<=256: elif M<2048:
config = { m_=1024
"BLOCK_SIZE_M": 64, #64 elif M<4096:
"BLOCK_SIZE_N": 64, m_=2048
"BLOCK_SIZE_K": 128, elif M<6000:
"GROUP_SIZE_M": 2, m_=4096
"num_warps": 4, else:
"num_stages": 0, m_=8192
} #print("==================m:{},n:{},k:{}".format(M,N,K))
else : config=W8A8_TRITONJSON.triton_json_dict[0][f"{m_}_{N}_{K}_block[{block_n},{block_k}]"]
config = {
"BLOCK_SIZE_M": 64, #64 else:
"BLOCK_SIZE_N": 128, config=None
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8, # print("m:{},n:{},k:{}".format(M,N,K))
"num_warps": 8, # print("config not found!")
"num_stages": 0,
} if M<=64:
config = {
"BLOCK_SIZE_M": 16, #64
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<128:
config = {
"BLOCK_SIZE_M": 32, #64
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<=256:
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
else :
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 0,
}
def grid(META): def grid(META):
return ( return (
...@@ -475,8 +511,6 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.f ...@@ -475,8 +511,6 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.f
C = C.reshape(origin_C_shape).to(output_dtype) C = C.reshape(origin_C_shape).to(output_dtype)
return C return C
def apply_w8a8_block_int8_linear( def apply_w8a8_block_int8_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
...@@ -497,11 +531,6 @@ def apply_w8a8_block_int8_linear( ...@@ -497,11 +531,6 @@ def apply_w8a8_block_int8_linear(
q_input, weight, x_scale, weight_scale, block_size, q_input, weight, x_scale, weight_scale, block_size,
output_dtype=input.dtype output_dtype=input.dtype
) )
# output = native_w8a8_block_int8_matmul(
# q_input, weight, x_scale, weight_scale, block_size,
# output_dtype=input.dtype
# )
if bias is not None: if bias is not None:
......
...@@ -53,6 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -53,6 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import W8a8GetCacheJSON
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
...@@ -677,6 +678,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -677,6 +678,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1' self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1'
self.tritonsingleton= W8a8GetCacheJSON()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
...@@ -948,7 +950,48 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -948,7 +950,48 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
scales=params_dict[layername.replace("qweight", "scales")] scales=params_dict[layername.replace("qweight", "scales")]
sz_tensor = self.restore_qzeros_tensor(qzeros, scales) sz_tensor = self.restore_qzeros_tensor(qzeros, scales)
scales.data = sz_tensor scales.data = sz_tensor
if hasattr(self.config, "quantization_config") and self.config.quantization_config["quant_method"] == "blockwise_int8":
lay_key_words = [
"self_attn.q_a_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
"mlp.shared_experts.gate_up_proj.weight",
"mlp.shared_experts.down_proj.weight"
]
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
n=weight_data.shape[0]
if len(matched_key_words) < 9 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
#print("n:{},k:{}".format(n,k))
json_file=self.tritonsingleton.get_blockint8json_name(n,k,128,128)
configs_dict=self.tritonsingleton.get_blockint8_triton_cache(json_file,n,k,128,128)
if configs_dict:
all_json.update(configs_dict)
self.tritonsingleton.triton_json_dict.append(all_json)
#print("self.tritonsingleton.triton_json_dict[0].shape:",len(self.tritonsingleton.triton_json_dict[0]))
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
# ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params return loaded_params
......
...@@ -1611,6 +1611,38 @@ class W8a8GetCacheJSON: ...@@ -1611,6 +1611,38 @@ class W8a8GetCacheJSON:
device_name = current_platform.get_device_name().replace(" ", "_") device_name = current_platform.get_device_name().replace(" ", "_")
return self.triton_json_dir+f"/W8A8_{n}_{k}_{device_name}.json" return self.triton_json_dir+f"/W8A8_{n}_{k}_{device_name}.json"
def get_blockint8_triton_cache(self,file_path,n,k,block_n,block_k):
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_value={
'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
'kpack': int(sub_value["kpack"]),
'num_stages':int(sub_value['num_stages']),
'num_warps':int(sub_value['num_warps']),
}
configs_dict[configs_key]=configs_value
return configs_dict
def get_blockint8json_name(self,n,k,block_n,block_k):
from vllm.platforms import current_platform
device_name = current_platform.get_device_name().replace(" ", "_")
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{device_name}.json"
# Adapted from: https://stackoverflow.com/a/47212782/5082708 # Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping[str, T], Generic[T]): class LazyDict(Mapping[str, T], Generic[T]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment