Commit 1bbb2f94 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev_yql_2.11' into 'v0.11.0-dev'

修复CompressedTensors的w8a16的支持,新增awq_marlin gemm的qwen72B的支持

See merge request dcutoolkit/deeplearing/vllm!429
parents a4d28758 3d4721f2
......@@ -420,6 +420,24 @@ def awq_gemm_fake(input: torch.Tensor, weight: torch.Tensor,
splikspacesize:int) -> torch.Tensor:
return torch.empty((m, n), dtype=input.dtype, device=input.device)
def awq_gemm_marlin_weight_repack(weight_trans: torch.Tensor, N:int, K:int) -> torch.Tensor:
return lightop.awq_gemm_marlin_weight_repack(weight_trans, N, K)
def awq_gemm_marlin_weight_repack_fake(weight_trans: torch.Tensor, N:int, K:int) -> torch.Tensor:
return torch.empty((N, K), dtype=weight_trans.dtype, device=weight_trans.device)
def gemm_awq_w4a16_marlin(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor,
m:int,n:int,k:int) -> torch.Tensor:
return lightop.gemm_awq_w4a16_marlin(input,
weight,
zeros_and_scales)
def gemm_awq_w4a16_marlin_fake(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor,
m:int,n:int,k:int) -> torch.Tensor:
return torch.empty((m, n), dtype=input.dtype, device=input.device)
def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
......@@ -2299,3 +2317,17 @@ direct_register_custom_op(
mutates_args=[],
fake_impl=gptq_gemm_fake_,
)
direct_register_custom_op(
op_name="awq_gemm_marlin_weight_repack",
op_func=awq_gemm_marlin_weight_repack,
mutates_args=[],
fake_impl=awq_gemm_marlin_weight_repack_fake,
)
direct_register_custom_op(
op_name="gemm_awq_w4a16_marlin",
op_func=gemm_awq_w4a16_marlin,
mutates_args=[],
fake_impl=gemm_awq_w4a16_marlin_fake,
)
\ No newline at end of file
......@@ -96,6 +96,7 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
VLLM_USE_TRITON_AWQ: bool = False
AWQ_GEMM_MARLIN: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = []
......@@ -909,6 +910,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_AWQ":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
# If set, vLLM will use marlin implementations of AWQ.
"AWQ_GEMM_MARLIN":
lambda: bool(int(os.getenv("AWQ_GEMM_MARLIN", "0"))),
# If set, allow loading or unloading lora adapters in runtime,
"VLLM_ALLOW_RUNTIME_LORA_UPDATING":
lambda:
......@@ -1765,6 +1770,7 @@ def compute_hash() -> str:
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"VLLM_USE_TRITON_FLASH_ATTN",
"VLLM_USE_TRITON_AWQ",
"AWQ_GEMM_MARLIN",
"VLLM_DP_RANK",
"VLLM_DP_SIZE",
"VLLM_USE_STANDALONE_COMPILE",
......
......@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
import lightop
from vllm.logger import init_logger
logger = init_logger(__name__)
triton_configs_dict={}
......@@ -205,8 +206,9 @@ class AWQLinearMethod(LinearMethodBase):
"""
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
if not envs.AWQ_GEMM_MARLIN and not envs.VLLM_USE_TRITON_AWQ:
self.awqsingleton= AWQShareWorkSpace()
self.quant_config = quant_config
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def create_weights(self, layer: torch.nn.Module,
......@@ -304,6 +306,9 @@ class AWQLinearMethod(LinearMethodBase):
qweight_pad = torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
_qw=torch.cat((_qw,qweight_pad),dim=1).contiguous()
if envs.AWQ_GEMM_MARLIN:
_qw =torch.ops.vllm.awq_gemm_marlin_weight_repack(_qw, dim_n, dim_k)
layer.qweight = torch.nn.Parameter(_qw, requires_grad=False)
layer.zeros_and_scales = torch.nn.Parameter(sz, requires_grad=False)
layer.qzeros = None
......@@ -326,12 +331,13 @@ class AWQLinearMethod(LinearMethodBase):
qzeros = layer.qzeros
scales = layer.scales
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[0] * 1, ))
reshaped_x = x.reshape(-1, x.shape[-1])
reshaped_x = x.reshape(-1, x.shape[-1])
m = reshaped_x.shape[0]
k = reshaped_x.shape[-1]
n = qweight.shape[0]
n = zeros_and_scales.shape[0]
out_shape = (x.shape[:-1] + (n, ))
if self.use_awq_pad:
if k % 4096 == 0:
......@@ -345,6 +351,9 @@ class AWQLinearMethod(LinearMethodBase):
best_config=getspec_config(m,n,k)
out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)
out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
else:
if envs.AWQ_GEMM_MARLIN:
out = torch.ops.vllm.gemm_awq_w4a16_marlin(reshaped_x, qweight, zeros_and_scales, m, n, k)
else:
out = torch.ops.vllm.awq_gemm(reshaped_x,
qweight,
......
......@@ -730,7 +730,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
)
if weights_scheme is not None:
num_bits = weights_scheme.num_bits
if num_bits == 4:
if isinstance(layer.scheme, CompressedTensorsWNA16):
return layer.scheme.process_weights_after_loading(layer)
n=layer.weight.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment