Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b3ab1cdc
Commit
b3ab1cdc
authored
Oct 15, 2024
by
zhuwenwen
Browse files
support baichuan awq and skip _rocm_C
parent
422af727
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
4 deletions
+51
-4
CMakeLists.txt
CMakeLists.txt
+6
-0
setup.py
setup.py
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+2
-2
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+41
-0
No files found.
CMakeLists.txt
View file @
b3ab1cdc
...
@@ -344,6 +344,7 @@ define_gpu_extension_target(
...
@@ -344,6 +344,7 @@ define_gpu_extension_target(
USE_SABI 3
USE_SABI 3
WITH_SOABI
)
WITH_SOABI
)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP")
if(VLLM_GPU_LANG STREQUAL "HIP")
#
#
# _rocm_C extension
# _rocm_C extension
...
@@ -362,6 +363,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
...
@@ -362,6 +363,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3
USE_SABI 3
WITH_SOABI)
WITH_SOABI)
endif()
endif()
]]
# vllm-flash-attn currently only supported on CUDA
# vllm-flash-attn currently only supported on CUDA
if
(
NOT VLLM_TARGET_DEVICE STREQUAL
"cuda"
)
if
(
NOT VLLM_TARGET_DEVICE STREQUAL
"cuda"
)
...
@@ -389,6 +391,7 @@ endif()
...
@@ -389,6 +391,7 @@ endif()
if
(
VLLM_FLASH_ATTN_SRC_DIR
)
if
(
VLLM_FLASH_ATTN_SRC_DIR
)
FetchContent_Declare
(
vllm-flash-attn SOURCE_DIR
${
VLLM_FLASH_ATTN_SRC_DIR
}
)
FetchContent_Declare
(
vllm-flash-attn SOURCE_DIR
${
VLLM_FLASH_ATTN_SRC_DIR
}
)
#[[
else()
else()
FetchContent_Declare(
FetchContent_Declare(
vllm-flash-attn
vllm-flash-attn
...
@@ -396,11 +399,13 @@ else()
...
@@ -396,11 +399,13 @@ else()
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
GIT_PROGRESS TRUE
GIT_PROGRESS TRUE
)
)
]]
endif
()
endif
()
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set
(
VLLM_PARENT_BUILD ON
)
set
(
VLLM_PARENT_BUILD ON
)
#[[
# Ensure the vllm/vllm_flash_attn directory exists before installation
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)
...
@@ -426,3 +431,4 @@ install(
...
@@ -426,3 +431,4 @@ install(
)
)
# Nothing after vllm-flash-attn, see comment about macros above
# Nothing after vllm-flash-attn, see comment about macros above
]]
\ No newline at end of file
setup.py
View file @
b3ab1cdc
...
@@ -532,8 +532,8 @@ if _build_core_ext():
...
@@ -532,8 +532,8 @@ if _build_core_ext():
if
_is_cuda
()
or
_is_hip
():
if
_is_cuda
()
or
_is_hip
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
if
_is_hip
():
#
if _is_hip():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._rocm_C"
))
#
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if
_is_cuda
():
if
_is_cuda
():
ext_modules
.
append
(
ext_modules
.
append
(
...
...
vllm/_custom_ops.py
View file @
b3ab1cdc
...
@@ -22,8 +22,8 @@ if not current_platform.is_tpu():
...
@@ -22,8 +22,8 @@ if not current_platform.is_tpu():
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
if
current_platform
.
is_rocm
():
#
if current_platform.is_rocm():
import
vllm._rocm_C
# noqa: F401
#
import vllm._rocm_C # noqa: F401
supports_moe_ops
=
False
supports_moe_ops
=
False
with
contextlib
.
suppress
(
ImportError
):
with
contextlib
.
suppress
(
ImportError
):
...
...
vllm/model_executor/models/baichuan.py
View file @
b3ab1cdc
...
@@ -461,6 +461,47 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
...
@@ -461,6 +461,47 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
if
self
.
quant_method
==
"awq"
:
lay_key_words
=
[
"self_attn.W_pack.qweight"
,
"self_attn.o_proj.qweight"
,
"mlp.gate_up_proj.qweight"
,
"mlp.down_proj.qweight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
qweight
=
params_dict
[
layername
]
qzeros
=
params_dict
[
layername
.
replace
(
"qweight"
,
"qzeros"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
zeros_and_scalse
=
params_dict
[
layername
.
replace
(
"qweight"
,
"zeros_and_scales"
)]
group_size
=
self
.
quant_config
.
group_size
dim_n
=
scales
.
data
.
shape
[
1
]
dim_k
=
qweight
.
data
.
shape
[
0
]
pad_group
=
2
_qw
,
_sz
=
ops
.
convert_s4
(
qweight
,
qzeros
,
scales
,
int
(
group_size
))
sz
=
ops
.
sz_permute
(
_sz
).
reshape
(
-
1
,
dim_n
)
zeros_and_scalse
.
data
.
copy_
(
sz
)
qweight
.
data
.
copy_
(
_qw
)
#reshape
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
if
dim_k
%
4096
==
0
and
self
.
use_awq_pad
:
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
class
BaichuanForCausalLM
(
BaiChuanBaseForCausalLM
):
class
BaichuanForCausalLM
(
BaiChuanBaseForCausalLM
):
"""Baichuan 13B and Baichuan2 7B/13B."""
"""Baichuan 13B and Baichuan2 7B/13B."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment