Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ffa48c91
Unverified
Commit
ffa48c91
authored
Dec 11, 2024
by
Mor Zusman
Committed by
GitHub
Dec 10, 2024
Browse files
[Model] PP support for Mamba-like models (#10992)
Signed-off-by:
mzusman
<
mor.zusmann@gmail.com
>
parent
d5c5154f
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
229 additions
and
81 deletions
+229
-81
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+3
-3
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+4
-2
vllm/config.py
vllm/config.py
+44
-14
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+37
-0
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+65
-28
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+48
-20
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+10
-1
vllm/utils.py
vllm/utils.py
+5
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+4
-4
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+3
-3
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+6
-6
No files found.
docs/source/models/supported_models.rst
View file @
ffa48c91
...
...
@@ -128,7 +128,7 @@ Text Generation
- FalconMamba
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc.
- ✅︎
-
-
✅︎
* - :code:`GemmaForCausalLM`
- Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
...
...
@@ -193,7 +193,7 @@ Text Generation
- Jamba
- :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc.
- ✅︎
-
-
✅︎
* - :code:`LlamaForCausalLM`
- Llama 3.1, Llama 3, Llama 2, LLaMA, Yi
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
...
...
@@ -203,7 +203,7 @@ Text Generation
- Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
-
-
-
✅︎
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc.
...
...
tests/distributed/test_pipeline_parallel.py
View file @
ffa48c91
...
...
@@ -156,13 +156,13 @@ TEXT_GENERATION_MODELS = {
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
"internlm/internlm2-chat-7b"
:
PPTestSettings
.
fast
(
trust_remote_code
=
True
),
"inceptionai/jais-13b-chat"
:
PPTestSettings
.
fast
(),
# TODO: Implement PP
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
"ai21labs/Jamba-tiny-dev"
:
PPTestSettings
.
fast
(),
"meta-llama/Meta-Llama-3-8B"
:
PPTestSettings
.
detailed
(),
"openbmb/MiniCPM-2B-sft-bf16"
:
PPTestSettings
.
fast
(
trust_remote_code
=
True
),
"openbmb/MiniCPM3-4B"
:
PPTestSettings
.
fast
(
trust_remote_code
=
True
),
# Uses Llama
# "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
"state-spaces/mamba-130m-hf"
:
PPTestSettings
.
fast
(),
"mistralai/Mixtral-8x7B-Instruct-v0.1"
:
PPTestSettings
.
fast
(
tp_base
=
4
),
"mosaicml/mpt-7b"
:
PPTestSettings
.
fast
(),
"nvidia/Minitron-8B-Base"
:
PPTestSettings
.
fast
(),
...
...
@@ -234,6 +234,8 @@ TEST_MODELS = [
"OpenGVLab/InternVL2-1B"
,
"microsoft/Phi-3-vision-128k-instruct"
,
"fixie-ai/ultravox-v0_3"
,
# [LANGUAGE GENERATION - HYBRID ARCH]
"ai21labs/Jamba-tiny-dev"
,
]
...
...
vllm/config.py
View file @
ffa48c91
...
...
@@ -27,8 +27,8 @@ from vllm.transformers_utils.config import (
ConfigFormat
,
get_config
,
get_hf_image_processor_config
,
get_hf_text_config
,
get_pooling_config
,
get_sentence_transformer_tokenizer_config
,
is_encoder_decoder
,
uses_mrope
)
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
print_warning_once
,
random_uuid
,
from
vllm.utils
import
(
GiB_bytes
,
LayerBlockType
,
cuda_device_count_stateless
,
get_cpu_memory
,
print_warning_once
,
random_uuid
,
resolve_obj_by_qualname
)
if
TYPE_CHECKING
:
...
...
@@ -284,6 +284,7 @@ class ModelConfig:
self
.
_verify_tokenizer_mode
()
self
.
is_attention_free
=
self
.
_init_attention_free
()
self
.
is_hybrid
=
self
.
_init_is_hybrid
()
self
.
has_inner_state
=
self
.
_init_has_inner_state
()
if
current_platform
.
is_neuron
():
...
...
@@ -340,6 +341,10 @@ class ModelConfig:
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
return
ModelRegistry
.
is_attention_free_model
(
architectures
)
def
_init_is_hybrid
(
self
)
->
bool
:
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
return
ModelRegistry
.
is_hybrid_model
(
architectures
)
def
_init_has_inner_state
(
self
)
->
bool
:
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
return
ModelRegistry
.
model_has_inner_state
(
architectures
)
...
...
@@ -669,26 +674,51 @@ class ModelConfig:
num_heads
=
getattr
(
self
.
hf_text_config
,
"num_attention_heads"
,
0
)
return
num_heads
//
parallel_config
.
tensor_parallel_size
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
def
get_layers_start_end_indices
(
self
,
parallel_config
:
"ParallelConfig"
)
->
Tuple
[
int
,
int
]:
from
vllm.distributed.utils
import
get_pp_indices
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
"num_hidden_layers"
,
0
)
pp_rank
=
parallel_config
.
rank
//
parallel_config
.
tensor_parallel_size
pp_size
=
parallel_config
.
pipeline_parallel_size
start
,
end
=
get_pp_indices
(
total_num_hidden_layers
,
pp_rank
,
pp_size
)
return
end
-
start
def
get_num_attention_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
if
self
.
is_attention_free
:
return
0
return
start
,
end
num_layers
=
self
.
get_num_layers
(
parallel_config
)
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
start
,
end
=
self
.
get_layers_start_end_indices
(
parallel_config
)
return
end
-
start
# Transformers supports layers_block_type @property
layers
=
getattr
(
self
.
hf_config
,
"layers_block_type"
,
[
"attention"
]
*
num_layers
)
return
len
([
t
for
t
in
layers
if
t
==
"attention"
])
def
get_num_layers_by_block_type
(
self
,
parallel_config
:
"ParallelConfig"
,
block_type
:
LayerBlockType
=
LayerBlockType
.
attention
,
)
->
int
:
# This function relies on 'layers_block_type' in hf_config,
# for w/o this attribute, we will need to have workarounds like so
attn_block_type
=
block_type
==
LayerBlockType
.
attention
is_transformer
=
not
self
.
is_hybrid
and
not
self
.
is_attention_free
start
,
end
=
self
.
get_layers_start_end_indices
(
parallel_config
)
if
is_transformer
:
# Handle the basic case first
return
end
-
start
if
attn_block_type
else
0
elif
self
.
is_attention_free
:
# Attention free
# Note that this code assumes there
# is only one type of attention-free block type.
return
0
if
attn_block_type
else
end
-
start
else
:
# Hybrid model
layers_block_type_value
=
getattr
(
self
.
hf_config
,
"layers_block_type"
,
None
)
if
layers_block_type_value
is
None
:
raise
ValueError
(
"The model is an hybrid without a"
"layers_block_type in the hf_config,"
"cannot determine the num of "
f
"
{
block_type
.
value
}
layers"
)
return
sum
(
t
==
block_type
.
value
for
t
in
layers_block_type_value
[
start
:
end
])
def
get_multimodal_config
(
self
)
->
"MultiModalConfig"
:
"""
...
...
vllm/model_executor/models/interfaces.py
View file @
ffa48c91
...
...
@@ -363,6 +363,43 @@ def is_attention_free(
return
isinstance
(
model
,
IsAttentionFree
)
@
runtime_checkable
class
IsHybrid
(
Protocol
):
"""The interface required for all models like Jamba that have both
attention and mamba blocks, indicates that
hf_config has 'layers_block_type'"""
is_hybrid
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model has both mamba and attention blocks
, also indicates that the model's hf_config has
'layers_block_type' """
@
runtime_checkable
class
_IsHybridType
(
Protocol
):
is_hybrid
:
ClassVar
[
Literal
[
True
]]
@
overload
def
is_hybrid
(
model
:
object
)
->
TypeIs
[
IsHybrid
]:
...
@
overload
def
is_hybrid
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
IsHybrid
]]:
...
def
is_hybrid
(
model
:
Union
[
Type
[
object
],
object
]
)
->
Union
[
TypeIs
[
Type
[
IsHybrid
]],
TypeIs
[
IsHybrid
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_IsHybridType
)
return
isinstance
(
model
,
IsHybrid
)
@
runtime_checkable
class
SupportsCrossEncoding
(
Protocol
):
"""The interface required for all models that support cross encoding."""
...
...
vllm/model_executor/models/jamba.py
View file @
ffa48c91
...
...
@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.config
import
_BATCH_SIZES_TO_CAPTURE
,
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
...
...
@@ -25,9 +26,12 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.interfaces
import
HasInnerState
,
SupportsLoRA
from
.utils
import
maybe_prefix
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -281,16 +285,24 @@ class JambaModel(nn.Module):
org_num_embeddings
=
config
.
vocab_size
,
)
decoder_layers
=
[]
for
i
in
range
(
config
.
num_hidden_layers
):
layer_class
=
ALL_DECODER_LAYER_TYPES
[
config
.
layers_block_type
[
i
]]
decoder_layers
.
append
(
layer_class
(
config
,
layer_idx
=
i
,
cache_config
=
cache_config
,
def
get_layer
(
prefix
:
str
):
layer_idx
=
int
(
prefix
.
rsplit
(
"."
,
1
)[
1
])
layer_class
=
ALL_DECODER_LAYER_TYPES
[
config
.
layers_block_type
[
layer_idx
]]
return
layer_class
(
config
,
layer_idx
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
i
}
"
))
self
.
layers
=
nn
.
ModuleList
(
decoder_layers
)
prefix
=
prefix
,
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
get_layer
,
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
self
.
final_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -304,26 +316,34 @@ class JambaModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
kv_cache_index
=
0
mamba_cache_index
=
0
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
kv_cache
=
None
layer_mamba_cache_params
=
None
if
isinstance
(
layer
,
JambaAttentionDecoderLayer
):
kv_cache
=
kv_caches
[
(
i
-
self
.
config
.
attn_layer_offset
)
//
self
.
config
.
attn_layer_period
]
kv_cache
=
kv_caches
[
kv_cache_index
]
kv_cache_index
+=
1
if
isinstance
(
layer
,
JambaMambaDecoderLayer
):
current_state_layer
=
i
-
(
1
+
(
i
-
self
.
config
.
attn_layer_offset
)
//
self
.
config
.
attn_layer_period
)
current_state_layer
=
mamba_cache_index
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
current_state_layer
)
mamba_cache_index
+=
1
hidden_states
,
residual
=
layer
(
positions
=
positions
,
...
...
@@ -332,11 +352,17 @@ class JambaModel(nn.Module):
attn_metadata
=
attn_metadata
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
final_layernorm
(
hidden_states
,
residual
)
return
hidden_states
class
JambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
SupportsLoRA
):
class
JambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
SupportsLoRA
,
SupportsPP
,
IsHybrid
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -368,6 +394,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
super
().
__init__
()
self
.
config
=
config
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
scheduler_config
=
scheduler_config
self
.
model
=
JambaModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
...
...
@@ -390,6 +418,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
...
@@ -406,10 +437,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
layers_type
=
self
.
config
.
layers_block_type
num_mamba_layers
=
sum
(
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
num_mamba_layers
=
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
...
...
@@ -423,7 +452,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
state_indices_tensor
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
mamba_cache_params
,
inputs_embeds
)
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
...
...
@@ -504,8 +533,12 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -520,6 +553,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
if
weight_name
not
in
name
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
...
...
@@ -533,6 +568,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
...
...
vllm/model_executor/models/mamba.py
View file @
ffa48c91
...
...
@@ -8,6 +8,7 @@ from transformers import MambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
_BATCH_SIZES_TO_CAPTURE
,
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_mixer
import
MambaMixer
...
...
@@ -18,13 +19,16 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
)
IsAttentionFree
,
SupportsPP
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.utils
import
maybe_prefix
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -95,15 +99,17 @@ class MambaModel(nn.Module):
org_num_embeddings
=
config
.
vocab_size
,
)
decoder_layers
=
[]
for
i
in
range
(
config
.
num_hidden_layers
):
decoder_layers
.
append
(
MambaDecoderLayer
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
))
self
.
layers
=
nn
.
ModuleList
(
decoder_layers
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
MambaDecoderLayer
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embeddings
(
input_ids
)
...
...
@@ -114,29 +120,40 @@ class MambaModel(nn.Module):
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
len
(
self
.
layer
s
)
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_
layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
))
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
))
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm_f
(
hidden_states
,
residual
)
return
hidden_states
class
MambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
):
class
MambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -148,7 +165,9 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
super
().
__init__
()
self
.
config
=
config
self
.
vllm_config
=
vllm_config
self
.
scheduler_config
=
scheduler_config
self
.
model_config
=
vllm_config
.
model_config
self
.
backbone
=
MambaModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"backbone"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
...
...
@@ -174,6 +193,9 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
backbone
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
backbone
.
get_input_embeddings
(
input_ids
)
...
...
@@ -189,9 +211,12 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
max_batch_size
=
(
VllmConfig
.
get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
num_mamba_layers
=
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
lm_head
.
weight
.
dtype
,
self
.
config
.
num_hidden_layers
,
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
(
mamba_cache_tensors
,
...
...
@@ -204,7 +229,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
state_indices_tensor
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
mamba_cache_params
,
inputs_embeds
)
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
@@ -252,6 +278,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
...
...
vllm/model_executor/models/registry.py
View file @
ffa48c91
...
...
@@ -21,7 +21,7 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
.adapters
import
as_embedding_model
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
is_hybrid
,
supports_cross_encoding
,
supports_multimodal
,
supports_pp
)
from
.interfaces_base
import
is_pooling_model
,
is_text_generation_model
...
...
@@ -218,6 +218,7 @@ class _ModelInfo:
supports_pp
:
bool
has_inner_state
:
bool
is_attention_free
:
bool
is_hybrid
:
bool
@
staticmethod
def
from_model_cls
(
model
:
Type
[
nn
.
Module
])
->
"_ModelInfo"
:
...
...
@@ -239,6 +240,7 @@ class _ModelInfo:
supports_pp
=
supports_pp
(
model
),
has_inner_state
=
has_inner_state
(
model
),
is_attention_free
=
is_attention_free
(
model
),
is_hybrid
=
is_hybrid
(
model
),
)
...
...
@@ -484,6 +486,13 @@ class _ModelRegistry:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_attention_free
def
is_hybrid_model
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_hybrid
ModelRegistry
=
_ModelRegistry
({
model_arch
:
_LazyRegisteredModel
(
...
...
vllm/utils.py
View file @
ffa48c91
...
...
@@ -170,6 +170,11 @@ class Device(enum.Enum):
CPU
=
enum
.
auto
()
class
LayerBlockType
(
enum
.
Enum
):
attention
=
"attention"
mamba
=
"mamba"
class
Counter
:
def
__init__
(
self
,
start
:
int
=
0
)
->
None
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
ffa48c91
...
...
@@ -15,8 +15,8 @@ from vllm.logger import init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.sampling_params
import
SamplingType
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
cdiv
,
is_pin_memory_available
)
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
FlashAttentionMetadata
)
from
vllm.v1.outputs
import
ModelRunnerOutput
...
...
@@ -68,8 +68,8 @@ class GPUModelRunner:
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
# Model-related.
self
.
num_attn_layers
=
model_config
.
get_num_
attention_layers
(
parallel_config
)
self
.
num_attn_layers
=
model_config
.
get_num_
layers_by_block_type
(
parallel_config
,
LayerBlockType
.
attention
)
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
head_size
=
model_config
.
get_head_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
...
...
vllm/v1/worker/gpu_worker.py
View file @
ffa48c91
...
...
@@ -14,7 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
LayerBlockType
,
get_dtype_size
from
vllm.v1.core.scheduler
import
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
...
...
@@ -260,8 +260,8 @@ def _get_cache_block_size(
)
->
int
:
head_size
=
model_config
.
get_head_size
()
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
num_attention_layers
=
model_config
.
get_num_
attention_layers
(
parallel_config
)
num_attention_layers
=
model_config
.
get_num_
layers_by_block_type
(
parallel_config
,
LayerBlockType
.
attention
)
key_cache_block
=
cache_config
.
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
...
...
vllm/worker/cache_engine.py
View file @
ffa48c91
...
...
@@ -6,8 +6,8 @@ import torch
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_siz
e
,
is_pin_memory_available
)
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
LayerBlockTyp
e
,
get_dtype_size
,
is_pin_memory_available
)
logger
=
init_logger
(
__name__
)
...
...
@@ -34,8 +34,8 @@ class CacheEngine:
self
.
head_size
=
model_config
.
get_head_size
()
# Models like Jamba, have mixed typed layers, E.g Mamba
self
.
num_attention_layers
=
model_config
.
get_num_
attention_layers
(
parallel_config
)
self
.
num_attention_layers
=
model_config
.
get_num_
layers_by_block_type
(
parallel_config
,
LayerBlockType
.
attention
)
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
block_size
=
cache_config
.
block_size
...
...
@@ -105,8 +105,8 @@ class CacheEngine:
)
->
int
:
head_size
=
model_config
.
get_head_size
()
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
num_attention_layers
=
model_config
.
get_num_
attention_layers
(
parallel_config
)
num_attention_layers
=
model_config
.
get_num_
layers_by_block_type
(
parallel_config
,
LayerBlockType
.
attention
)
key_cache_block
=
cache_config
.
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment