Commit a098923a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/v0.8.5.post1-dev-w8a8-debug' into v0.8.5.post1-opt1

parents 0d61a71c 44e87dde
......@@ -18,7 +18,6 @@ if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
# For test
def native_per_token_group_quant_int8(x,
group_size,
......
......@@ -985,9 +985,22 @@ def triton_int8_gemm_helper(m: int,
per_out_channel_weight_quant: bool,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.float16,
device: str = "cuda",
best_config:Optional[list] = None):
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config)
device: str = "cuda:0",
best_config:Optional[list] = None,
repeat:Optional[int] = 2):
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat)
def triton_blockint8_gemm_helper(m: int,
n: int,
k: int,
block_size:list=[128,128],
use_bias: bool=False,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda:0",
best_config:Optional[dict] = None,
repeat:Optional[int] = 2):
return quant_tools.triton_blockint8_gemm_helper(m,n,k,block_size,use_bias,out_dtype,device,best_config,repeat)
def cutlass_scaled_mm_azp(a: torch.Tensor,
......
......@@ -19,112 +19,17 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8,get_w8a8moe_json)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__)
device_name = current_platform.get_device_name().replace(" ", "_")
if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120:
stage1_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 4,"kpack": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 8},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#12
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #32
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 2,"num_stages": 0,"num_warps": 4}, #256
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 2,"num_stages": 0,"num_warps": 4},#1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 2,"num_stages": 0,"num_warps": 8},#8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "kpack": 1,"num_stages": 0,"num_warps": 8}
]
stage2_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#12
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #32
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4} ,#256
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4},#1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4},# 8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}
]
else:
stage1_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 8},#11
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 2},#12
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 4,"num_stages": 0,"num_warps": 2},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #32
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},#256
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},#1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},#8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},
]
stage2_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 8},#12
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #32
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4}, #256
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0, "num_warps": 4}, #1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0, "num_warps": 4}, #8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4}
]
@triton.jit
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
......@@ -1610,6 +1515,33 @@ def fused_experts_impl(hidden_states: torch.Tensor,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False):
# Check constraints.
if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8= False,
use_int8_w8a8= True,
use_int8_w8a16= False,
use_int4_w4a16 = False,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe= False
)
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch"
......@@ -1641,7 +1573,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if not use_int8_w8a8:
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
......@@ -1712,23 +1644,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
if use_int8_w8a8:
m=curr_hidden_states.shape[0]
if m<=16:
config =stage1_best_config[m-1]
elif m<=32:
config =stage1_best_config[15]
elif m<=64:
config =stage1_best_config[16]
elif m<=256:
config =stage1_best_config[17]
elif m<=1024:
config =stage1_best_config[18]
elif m<=8192:
config =stage1_best_config[19]
else:
config =stage1_best_config[20]
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
A=curr_hidden_states,
B=w1,
......@@ -1751,6 +1666,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
global_num_experts, expert_map))
enable_expert_parallel = (int)(expert_map is not None)
invoke_fused_moe_kernel(qcurr_hidden_states,
w1,
intermediate_cache1,
......@@ -1796,23 +1712,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
if use_int8_w8a8:
m=curr_hidden_states.shape[0]
if m<=16:
config =stage2_best_config[m-1]
elif m<=32:
config =stage2_best_config[15]
elif m<=64:
config =stage2_best_config[16]
elif m<=256:
config =stage2_best_config[17]
elif m<=1024:
config =stage2_best_config[18]
elif m<=8192:
config =stage2_best_config[19]
else:
config =stage2_best_config[20]
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
......
......@@ -20,8 +20,10 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter,
PerTensorScaleParameter)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
from lmslim.layers.gemm.int8_utils import (
apply_w8a8_block_int8_linear)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import W8a8GetCacheJSON
......@@ -133,6 +135,8 @@ class BlockInt8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: BlockInt8Config):
self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.block_size=self.quant_config.weight_block_size
assert self.quant_config.weight_block_size is not None
assert self.quant_config.is_checkpoint_int8_serialized
......@@ -227,23 +231,18 @@ class BlockInt8LinearMethod(LinearMethodBase):
n=layer.weight.shape[0]
k=layer.weight.shape[1]
block_n=self.quant_config.weight_block_size[0]
block_k=self.quant_config.weight_block_size[1]
block_size=[block_n,block_k]
#print("layer.weight.device:",layer.weight.device)
if {n,k} not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_blockint8json_name(n,k,block_n,block_k)
configs_dict=self.tritonsingleton.get_blockint8_triton_cache(json_file,n,k,block_n,block_k)
json_file=self.tritonsingleton.get_blockint8json_name(n,k,self.block_size[0],self.block_size[1])
configs_dict=self.tritonsingleton.get_blockint8_triton_cache(json_file,n,k,self.block_size[0],self.block_size[1])
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
#ops.triton_blockint8_gemm_helper(m=m,n=n,k=k,block_size=block_size,use_bias=False,out_dtype=torch.bfloat16,device=layer.weight.device,best_config=value)
ops.triton_blockint8_gemm_helper(m=m,n=n,k=k,block_size=self.block_size,use_bias=False,out_dtype=torch.bfloat16,device=layer.weight.device,best_config=value)
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
layer.weight_scale_inv = torch.nn.Parameter(
......@@ -256,6 +255,46 @@ class BlockInt8LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
M=x.shape[0]
K=x.shape[1]
N=layer.weight.shape[0]
#print("self.tritonsingleton.triton_json_dict:",self.tritonsingleton.triton_json_dict)
#Get the best config options
if len(self.tritonsingleton.triton_json_dict)==0:
config=None
elif f"1_{N}_{K}_block[{self.block_size[0]},{self.block_size[1]}]" in self.tritonsingleton.triton_json_dict:
if M<=16:
m_=M
elif M<=64:
m_= (M + 3) & -4 #取值到最近的4的倍数
elif M<=160:
m_=(M + 7) & -8
elif M<200: #256
m_=160
elif M<480: #512
m_=256
elif M<960: #1024
m_=512
elif M<2048:
m_=1024
elif M<4096:
m_=2048
elif M<6000:
m_=4096
else:
m_=8192
config=self.tritonsingleton.triton_json_dict[f"{m_}_{N}_{K}_block[{self.block_size[0]},{self.block_size[1]}]"]
else:
config=None
#print("m:{},n:{},k:{},config:{}".format(M,N,K,config))
return apply_w8a8_block_int8_linear(
input=x,
weight=layer.weight,
......@@ -263,6 +302,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
weight_scale=layer.weight_scale_inv,
input_scale=None,
bias=bias,
config=config
)
class BlockInt8MoEMethod:
......@@ -390,6 +430,7 @@ class BlockInt8MoEMethod:
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
# warmup and get moe block-int8 config
return
def apply(
......
......@@ -17,7 +17,10 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.model_executor.layers.fused_moe import fused_experts
from lmslim.layers.fused_moe.fuse_moe_int4 import fused_experts_w4a16
os.environ['W4A16_MOE_CUDA'] = os.environ.get('W4A16_MOE_CUDA', '0')
os.environ['W4A16_MOE_LMSLIM'] = os.environ.get('W4A16_MOE_LMSLIM', '1')
if os.environ['W4A16_MOE_CUDA'] == '1':
from vllm.model_executor.layers.quantization.utils.fused_moe_cuda import fused_experts_cuda
......@@ -180,7 +183,11 @@ class MoeWNA16Method(FusedMoEMethodBase):
def __init__(self, quant_config: MoeWNA16Config):
self.quant_config = quant_config
self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1'
self.use_w4a16_cuda = 0
self.use_moe_lmslim = 0
if self.use_w4a16_moe_sz:
self.use_w4a16_cuda = os.environ['W4A16_MOE_CUDA'] == '1'
self.use_moe_lmslim = os.environ['W4A16_MOE_LMSLIM'] == "1"
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
......@@ -352,6 +359,24 @@ class MoeWNA16Method(FusedMoEMethodBase):
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
if self.use_moe_lmslim:
return fused_experts_w4a16(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int4_w4a16=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
block_shape=[0, layer.group_size])
if self.use_w4a16_cuda:
m = topk_ids.shape[0]
if m <= 512:
......@@ -380,6 +405,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
......
......@@ -17,8 +17,8 @@ from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from grouped_gemm import moe_gemm_w4a16
from grouped_gemm.ops import permute as permute_topK, unpermute as unpermute_topK
from grouped_gemm_int4 import moe_gemm_w4a16
from grouped_gemm_int4.ops import permute as permute_topK, unpermute as unpermute_topK
import torch.nn.functional as F
logger = init_logger(__name__)
device_name = current_platform.get_device_name()
......@@ -201,11 +201,11 @@ def config_cuda(M):
NearestM = min(reference_points, key=lambda x: abs(x - M))
if device_name == "K100_AI":
mode_1 = k100ai_gemm1_m_to_mode_dict.get(M, k100ai_gemm1_m_to_mode_dict[NearestM])
mode_2 = k100ai_gemm2_m_to_mode_dict.get(M, k100ai_gemm2_m_to_mode_dict[NearestM])
mode_1 = k100ai_gemm1_m_to_mode_dict.get(NearestM, k100ai_gemm1_m_to_mode_dict[32])
mode_2 = k100ai_gemm2_m_to_mode_dict.get(NearestM, k100ai_gemm2_m_to_mode_dict[32])
else:
mode_1 = bw_gemm1_m_to_mode_dict.get(M, k100ai_gemm1_m_to_mode_dict[NearestM])
mode_2 = bw_gemm2_m_to_mode_dict.get(M, k100ai_gemm2_m_to_mode_dict[NearestM])
mode_1 = bw_gemm1_m_to_mode_dict.get(NearestM, bw_gemm1_m_to_mode_dict[32])
mode_2 = bw_gemm2_m_to_mode_dict.get(NearestM, bw_gemm2_m_to_mode_dict[32])
return mode_1, mode_2
......@@ -315,7 +315,7 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor,
num_tokens_post_padded, # 实际专家数
expert_ids, # expert_id_vec
w1_scale, # scale_zero
64, # group_size
block_shape[1], # group_size
topk=topk, # topk
mode=mode_1) # mode=gemm1_mode
......@@ -329,10 +329,12 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor,
expert_ids, # expert_id_vec
w2_scale, # scale_zero
topk_weights, # topk_weights
64, # group_size
block_shape[1], # group_size
topk=topk, # topk
mode=mode_2) # mode=gemm2_mode
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states)
return out_hidden_states
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/blob/4cb53ecd0cffceb6dee5c011a58f65997a86f151/python/sglang/srt/layers/quantization/int8_kernel.py
import functools
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from triton.language.extra import libdevice
from vllm.utils import W8a8GetCacheJSON
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
W8A8_TRITONJSON=W8a8GetCacheJSON()
def apply_w8a8_block_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1])
output = w8a8_block_int8_matmul(q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=input.dtype)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def input_to_int8(
x: torch.Tensor,
dtype: torch.dtype = torch.int8) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to int8 values with
tensor-wise quantization."""
iinfo = torch.iinfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
int8_min, int8_max = iinfo.min, iinfo.max
scale = int8_max / amax
x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> torch.Tensor:
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
The outputs are dequantized tensor.
"""
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block[
j * block_n:min((j + 1) * block_n, n),
i * block_k:min((i + 1) * block_k, k),
] *= x_s[j][i]
return x_dq_block
@triton.jit
def _per_token_quant_int8(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
BLOCK: tl.constexpr,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
def per_token_quant_int8(x):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_per_token_quant_int8[(M, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_group_quant_int8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
# Stride of input
group_size,
# M,
# K,
# # Collums of input
# N,
SIZE,
# Avoid to divide zero
eps,
# Information for int8
int8_min,
int8_max,
# Meta-parameters
BLOCK: tl.constexpr,
s_num : tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into int8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * BLOCK
y_q_ptr += g_id * BLOCK
y_s_ptr += g_id * s_num
cols = tl.arange(0, BLOCK) # N <= BLOCK
s_cols = tl.arange(0, s_num)
mask = g_id * BLOCK + cols < SIZE
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
y = tl.reshape(y, (s_num, 128))
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
y_s = (_absmax / int8_max).reshape(s_num, 1)
y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty)
y_q = tl.reshape(y_q, (s_num*128))
y_s = tl.reshape(y_s, (s_num))
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr + s_cols, y_s.to(y_s_ptr.dtype.element_ty))
def per_token_group_quant_int8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.int8`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_max = iinfo.max
int8_min = iinfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
N = group_size
m = x.shape[0]
if m<=16:
config={"BLOCK":128,"s_num":1,"num_warps":1,"num_stages":1}
elif m<=256:
config={"BLOCK":1024,"s_num":8,"num_warps":4,"num_stages":1}
else:
config={"BLOCK":2048,"s_num":16,"num_warps":4,"num_stages":2}
grid = lambda META: (
triton.cdiv(x.numel(), META['BLOCK']),
)
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size, ),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_int8[grid](
x,
x_q,
x_s,
group_size,
# M,
# K,
# N,
x.numel(),
eps,
int8_min=int8_min,
int8_max=int8_max,
**config
)
return x_q, x_s
@triton.jit
def _w8a8_block_int8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bsn = pid_n * BLOCK_SIZE_N // group_n
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# a_ptrs = A + (offs_am[:, None] * stride_am)
# b_ptrs = B + (offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
# offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
"Using configuration from %s for W8A8 Block INT8 kernel.",
config_file_path,
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
("Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"),
config_file_path,
)
return None
def w8a8_block_int8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be
2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
if len(W8A8_TRITONJSON.triton_json_dict)==0:
config=None
elif f"1_{N}_{K}_block[{block_n},{block_k}]" in W8A8_TRITONJSON.triton_json_dict:
if M<=16:
m_=M
elif M<=64:
m_= (M + 3) & -4 #取值到最近的4的倍数
elif M<=160:
m_=(M + 7) & -8
elif M<200: #256
m_=160
elif M<480: #512
m_=256
elif M<960: #1024
m_=512
elif M<2048:
m_=1024
elif M<4096:
m_=2048
elif M<6000:
m_=4096
else:
m_=8192
config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{N}_{K}_block[{block_n},{block_k}]"]
else:
config=None
if config==None:
# print("m:{},n:{},k:{}".format(M,N,K))
# print("config not found!")
if M<=64:
config = {
"BLOCK_SIZE_M": 16, #64
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<128:
config = {
"BLOCK_SIZE_M": 32, #64
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<=256:
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
else :
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 0,
}
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_int8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
def apply_w8a8_block_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1])
output = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=input.dtype
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def input_to_int8(
x: torch.Tensor, dtype: torch.dtype = torch.int8
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to
int8 values with tensor-wise quantization.
"""
iinfo = torch.iinfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
int8_min, int8_max = iinfo.min, iinfo.max
scale = int8_max / amax
x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> torch.Tensor:
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
The outputs are dequantized tensor.
"""
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block[
j * block_n : min((j + 1) * block_n, n),
i * block_k : min((i + 1) * block_k, k),
] *= x_s[j][i]
return x_dq_block
\ No newline at end of file
......@@ -13,7 +13,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
from vllm import _custom_ops as ops
......
......@@ -60,11 +60,6 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
maybe_prefix)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.int8_utils import (
block_dequant as int8_block_dequant,
)
class DeepseekV2MLP(nn.Module):
def __init__(
......
......@@ -1735,42 +1735,8 @@ class W8a8GetCacheJSON:
else:
return None
def get_triton_cache_tune(self,file_path,n,k):
#tuning的时候使用,当文件不存在时候,则创建文件夹
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
folder_path = os.path.dirname(file_path)
os.makedirs(folder_path, exist_ok=True)
cachedata = {}
# 写入空数据到新的JSON文件
with open(file_path, 'w') as file:
json.dump(cachedata, file)
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_value={
'SPLIT_K': int(sub_value["SPLIT_K"]),
'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
'num_stages':int(sub_value['num_stages']),
'num_warps':int(sub_value['num_warps'])
}
configs_dict[configs_key]=configs_value
return configs_dict
def get_triton_cache(self,file_path,n,k):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file=file_path
if os.path.exists(file_path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment