Commit a4df8463 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.14.1-dev_yql_1.28_2' into 'v0.14.1-dev'

修复vit attn的导入问题,以及w4a16的gptq的接口问题

See merge request dcutoolkit/deeplearing/vllm!395
parents 82277b17 8900b622
......@@ -385,7 +385,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.scales,
layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.use_v2_format,
# self.use_v2_format,
self.quant_config.weight_bits,
)
if bias is not None:
......
......@@ -155,12 +155,12 @@ class ExllamaLinearKernel(MPLinearKernel):
# gptq_gemm supports GPTQv2 format by passing use_v2_format=True.
# However, the MPLinearLayerConfig doesn't contain format info.
# So hardcode GPTQv1 format here, to keep its behavior unchanged.
use_v2_format = False
#use_v2_format = False
assert w_zp is not None, "Zero points are required by Exllama"
assert w_g_idx is not None, "Group index is required by Exllama"
output = ops.gptq_gemm(
x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits
x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits
)
if bias is not None:
......
......@@ -23,7 +23,7 @@ elif current_platform.is_xpu():
elif current_platform.is_rocm():
try:
from vllm._custom_ops import reshape_and_cache_cuda
from flash_attn import vllm_flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
except ImportError as e:
raise ImportError(
"Rocm platform requires upstream flash-attn "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment