Commit ae59e10f authored by zhuwenwen's avatar zhuwenwen
Browse files

修复vit attn的导入问题,以及w4a16的gptq的接口问题

parent 2d8b3257
......@@ -385,7 +385,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.scales,
layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.use_v2_format,
# self.use_v2_format,
self.quant_config.weight_bits,
)
if bias is not None:
......
......@@ -155,12 +155,12 @@ class ExllamaLinearKernel(MPLinearKernel):
# gptq_gemm supports GPTQv2 format by passing use_v2_format=True.
# However, the MPLinearLayerConfig doesn't contain format info.
# So hardcode GPTQv1 format here, to keep its behavior unchanged.
use_v2_format = False
# use_v2_format = False
assert w_zp is not None, "Zero points are required by Exllama"
assert w_g_idx is not None, "Group index is required by Exllama"
output = ops.gptq_gemm(
x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits
x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits
)
if bias is not None:
......
......@@ -26,7 +26,7 @@ elif current_platform.is_rocm():
try:
from vllm._custom_ops import reshape_and_cache_cuda
# from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
from flash_attn import vllm_flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
except ImportError:
def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment