Commit 40b94473 authored by gaoqiong's avatar gaoqiong
Browse files

修改deepseelk block-int8 权重处理流程,增加per-channel bestconfig配置以及首次triton warmup代码

parent a68aef25
......@@ -23,6 +23,10 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.utils.int8_utils import (
apply_w8a8_block_int8_linear)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import W8a8GetCacheJSON
import os
from vllm import _custom_ops as ops
ACTIVATION_SCHEMES = ["static", "dynamic"]
......@@ -128,6 +132,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: BlockInt8Config):
self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
assert self.quant_config.weight_block_size is not None
assert self.quant_config.is_checkpoint_int8_serialized
......@@ -219,6 +224,27 @@ class BlockInt8LinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
# Use torch Parameter to avoid cuda graph capturing issue
n=layer.weight.shape[0]
k=layer.weight.shape[1]
block_n=self.quant_config.weight_block_size[0]
block_k=self.quant_config.weight_block_size[1]
block_size=[block_n,block_k]
#print("layer.weight.device:",layer.weight.device)
if {n,k} not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_blockint8json_name(n,k,block_n,block_k)
configs_dict=self.tritonsingleton.get_blockint8_triton_cache(json_file,n,k,block_n,block_k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
#ops.triton_blockint8_gemm_helper(m=m,n=n,k=k,block_size=block_size,use_bias=False,out_dtype=torch.bfloat16,device=layer.weight.device,best_config=value)
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv.data, requires_grad=False
......
......@@ -17,6 +17,12 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
import os
from vllm import _custom_ops as ops
W8A8_TRITONJSON=W8a8GetCacheJSON()
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
......@@ -84,8 +90,30 @@ class W8A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: W8A8Int8Config):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if {n,k} not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
......@@ -128,18 +156,67 @@ class W8A8Int8LinearMethod(LinearMethodBase):
):
x_q, x_scale = per_token_quant_int8(x)
# return int8_scaled_mm(
# x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
# )
#return baseline_scaled_mm(x_q, layer.weight, x_scale, layer.weight_scale, x.dtype, bias)
best_config=None
return ops.triton_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,best_config=best_config)
if self.w8a8_strategy==1:
m=x_q.shape[0]
k=x_q.shape[1]
n=layer.weight.shape[1]
if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16:
m_=m
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
elif m<=160:
m_=(m + 7) & -8
elif m<200: #256
m_=160
elif m<480: #512
m_=256
elif m<960: #1024
m_=512
elif m<2048:
m_=1024
elif m<4096:
m_=2048
elif m<6000:
m_=4096
else:
m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else:
best_config=None
if best_config==None:
print("m:{},n:{},k:{}".format(m,n,k))
print("config not found!")
return ops.triton_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,best_config=best_config)
elif self.w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
class W8A8Int8MoEMethod:
"""MoE method for INT8.
......
......@@ -53,7 +53,6 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import W8a8GetCacheJSON
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
......@@ -704,7 +703,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
os.environ['LM_NN'] = '0'
self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1'
self.tritonsingleton= W8a8GetCacheJSON()
self.config = config
self.quant_config = quant_config
......@@ -928,48 +926,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
sz_tensor = self.restore_qzeros_tensor(qzeros, scales)
scales.data = sz_tensor
if hasattr(self.config, "quantization_config") and self.config.quantization_config["quant_method"] == "blockwise_int8":
lay_key_words = [
"self_attn.q_a_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
"mlp.shared_experts.gate_up_proj.weight",
"mlp.shared_experts.down_proj.weight"
]
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
n=weight_data.shape[0]
if len(matched_key_words) < 9 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
#print("n:{},k:{}".format(n,k))
json_file=self.tritonsingleton.get_blockint8json_name(n,k,128,128)
configs_dict=self.tritonsingleton.get_blockint8_triton_cache(json_file,n,k,128,128)
if configs_dict:
all_json.update(configs_dict)
self.tritonsingleton.triton_json_list.append(all_json)
#print("self.tritonsingleton.triton_json_dict[0].shape:",len(self.tritonsingleton.triton_json_dict[0]))
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
# ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment