Commit b3ab1cdc authored by zhuwenwen's avatar zhuwenwen
Browse files

support baichuan awq and skip _rocm_C

parent 422af727
...@@ -344,6 +344,7 @@ define_gpu_extension_target( ...@@ -344,6 +344,7 @@ define_gpu_extension_target(
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "HIP")
# #
# _rocm_C extension # _rocm_C extension
...@@ -362,6 +363,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") ...@@ -362,6 +363,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
endif() endif()
]]
# vllm-flash-attn currently only supported on CUDA # vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda") if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
...@@ -389,6 +391,7 @@ endif() ...@@ -389,6 +391,7 @@ endif()
if(VLLM_FLASH_ATTN_SRC_DIR) if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR}) FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
#[[
else() else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
...@@ -396,11 +399,13 @@ else() ...@@ -396,11 +399,13 @@ else()
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
) )
]]
endif() endif()
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization. # Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON) set(VLLM_PARENT_BUILD ON)
#[[
# Ensure the vllm/vllm_flash_attn directory exists before installation # Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c) install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)
...@@ -426,3 +431,4 @@ install( ...@@ -426,3 +431,4 @@ install(
) )
# Nothing after vllm-flash-attn, see comment about macros above # Nothing after vllm-flash-attn, see comment about macros above
]]
\ No newline at end of file
...@@ -532,8 +532,8 @@ if _build_core_ext(): ...@@ -532,8 +532,8 @@ if _build_core_ext():
if _is_cuda() or _is_hip(): if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C")) ext_modules.append(CMakeExtension(name="vllm._moe_C"))
if _is_hip(): # if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C")) # ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda(): if _is_cuda():
ext_modules.append( ext_modules.append(
......
...@@ -22,8 +22,8 @@ if not current_platform.is_tpu(): ...@@ -22,8 +22,8 @@ if not current_platform.is_tpu():
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e) logger.warning("Failed to import from vllm._C with %r", e)
if current_platform.is_rocm(): # if current_platform.is_rocm():
import vllm._rocm_C # noqa: F401 # import vllm._rocm_C # noqa: F401
supports_moe_ops = False supports_moe_ops = False
with contextlib.suppress(ImportError): with contextlib.suppress(ImportError):
......
...@@ -461,6 +461,47 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -461,6 +461,47 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
weight.data=weight.data.reshape(ori_shape[1], -1) weight.data=weight.data.reshape(ori_shape[1], -1)
if self.quant_method == "awq":
lay_key_words = [
"self_attn.W_pack.qweight",
"self_attn.o_proj.qweight",
"mlp.gate_up_proj.qweight",
"mlp.down_proj.qweight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
qweight =params_dict[layername]
qzeros=params_dict[layername.replace("qweight", "qzeros")]
scales=params_dict[layername.replace("qweight", "scales")]
zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
group_size= self.quant_config.group_size
dim_n = scales.data.shape[1]
dim_k = qweight.data.shape[0]
pad_group=2
_qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
sz = ops.sz_permute(_sz).reshape(-1,dim_n)
zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw)
#reshape
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
if dim_k % 4096==0 and self.use_awq_pad:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B.""" """Baichuan 13B and Baichuan2 7B/13B."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment