"vllm/vscode:/vscode.git/clone" did not exist on "2e9a2227ecee8990f0552518fc40dba67f1026b3"
Commit ae59e10f authored by zhuwenwen's avatar zhuwenwen
Browse files

修复vit attn的导入问题,以及w4a16的gptq的接口问题

parent 2d8b3257
...@@ -385,7 +385,7 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -385,7 +385,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.scales, layer.scales,
layer.g_idx, layer.g_idx,
layer.exllama_state == ExllamaState.READY, layer.exllama_state == ExllamaState.READY,
self.use_v2_format, # self.use_v2_format,
self.quant_config.weight_bits, self.quant_config.weight_bits,
) )
if bias is not None: if bias is not None:
......
...@@ -155,12 +155,12 @@ class ExllamaLinearKernel(MPLinearKernel): ...@@ -155,12 +155,12 @@ class ExllamaLinearKernel(MPLinearKernel):
# gptq_gemm supports GPTQv2 format by passing use_v2_format=True. # gptq_gemm supports GPTQv2 format by passing use_v2_format=True.
# However, the MPLinearLayerConfig doesn't contain format info. # However, the MPLinearLayerConfig doesn't contain format info.
# So hardcode GPTQv1 format here, to keep its behavior unchanged. # So hardcode GPTQv1 format here, to keep its behavior unchanged.
use_v2_format = False # use_v2_format = False
assert w_zp is not None, "Zero points are required by Exllama" assert w_zp is not None, "Zero points are required by Exllama"
assert w_g_idx is not None, "Group index is required by Exllama" assert w_g_idx is not None, "Group index is required by Exllama"
output = ops.gptq_gemm( output = ops.gptq_gemm(
x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits
) )
if bias is not None: if bias is not None:
......
...@@ -26,7 +26,7 @@ elif current_platform.is_rocm(): ...@@ -26,7 +26,7 @@ elif current_platform.is_rocm():
try: try:
from vllm._custom_ops import reshape_and_cache_cuda from vllm._custom_ops import reshape_and_cache_cuda
# from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] # from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
from flash_attn import vllm_flash_attn_varlen_func from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
except ImportError: except ImportError:
def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc] def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment