Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
371b1251
Commit
371b1251
authored
Jul 06, 2024
by
zhuwenwen
Browse files
add fa pad
parent
1863c926
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
19 additions
and
0 deletions
+19
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+5
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+2
-0
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+3
-0
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+3
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+3
-0
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+3
-0
No files found.
vllm/model_executor/layers/linear.py
View file @
371b1251
...
...
@@ -607,6 +607,7 @@ class QKVParallelLinear(ColumnParallelLinear):
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
...
...
@@ -763,8 +764,12 @@ class QKVParallelLinear(ColumnParallelLinear):
assert
param_data_
.
shape
==
loaded_weight
.
shape
param_data_
.
copy_
(
loaded_weight
)
if
loaded_shard_id
==
"v"
and
len
(
param_data
.
shape
)
==
2
:
if
self
.
use_fa_pad
and
param_data
.
shape
[
0
]
==
12288
:
param_data
=
pad_weight
(
param
.
data
,
32
)
param_data
=
param_data
.
transpose
(
0
,
1
)
param
.
data
=
param_data
.
reshape
(
param_data
.
shape
[
1
],
-
1
)
if
self
.
use_fa_pad
and
param_data
.
shape
[
0
]
==
12288
and
loaded_shard_id
==
"v"
and
len
(
param_data
.
shape
)
==
1
:
param
.
data
=
pad_weight
(
param
.
data
,
32
)
else
:
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
vllm/model_executor/model_loader/utils.py
View file @
371b1251
...
...
@@ -27,6 +27,8 @@ def get_model_architecture(
os
.
environ
[
'LLAMA_NN'
]
=
'1'
if
os
.
getenv
(
'GEMM_PAD'
)
!=
'0'
:
os
.
environ
[
'GEMM_PAD'
]
=
'1'
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
os
.
environ
[
'FA_PAD'
]
=
'0'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if
(
model_config
.
quantization
is
not
None
...
...
vllm/model_executor/models/baichuan.py
View file @
371b1251
...
...
@@ -24,6 +24,7 @@ from typing import Iterable, List, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
import
os
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
...
@@ -178,6 +179,8 @@ class BaiChuanAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
...
vllm/model_executor/models/chatglm.py
View file @
371b1251
...
...
@@ -7,6 +7,7 @@ from typing import Iterable, List, Optional, Tuple
import
torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
import
os
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
...
@@ -101,6 +102,8 @@ class GLMAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
context_layer
=
self
.
attn
(
...
...
vllm/model_executor/models/llama.py
View file @
371b1251
...
...
@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
import
os
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
...
@@ -156,6 +157,8 @@ class LlamaAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
...
...
vllm/model_executor/models/qwen2.py
View file @
371b1251
...
...
@@ -27,6 +27,7 @@ from typing import Iterable, List, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
Qwen2Config
import
os
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
...
@@ -148,6 +149,8 @@ class Qwen2Attention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment