Commit 66b3ded6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-update' into 'v0.9.2-dev'

修复w8a8 triton config 择优位运算可能引发torch compile 编译错误,修复smquant w8a8 权重后处理位置

See merge request dcutoolkit/deeplearing/vllm!320
parents 7d5faa43 16d49763
......@@ -39,7 +39,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
cutlass_fp4_supported)
from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
import os
from vllm import _custom_ops as ops
......@@ -616,33 +616,10 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if [n,k] not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append([n,k])
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
elif self.w8a8_strategy==3:
layer.weight.data = layer.weight.data.T
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
self.tritonsingleton.gen_model_json()
layer.scheme.process_weights_after_loading(layer)
def create_weights(self, layer: torch.nn.Module,
......
......@@ -18,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops
logger = init_logger(__name__)
......@@ -29,6 +31,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.input_symmetric = input_symmetric
......@@ -108,6 +111,30 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if [n,k] not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append([n,k])
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
elif self.w8a8_strategy==3:
layer.weight.data = layer.weight.data.T
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
self.tritonsingleton.gen_model_json()
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
......
......@@ -455,11 +455,11 @@ def apply_int8_linear(
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16:
m_=m
#best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
m_= (m //4) * 4 #取值到最近的4的倍数
elif m<=160:
m_=(m + 7) & -8
m_=(m // 8) *8
elif m<200: #256
m_=160
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment