Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0fca3cdc
Unverified
Commit
0fca3cdc
authored
May 13, 2024
by
Woosuk Kwon
Committed by
GitHub
May 13, 2024
Browse files
[Misc] Enhance attention selector (#4751)
parent
e7c46b95
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
90 additions
and
27 deletions
+90
-27
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+13
-3
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+10
-4
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+15
-3
vllm/model_executor/models/xverse.py
vllm/model_executor/models/xverse.py
+10
-4
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+11
-3
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+11
-4
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+9
-1
vllm/worker/embedding_model_runner.py
vllm/worker/embedding_model_runner.py
+0
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+11
-4
No files found.
vllm/model_executor/models/qwen2_moe.py
View file @
0fca3cdc
...
...
@@ -30,6 +30,7 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
...
...
@@ -187,6 +188,7 @@ class Qwen2MoeAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -238,7 +240,8 @@ class Qwen2MoeAttention(nn.Module):
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
)
def
forward
(
self
,
...
...
@@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -276,6 +280,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
if
(
config
.
num_experts
is
not
None
...
...
@@ -328,6 +333,7 @@ class Qwen2MoeModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -339,7 +345,10 @@ class Qwen2MoeModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
Qwen2MoeDecoderLayer
(
config
,
layer_idx
,
quant_config
=
quant_config
)
Qwen2MoeDecoderLayer
(
config
,
layer_idx
,
cache_config
,
quant_config
=
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -369,12 +378,13 @@ class Qwen2MoeForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2MoeModel
(
config
,
quant_config
)
self
.
model
=
Qwen2MoeModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/stablelm.py
View file @
0fca3cdc
...
...
@@ -26,6 +26,7 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -72,6 +73,7 @@ class StablelmAttention(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -124,7 +126,8 @@ class StablelmAttention(nn.Module):
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_key_value_heads
)
num_kv_heads
=
self
.
num_key_value_heads
,
cache_config
=
cache_config
)
def
forward
(
self
,
...
...
@@ -146,10 +149,11 @@ class StablelmDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
self_attn
=
StablelmAttention
(
config
)
self
.
self_attn
=
StablelmAttention
(
config
,
cache_config
,
quant_config
)
self
.
mlp
=
StablelmMLP
(
config
,
quant_config
)
norm_eps
=
getattr
(
config
,
"norm_eps"
,
getattr
(
config
,
"layer_norm_eps"
,
1e-05
))
...
...
@@ -188,6 +192,7 @@ class StableLMEpochModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
...
...
@@ -195,7 +200,7 @@ class StableLMEpochModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
StablelmDecoderLayer
(
config
,
quant_config
)
StablelmDecoderLayer
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
norm_eps
=
getattr
(
config
,
"norm_eps"
,
...
...
@@ -227,12 +232,13 @@ class StablelmForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
StableLMEpochModel
(
config
,
quant_config
)
self
.
model
=
StableLMEpochModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/starcoder2.py
View file @
0fca3cdc
...
...
@@ -25,6 +25,7 @@ from torch import nn
from
transformers
import
Starcoder2Config
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -46,6 +47,7 @@ class Starcoder2Attention(nn.Module):
def
__init__
(
self
,
config
:
Starcoder2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -101,6 +103,7 @@ class Starcoder2Attention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
self
.
sliding_window
,
cache_config
=
cache_config
,
)
def
forward
(
...
...
@@ -150,10 +153,13 @@ class Starcoder2DecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
Starcoder2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
Starcoder2Attention
(
config
,
quant_config
=
quant_config
)
self
.
self_attn
=
Starcoder2Attention
(
config
,
cache_config
,
quant_config
=
quant_config
)
self
.
mlp
=
Starcoder2MLP
(
config
,
quant_config
=
quant_config
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_epsilon
)
...
...
@@ -191,6 +197,7 @@ class Starcoder2Model(nn.Module):
def
__init__
(
self
,
config
:
Starcoder2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -201,7 +208,9 @@ class Starcoder2Model(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
Starcoder2DecoderLayer
(
config
,
quant_config
=
quant_config
)
Starcoder2DecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_epsilon
)
...
...
@@ -226,10 +235,13 @@ class Starcoder2ForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
Starcoder2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
Starcoder2Model
(
config
,
quant_config
=
quant_config
)
self
.
model
=
Starcoder2Model
(
config
,
cache_config
,
quant_config
=
quant_config
)
self
.
vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
config
.
tie_word_embeddings
:
...
...
vllm/model_executor/models/xverse.py
View file @
0fca3cdc
...
...
@@ -27,7 +27,7 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -89,6 +89,7 @@ class XverseAttention(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
sliding_window
:
Optional
[
int
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -133,7 +134,8 @@ class XverseAttention(nn.Module):
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
sliding_window
)
sliding_window
=
sliding_window
,
cache_config
=
cache_config
)
def
forward
(
self
,
...
...
@@ -155,6 +157,7 @@ class XverseDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -175,6 +178,7 @@ class XverseDecoderLayer(nn.Module):
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"bias"
,
False
),
sliding_window
=
sliding_window
,
cache_config
=
cache_config
,
)
self
.
mlp
=
XverseMLP
(
hidden_size
=
self
.
hidden_size
,
...
...
@@ -221,6 +225,7 @@ class XverseModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
...
...
@@ -237,7 +242,7 @@ class XverseModel(nn.Module):
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
layers
=
nn
.
ModuleList
([
XverseDecoderLayer
(
config
,
quant_config
)
XverseDecoderLayer
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -295,13 +300,14 @@ class XverseForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
XverseModel
(
config
,
quant_config
)
self
.
model
=
XverseModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/worker/cache_engine.py
View file @
0fca3cdc
...
...
@@ -31,7 +31,7 @@ class CacheEngine:
self
.
head_size
=
model_config
.
get_head_size
()
self
.
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
self
.
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
num_
kv_
heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
block_size
=
cache_config
.
block_size
self
.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
...
...
@@ -43,7 +43,15 @@ class CacheEngine:
self
.
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
model_config
.
dtype
)
self
.
attn_backend
=
get_attn_backend
(
model_config
.
get_num_attention_heads
(
parallel_config
),
self
.
head_size
,
self
.
num_kv_heads
,
model_config
.
get_sliding_window
(),
model_config
.
dtype
,
cache_config
.
cache_dtype
,
self
.
block_size
,
)
# Initialize the cache.
self
.
gpu_cache
=
self
.
_allocate_kv_cache
(
self
.
num_gpu_blocks
,
"cuda"
)
...
...
@@ -56,7 +64,7 @@ class CacheEngine:
)
->
List
[
torch
.
Tensor
]:
"""Allocates KV cache on the specified device."""
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
self
.
block_size
,
self
.
num_heads
,
self
.
head_size
)
num_blocks
,
self
.
block_size
,
self
.
num_
kv_
heads
,
self
.
head_size
)
pin_memory
=
is_pin_memory_available
()
if
device
==
"cpu"
else
False
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
self
.
num_layers
):
...
...
vllm/worker/cpu_model_runner.py
View file @
0fca3cdc
...
...
@@ -53,7 +53,15 @@ class CPUModelRunner:
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
)
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
)
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
...
...
@@ -66,7 +74,8 @@ class CPUModelRunner:
vision_language_config
=
self
.
vision_language_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
scheduler_config
=
self
.
scheduler_config
,
cache_config
=
self
.
cache_config
)
def
_prepare_prompt
(
self
,
...
...
@@ -158,7 +167,6 @@ class CPUModelRunner:
decode_metadata
=
None
,
block_tables
=
torch
.
tensor
([]),
slot_mapping
=
slot_mapping
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_input
)
...
...
@@ -242,7 +250,6 @@ class CPUModelRunner:
prefill_metadata
=
None
,
decode_metadata
=
None
,
block_tables
=
block_tables
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
...
...
vllm/worker/cpu_worker.py
View file @
0fca3cdc
...
...
@@ -53,7 +53,15 @@ class CPUCacheEngine:
self
.
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
model_config
.
dtype
)
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
cache_config
.
cache_dtype
,
self
.
block_size
,
)
# Initialize the cache.
self
.
cpu_cache
=
self
.
_allocate_kv_cache
(
self
.
num_cpu_blocks
)
...
...
vllm/worker/embedding_model_runner.py
View file @
0fca3cdc
...
...
@@ -235,7 +235,6 @@ class EmbeddingModelRunner(ModelRunner):
num_decode_tokens
=
num_decode_tokens
,
prefill_metadata
=
prefill_attn_metadata
,
decode_metadata
=
decode_attn_metadata
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
pooling_metadata
,
...
...
vllm/worker/model_runner.py
View file @
0fca3cdc
...
...
@@ -141,10 +141,18 @@ class ModelRunner:
self
.
graph_block_tables
=
np
.
zeros
(
(
max
(
_BATCH_SIZES_TO_CAPTURE
),
self
.
get_max_block_per_batch
()),
dtype
=
np
.
int32
)
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
)
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
)
# Lazy initialization
self
.
model
:
torch
.
nn
.
Module
# Set after load_model
self
.
model
:
nn
.
Module
# Set after load_model
# Set if the backend is flashinfer.
self
.
flashinfer_workspace_buffer
:
torch
.
Tensor
# Set after load_model.
...
...
@@ -160,6 +168,7 @@ class ModelRunner:
vision_language_config
=
self
.
vision_language_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
cache_config
=
self
.
cache_config
,
)
self
.
model_memory_usage
=
m
.
consumed_memory
...
...
@@ -753,7 +762,6 @@ class ModelRunner:
num_decode_tokens
=
num_decode_tokens
,
prefill_metadata
=
prefill_attn_metadata
,
decode_metadata
=
decode_attn_metadata
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
...
...
@@ -965,7 +973,6 @@ class ModelRunner:
slot_mapping
=
slot_mapping
[:
batch_size
],
prefill_metadata
=
None
,
decode_metadata
=
decode_metadata
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
if
self
.
lora_config
:
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment