Commit b3ab1cdc authored by zhuwenwen's avatar zhuwenwen
Browse files

support baichuan awq and skip _rocm_C

parent 422af727
......@@ -344,6 +344,7 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP")
#
# _rocm_C extension
......@@ -362,6 +363,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3
WITH_SOABI)
endif()
]]
# vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
......@@ -389,6 +391,7 @@ endif()
if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
#[[
else()
FetchContent_Declare(
vllm-flash-attn
......@@ -396,11 +399,13 @@ else()
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
GIT_PROGRESS TRUE
)
]]
endif()
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON)
#[[
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)
......@@ -426,3 +431,4 @@ install(
)
# Nothing after vllm-flash-attn, see comment about macros above
]]
\ No newline at end of file
......@@ -532,8 +532,8 @@ if _build_core_ext():
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
# if _is_hip():
# ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda():
ext_modules.append(
......
......@@ -22,8 +22,8 @@ if not current_platform.is_tpu():
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
if current_platform.is_rocm():
import vllm._rocm_C # noqa: F401
# if current_platform.is_rocm():
# import vllm._rocm_C # noqa: F401
supports_moe_ops = False
with contextlib.suppress(ImportError):
......
......@@ -461,6 +461,47 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
weight.data=weight.data.reshape(ori_shape[1], -1)
if self.quant_method == "awq":
lay_key_words = [
"self_attn.W_pack.qweight",
"self_attn.o_proj.qweight",
"mlp.gate_up_proj.qweight",
"mlp.down_proj.qweight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
qweight =params_dict[layername]
qzeros=params_dict[layername.replace("qweight", "qzeros")]
scales=params_dict[layername.replace("qweight", "scales")]
zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
group_size= self.quant_config.group_size
dim_n = scales.data.shape[1]
dim_k = qweight.data.shape[0]
pad_group=2
_qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
sz = ops.sz_permute(_sz).reshape(-1,dim_n)
zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw)
#reshape
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
if dim_k % 4096==0 and self.use_awq_pad:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment