Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0b5e4e11
Commit
0b5e4e11
authored
Jul 24, 2024
by
zhuwenwen
Browse files
add gemm pad and fa pad for 7b model
parent
2d0a73a3
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
106 additions
and
4 deletions
+106
-4
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+5
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+7
-0
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+13
-1
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+13
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+13
-1
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+13
-1
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+13
-0
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+29
-0
No files found.
vllm/model_executor/layers/linear.py
View file @
0b5e4e11
...
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.utils
import
set_weight_attrs
import
os
from
vllm.model_executor.utils
import
gemm_bank_conf
logger
=
init_logger
(
__name__
)
...
...
@@ -88,6 +89,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
self
.
separate_bias_add
=
separate_bias_add
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
...
...
@@ -114,6 +116,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
return
F
.
linear
(
x
,
weight
)
if
self
.
use_llama_nn
:
if
gemm_bank_conf
(
weight
.
shape
[
1
]
-
32
)
and
os
.
environ
[
'GEMM_PAD'
]
==
'1'
:
weight
=
weight
[:,:
-
32
]
if
bias
is
not
None
:
if
len
(
x
.
shape
)
==
2
:
return
torch
.
addmm
(
bias
,
x
,
weight
)
...
...
vllm/model_executor/model_loader/utils.py
View file @
0b5e4e11
...
...
@@ -26,8 +26,15 @@ def get_model_architecture(
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
if
os
.
getenv
(
'GEMM_PAD'
)
!=
'1'
:
os
.
environ
[
'GEMM_PAD'
]
=
'0'
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
os
.
environ
[
'FA_PAD'
]
=
'0'
else
:
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'FA_PAD'
]
=
'0'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if
(
model_config
.
quantization
is
not
None
...
...
vllm/model_executor/models/baichuan.py
View file @
0b5e4e11
...
...
@@ -46,7 +46,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
...
...
@@ -181,6 +183,8 @@ class BaiChuanAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
qkv
.
shape
[
-
1
]
==
12320
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
...
@@ -330,6 +334,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
self
,
...
...
@@ -409,7 +415,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
weight
.
data
.
shape
[
0
]
==
12288
:
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
...
...
vllm/model_executor/models/chatglm.py
View file @
0b5e4e11
...
...
@@ -29,7 +29,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
class
GLMAttention
(
nn
.
Module
):
...
...
@@ -104,6 +106,8 @@ class GLMAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
qkv
.
shape
[
-
1
]
==
12320
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
context_layer
=
self
.
attn
(
...
...
@@ -357,6 +361,8 @@ class ChatGLMForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
self
,
...
...
@@ -409,7 +415,13 @@ class ChatGLMForCausalLM(nn.Module):
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
weight
.
data
.
shape
[
0
]
==
12288
:
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
...
...
vllm/model_executor/models/llama.py
View file @
0b5e4e11
...
...
@@ -52,6 +52,8 @@ from vllm.sequence import SamplerOutput
from
vllm.utils
import
is_hip
,
print_warning_once
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
class
LlamaMLP
(
nn
.
Module
):
...
...
@@ -159,6 +161,8 @@ class LlamaAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
qkv
.
shape
[
-
1
]
==
12320
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
...
...
@@ -364,6 +368,8 @@ class LlamaForCausalLM(nn.Module):
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
self
,
...
...
@@ -452,7 +458,13 @@ class LlamaForCausalLM(nn.Module):
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
weight
.
data
.
shape
[
0
]
==
12288
:
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
...
...
vllm/model_executor/models/qwen.py
View file @
0b5e4e11
...
...
@@ -33,6 +33,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
class
QWenMLP
(
nn
.
Module
):
def
__init__
(
...
...
@@ -120,6 +123,8 @@ class QWenAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
qkv
.
shape
[
-
1
]
==
12320
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
...
...
@@ -202,7 +207,6 @@ class QWenModel(nn.Module):
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
forward
(
self
,
...
...
@@ -242,6 +246,8 @@ class QWenLMHeadModel(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
self
,
...
...
@@ -309,6 +315,12 @@ class QWenLMHeadModel(nn.Module):
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
weight
.
data
.
shape
[
0
]
==
12288
:
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
...
...
vllm/model_executor/models/qwen2.py
View file @
0b5e4e11
...
...
@@ -50,6 +50,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
class
Qwen2MLP
(
nn
.
Module
):
def
__init__
(
...
...
@@ -150,6 +153,8 @@ class Qwen2Attention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
qkv
.
shape
[
-
1
]
==
12320
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
...
...
@@ -322,6 +327,8 @@ class Qwen2ForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
self
,
...
...
@@ -395,6 +402,12 @@ class Qwen2ForCausalLM(nn.Module):
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
weight
.
data
.
shape
[
0
]
==
12288
:
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
...
...
vllm/model_executor/utils.py
View file @
0b5e4e11
...
...
@@ -33,3 +33,32 @@ def set_weight_attrs(
assert
not
hasattr
(
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
setattr
(
weight
,
key
,
value
)
def
pad_weight
(
weight
:
torch
.
Tensor
,
num_pad
:
int
,
pad_dim
:
int
=
0
):
if
weight
.
dim
()
==
1
:
padding
=
torch
.
zeros
(
num_pad
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
0
)
elif
weight
.
dim
()
==
2
:
if
pad_dim
==
0
:
padding
=
torch
.
zeros
(
num_pad
,
weight
.
shape
[
1
],
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
0
)
elif
pad_dim
==
1
:
padding
=
torch
.
zeros
(
weight
.
shape
[
0
],
num_pad
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
1
)
else
:
raise
ValueError
(
"pad_dim must be 0 or 1"
)
else
:
raise
ValueError
(
"Weight tensor must be 1D or 2D"
)
padded_weight
=
padded_weight
.
contiguous
()
return
padded_weight
def
gemm_bank_conf
(
weight
):
is_mul_of_2048
=
weight
%
2048
==
0
is_power_of_two
=
(
weight
&
(
weight
-
1
))
==
0
and
weight
!=
0
if
is_mul_of_2048
and
is_power_of_two
:
return
True
else
:
return
False
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment