Commit 1bbb2f94 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev_yql_2.11' into 'v0.11.0-dev'

修复CompressedTensors的w8a16的支持,新增awq_marlin gemm的qwen72B的支持

See merge request dcutoolkit/deeplearing/vllm!429
parents a4d28758 3d4721f2
...@@ -420,6 +420,24 @@ def awq_gemm_fake(input: torch.Tensor, weight: torch.Tensor, ...@@ -420,6 +420,24 @@ def awq_gemm_fake(input: torch.Tensor, weight: torch.Tensor,
splikspacesize:int) -> torch.Tensor: splikspacesize:int) -> torch.Tensor:
return torch.empty((m, n), dtype=input.dtype, device=input.device) return torch.empty((m, n), dtype=input.dtype, device=input.device)
def awq_gemm_marlin_weight_repack(weight_trans: torch.Tensor, N:int, K:int) -> torch.Tensor:
return lightop.awq_gemm_marlin_weight_repack(weight_trans, N, K)
def awq_gemm_marlin_weight_repack_fake(weight_trans: torch.Tensor, N:int, K:int) -> torch.Tensor:
return torch.empty((N, K), dtype=weight_trans.dtype, device=weight_trans.device)
def gemm_awq_w4a16_marlin(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor,
m:int,n:int,k:int) -> torch.Tensor:
return lightop.gemm_awq_w4a16_marlin(input,
weight,
zeros_and_scales)
def gemm_awq_w4a16_marlin_fake(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor,
m:int,n:int,k:int) -> torch.Tensor:
return torch.empty((m, n), dtype=input.dtype, device=input.device)
def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor, def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int): group_size: int):
...@@ -2299,3 +2317,17 @@ direct_register_custom_op( ...@@ -2299,3 +2317,17 @@ direct_register_custom_op(
mutates_args=[], mutates_args=[],
fake_impl=gptq_gemm_fake_, fake_impl=gptq_gemm_fake_,
) )
direct_register_custom_op(
op_name="awq_gemm_marlin_weight_repack",
op_func=awq_gemm_marlin_weight_repack,
mutates_args=[],
fake_impl=awq_gemm_marlin_weight_repack_fake,
)
direct_register_custom_op(
op_name="gemm_awq_w4a16_marlin",
op_func=gemm_awq_w4a16_marlin,
mutates_args=[],
fake_impl=gemm_awq_w4a16_marlin_fake,
)
\ No newline at end of file
...@@ -96,6 +96,7 @@ if TYPE_CHECKING: ...@@ -96,6 +96,7 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_WITH_STACK: bool = True VLLM_TORCH_PROFILER_WITH_STACK: bool = True
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
VLLM_USE_TRITON_AWQ: bool = False VLLM_USE_TRITON_AWQ: bool = False
AWQ_GEMM_MARLIN: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLED_KERNELS: list[str] = []
...@@ -909,6 +910,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -909,6 +910,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_AWQ": "VLLM_USE_TRITON_AWQ":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
# If set, vLLM will use marlin implementations of AWQ.
"AWQ_GEMM_MARLIN":
lambda: bool(int(os.getenv("AWQ_GEMM_MARLIN", "0"))),
# If set, allow loading or unloading lora adapters in runtime, # If set, allow loading or unloading lora adapters in runtime,
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "VLLM_ALLOW_RUNTIME_LORA_UPDATING":
lambda: lambda:
...@@ -1765,6 +1770,7 @@ def compute_hash() -> str: ...@@ -1765,6 +1770,7 @@ def compute_hash() -> str:
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"VLLM_USE_TRITON_FLASH_ATTN", "VLLM_USE_TRITON_FLASH_ATTN",
"VLLM_USE_TRITON_AWQ", "VLLM_USE_TRITON_AWQ",
"AWQ_GEMM_MARLIN",
"VLLM_DP_RANK", "VLLM_DP_RANK",
"VLLM_DP_SIZE", "VLLM_DP_SIZE",
"VLLM_USE_STANDALONE_COMPILE", "VLLM_USE_STANDALONE_COMPILE",
......
...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter) PackedvLLMParameter)
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
import lightop
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
triton_configs_dict={} triton_configs_dict={}
...@@ -205,8 +206,9 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -205,8 +206,9 @@ class AWQLinearMethod(LinearMethodBase):
""" """
def __init__(self, quant_config: AWQConfig): def __init__(self, quant_config: AWQConfig):
if not envs.AWQ_GEMM_MARLIN and not envs.VLLM_USE_TRITON_AWQ:
self.awqsingleton= AWQShareWorkSpace()
self.quant_config = quant_config self.quant_config = quant_config
self.awqsingleton= AWQShareWorkSpace()
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
...@@ -303,7 +305,10 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -303,7 +305,10 @@ class AWQLinearMethod(LinearMethodBase):
sz = torch.cat((sz,zeros_and_scalse_pad),dim=1).contiguous() sz = torch.cat((sz,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad = torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight_pad = torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
_qw=torch.cat((_qw,qweight_pad),dim=1).contiguous() _qw=torch.cat((_qw,qweight_pad),dim=1).contiguous()
if envs.AWQ_GEMM_MARLIN:
_qw =torch.ops.vllm.awq_gemm_marlin_weight_repack(_qw, dim_n, dim_k)
layer.qweight = torch.nn.Parameter(_qw, requires_grad=False) layer.qweight = torch.nn.Parameter(_qw, requires_grad=False)
layer.zeros_and_scales = torch.nn.Parameter(sz, requires_grad=False) layer.zeros_and_scales = torch.nn.Parameter(sz, requires_grad=False)
layer.qzeros = None layer.qzeros = None
...@@ -326,12 +331,13 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -326,12 +331,13 @@ class AWQLinearMethod(LinearMethodBase):
qzeros = layer.qzeros qzeros = layer.qzeros
scales = layer.scales scales = layer.scales
pack_factor = self.quant_config.pack_factor pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[0] * 1, ))
reshaped_x = x.reshape(-1, x.shape[-1])
reshaped_x = x.reshape(-1, x.shape[-1])
m = reshaped_x.shape[0] m = reshaped_x.shape[0]
k = reshaped_x.shape[-1] k = reshaped_x.shape[-1]
n = qweight.shape[0] n = zeros_and_scales.shape[0]
out_shape = (x.shape[:-1] + (n, ))
if self.use_awq_pad: if self.use_awq_pad:
if k % 4096 == 0: if k % 4096 == 0:
...@@ -346,16 +352,19 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -346,16 +352,19 @@ class AWQLinearMethod(LinearMethodBase):
out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config) out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)
out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, )) out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
else: else:
out = torch.ops.vllm.awq_gemm(reshaped_x, if envs.AWQ_GEMM_MARLIN:
qweight, out = torch.ops.vllm.gemm_awq_w4a16_marlin(reshaped_x, qweight, zeros_and_scales, m, n, k)
zeros_and_scales, else:
m, out = torch.ops.vllm.awq_gemm(reshaped_x,
n, qweight,
k, zeros_and_scales,
self.quant_config.group_size, m,
padding_group, n,
self.awqsingleton.awqworkshapce, k,
self.awqsingleton.awqworkshapcesize) self.quant_config.group_size,
padding_group,
self.awqsingleton.awqworkshapce,
self.awqsingleton.awqworkshapcesize)
if bias is not None: if bias is not None:
out.add_(bias) out.add_(bias)
......
...@@ -730,7 +730,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -730,7 +730,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
) )
if weights_scheme is not None: if weights_scheme is not None:
num_bits = weights_scheme.num_bits num_bits = weights_scheme.num_bits
if num_bits == 4: if isinstance(layer.scheme, CompressedTensorsWNA16):
return layer.scheme.process_weights_after_loading(layer) return layer.scheme.process_weights_after_loading(layer)
n=layer.weight.shape[0] n=layer.weight.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment