Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ffbef65c
Commit
ffbef65c
authored
Aug 01, 2024
by
zhuwenwen
Browse files
support fa pad
parent
1190e964
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
58 additions
and
17 deletions
+58
-17
README.md
README.md
+4
-1
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+7
-3
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+13
-3
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+7
-3
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+13
-3
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+14
-4
No files found.
README.md
View file @
ffbef65c
...
...
@@ -54,7 +54,10 @@ pip install setuptools wheel
```
shell
git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git
# 根据需要的分支进行切换
```
安装依赖:
```
shell
pip
install
-r
requirements-rocm.txt
```
-
提供2种源码编译方式(进入vllm目录):
```
1. 编译whl包并安装
...
...
vllm/model_executor/models/baichuan.py
View file @
ffbef65c
...
...
@@ -183,7 +183,7 @@ class BaiChuanAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
qkv
.
shape
[
-
1
]
==
12320
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
...
...
@@ -413,13 +413,17 @@ class BaiChuanBaseForCausalLM(nn.Module):
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attn.W_pack.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
weight
.
data
.
shape
[
0
]
==
12288
:
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
...
...
vllm/model_executor/models/chatglm.py
View file @
ffbef65c
...
...
@@ -106,7 +106,7 @@ class GLMAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
qkv
.
shape
[
-
1
]
==
12320
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
...
...
@@ -413,13 +413,23 @@ class ChatGLMForCausalLM(nn.Module):
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attention.query_key_value.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
lay_qkv_bias_words
=
[
"self_attention.query_key_value.bias"
]
qkv_bias_words
=
"|"
.
join
(
lay_qkv_bias_words
)
for
layername
,
weight
in
params_dict
.
items
():
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_bias_words
,
layername
)):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
weight
.
data
.
shape
[
0
]
==
12288
:
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
...
...
vllm/model_executor/models/llama.py
View file @
ffbef65c
...
...
@@ -161,7 +161,7 @@ class LlamaAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
qkv
.
shape
[
-
1
]
==
12320
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
...
@@ -456,13 +456,17 @@ class LlamaForCausalLM(nn.Module):
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attn.qkv_proj.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
weight
.
data
.
shape
[
0
]
==
12288
:
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
...
...
vllm/model_executor/models/qwen.py
View file @
ffbef65c
...
...
@@ -123,7 +123,7 @@ class QWenAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
qkv
.
shape
[
-
1
]
==
12320
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
...
@@ -312,13 +312,23 @@ class QWenLMHeadModel(nn.Module):
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"attn.c_attn.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
lay_qkv_bias_words
=
[
"attn.c_attn.bias"
]
qkv_bias_words
=
"|"
.
join
(
lay_qkv_bias_words
)
for
layername
,
weight
in
params_dict
.
items
():
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_bias_words
,
layername
)):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
weight
.
data
.
shape
[
0
]
==
12288
:
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
...
...
vllm/model_executor/models/qwen2.py
View file @
ffbef65c
...
...
@@ -153,7 +153,7 @@ class Qwen2Attention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
qkv
.
shape
[
-
1
]
==
12320
:
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
...
@@ -399,13 +399,23 @@ class Qwen2ForCausalLM(nn.Module):
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attn.qkv_proj.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
lay_qkv_bias_words
=
[
"self_attn.qkv_proj.bias"
]
qkv_bias_words
=
"|"
.
join
(
lay_qkv_bias_words
)
for
layername
,
weight
in
params_dict
.
items
():
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_bias_words
,
layername
)):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
weight
.
data
.
shape
[
0
]
==
12288
:
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment