Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c6b636f9
Unverified
Commit
c6b636f9
authored
May 23, 2025
by
Mark McLoughlin
Committed by
GitHub
May 23, 2025
Browse files
[V1][Spec Decoding] Use model_loader.get_model() to load models (#18273)
Signed-off-by:
Mark McLoughlin
<
markmc@redhat.com
>
parent
04eb88dc
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
59 additions
and
135 deletions
+59
-135
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+6
-52
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+10
-3
vllm/model_executor/model_loader/base_loader.py
vllm/model_executor/model_loader/base_loader.py
+2
-1
vllm/model_executor/model_loader/bitsandbytes_loader.py
vllm/model_executor/model_loader/bitsandbytes_loader.py
+2
-3
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+4
-3
vllm/model_executor/model_loader/dummy_loader.py
vllm/model_executor/model_loader/dummy_loader.py
+2
-2
vllm/model_executor/model_loader/gguf_loader.py
vllm/model_executor/model_loader/gguf_loader.py
+2
-2
vllm/model_executor/model_loader/runai_streamer_loader.py
vllm/model_executor/model_loader/runai_streamer_loader.py
+2
-3
vllm/model_executor/model_loader/sharded_state_loader.py
vllm/model_executor/model_loader/sharded_state_loader.py
+2
-2
vllm/model_executor/model_loader/tensorizer_loader.py
vllm/model_executor/model_loader/tensorizer_loader.py
+2
-2
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+3
-1
vllm/model_executor/models/llama_eagle.py
vllm/model_executor/models/llama_eagle.py
+4
-2
vllm/model_executor/models/llama_eagle3.py
vllm/model_executor/models/llama_eagle3.py
+6
-5
vllm/model_executor/models/medusa.py
vllm/model_executor/models/medusa.py
+1
-4
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+6
-32
vllm/v1/spec_decode/medusa.py
vllm/v1/spec_decode/medusa.py
+5
-18
No files found.
tests/v1/spec_decode/test_eagle.py
View file @
c6b636f9
...
...
@@ -117,34 +117,13 @@ def test_prepare_inputs():
])
@
mock
.
patch
(
'vllm.v1.spec_decode.eagle.get_pp_group'
)
@
mock
.
patch
(
'vllm.v1.spec_decode.eagle.get_layers_from_vllm_config'
)
@
mock
.
patch
(
'vllm.v1.spec_decode.eagle.ModelRegistry'
)
@
mock
.
patch
(
'vllm.v1.spec_decode.eagle.get_model_loader'
)
@
mock
.
patch
(
'vllm.v1.spec_decode.eagle.set_default_torch_dtype'
)
@
mock
.
patch
(
'vllm.v1.spec_decode.eagle.set_current_vllm_config'
)
def
test_load_model
(
mock_set_config
,
mock_set_dtype
,
mock_get_loader
,
mock_registry
,
mock_get_layers
,
mock_get_pp_group
,
method
,
@
mock
.
patch
(
'vllm.v1.spec_decode.eagle.get_model'
)
def
test_load_model
(
mock_get_model
,
mock_get_layers
,
mock_get_pp_group
,
method
,
proposer_helper
,
draft_model_dir
,
target_attribute_path
):
# Setup mock for model class
mock_model_cls
=
mock
.
MagicMock
()
mock_registry
.
resolve_model_cls
.
return_value
=
(
mock_model_cls
,
"test_arch"
)
# Create a real context manager for mocks
class
MockContextManager
:
def
__init__
(
self
):
pass
def
__enter__
(
self
):
return
None
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
return
False
# Make the mocks return actual context manager objects
mock_set_dtype
.
return_value
=
MockContextManager
()
mock_set_config
.
return_value
=
MockContextManager
()
# Setup model mock
mock_model
=
mock
.
MagicMock
()
mock_get_model
.
return_value
=
mock_model
# Setup mocks for attention layers
target_attn_layers
=
{
...
...
@@ -164,25 +143,6 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
mock_pp_group
.
world_size
=
2
if
method
==
"eagle"
else
1
mock_get_pp_group
.
return_value
=
mock_pp_group
# Setup model loader mock
mock_loader
=
mock
.
MagicMock
()
mock_get_loader
.
return_value
=
mock_loader
# Setup model mock
mock_model
=
mock
.
MagicMock
()
mock_model_cls
.
return_value
=
mock_model
mock_model
.
to
.
return_value
=
mock_model
# Configure mock to test the attribute sharing path
if
method
==
"eagle"
:
# For eagle, test the lm_head path
mock_model
.
load_weights
.
return_value
=
{
"model.embed_tokens.weight"
:
torch
.
zeros
(
1
)
}
else
:
# For eagle3, test the embed_tokens path
mock_model
.
load_weights
.
return_value
=
{}
# Setup target model with the appropriate attributes
target_model
=
mock
.
MagicMock
()
...
...
@@ -204,13 +164,7 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
proposer
.
load_model
(
target_model
)
# Verify common interactions
mock_get_loader
.
assert_called_once
()
mock_model_cls
.
assert_called_once
()
mock_model
.
to
.
assert_called_once
()
mock_model
.
load_weights
.
assert_called_once
()
# Verify the loader was called with the right config
mock_get_loader
.
assert_called_once_with
(
proposer
.
vllm_config
.
load_config
)
mock_get_model
.
assert_called_once
()
# Verify the specific attribute sharing based on the method
if
method
==
"eagle"
:
...
...
vllm/model_executor/model_loader/__init__.py
View file @
c6b636f9
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
from
torch
import
nn
from
vllm.config
import
LoadConfig
,
LoadFormat
,
VllmConfig
from
vllm.config
import
LoadConfig
,
LoadFormat
,
ModelConfig
,
VllmConfig
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.bitsandbytes_loader
import
(
BitsAndBytesModelLoader
)
...
...
@@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
return
DefaultModelLoader
(
load_config
)
def
get_model
(
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
get_model
(
*
,
vllm_config
:
VllmConfig
,
model_config
:
Optional
[
ModelConfig
]
=
None
)
->
nn
.
Module
:
loader
=
get_model_loader
(
vllm_config
.
load_config
)
return
loader
.
load_model
(
vllm_config
=
vllm_config
)
if
model_config
is
None
:
model_config
=
vllm_config
.
model_config
return
loader
.
load_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
)
__all__
=
[
...
...
vllm/model_executor/model_loader/base_loader.py
View file @
c6b636f9
...
...
@@ -18,6 +18,7 @@ class BaseModelLoader(ABC):
raise
NotImplementedError
@
abstractmethod
def
load_model
(
self
,
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
*
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
raise
NotImplementedError
vllm/model_executor/model_loader/bitsandbytes_loader.py
View file @
c6b636f9
...
...
@@ -569,10 +569,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
...
...
vllm/model_executor/model_loader/default_loader.py
View file @
c6b636f9
...
...
@@ -264,13 +264,14 @@ class DefaultModelLoader(BaseModelLoader):
fall_back_to_pt
=
True
,
allow_patterns_overrides
=
None
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
)
model
=
initialize_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
)
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
loaded_weights
=
model
.
load_weights
(
...
...
vllm/model_executor/model_loader/dummy_loader.py
View file @
c6b636f9
...
...
@@ -22,9 +22,9 @@ class DummyModelLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
# Nothing to download
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
...
...
vllm/model_executor/model_loader/gguf_loader.py
View file @
c6b636f9
...
...
@@ -92,9 +92,9 @@ class GGUFModelLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
# we can only know if tie word embeddings after mapping weights
...
...
vllm/model_executor/model_loader/runai_streamer_loader.py
View file @
c6b636f9
...
...
@@ -100,11 +100,10 @@ class RunaiModelStreamerLoader(BaseModelLoader):
"""Download model if necessary"""
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
"""Perform streaming of the model to destination"""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
...
...
vllm/model_executor/model_loader/sharded_state_loader.py
View file @
c6b636f9
...
...
@@ -100,9 +100,9 @@ class ShardedStateLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
from
vllm.distributed
import
get_tensor_model_parallel_rank
...
...
vllm/model_executor/model_loader/tensorizer_loader.py
View file @
c6b636f9
...
...
@@ -93,8 +93,8 @@ class TensorizerLoader(BaseModelLoader):
with
self
.
tensorizer_config
.
open_stream
():
pass
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
model_config
=
vllm_config
.
model_config
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
parallel_config
=
vllm_config
.
parallel_config
self
.
_verify_config
(
model_config
,
parallel_config
)
...
...
vllm/model_executor/model_loader/utils.py
View file @
c6b636f9
...
...
@@ -42,9 +42,11 @@ def initialize_model(
*
,
prefix
:
str
=
""
,
model_class
:
Optional
[
type
[
nn
.
Module
]]
=
None
,
model_config
:
Optional
[
ModelConfig
]
=
None
,
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_config
=
vllm_config
.
model_config
if
model_config
is
None
:
model_config
=
vllm_config
.
model_config
if
model_class
is
None
:
model_class
,
_
=
get_model_architecture
(
model_config
)
...
...
vllm/model_executor/models/llama_eagle.py
View file @
c6b636f9
...
...
@@ -130,13 +130,15 @@ class LlamaModel(nn.Module):
class
EagleLlamaForCausalLM
(
LlamaForCausalLM
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
start_layer_id
:
int
=
0
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
vllm_config
.
\
speculative_config
.
draft_model_config
.
hf_config
target_layer_num
=
vllm_config
.
model_config
.
get_num_layers
(
vllm_config
.
parallel_config
)
self
.
model
=
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
"model"
,
start_layer_id
=
s
tart_layer_
id
)
start_layer_id
=
tar
ge
t_layer_
num
)
logit_scale
=
getattr
(
self
.
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
.
vocab_size
,
...
...
vllm/model_executor/models/llama_eagle3.py
View file @
c6b636f9
...
...
@@ -175,13 +175,15 @@ class LlamaModel(nn.Module):
class
Eagle3LlamaForCausalLM
(
LlamaForCausalLM
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
start_layer_id
:
int
=
0
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
vllm_config
.
\
speculative_config
.
draft_model_config
.
hf_config
target_layer_num
=
vllm_config
.
model_config
.
get_num_layers
(
vllm_config
.
parallel_config
)
self
.
model
=
LlamaModel
(
vllm_config
=
vllm_config
,
start_layer_id
=
start_layer_id
,
prefix
=
"model"
)
prefix
=
"model"
,
start_layer_id
=
target_layer_num
)
logit_scale
=
getattr
(
self
.
config
,
"logit_scale"
,
1.0
)
self
.
lm_head
=
ParallelLMHead
(
...
...
@@ -193,8 +195,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
.
draft_vocab_size
,
scale
=
logit_scale
)
self
.
draft_id_to_target_id
=
nn
.
Parameter
(
torch
.
zeros
((
self
.
config
.
draft_vocab_size
),
dtype
=
torch
.
long
).
type
(
torch
.
LongTensor
),
torch
.
zeros
(
self
.
config
.
draft_vocab_size
,
dtype
=
torch
.
long
),
requires_grad
=
False
,
)
...
...
vllm/model_executor/models/medusa.py
View file @
c6b636f9
...
...
@@ -51,10 +51,7 @@ class Medusa(nn.Module):
needs to have truncated_vocab_size (=k) as an attribute."""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
if
hasattr
(
vllm_config
,
'draft_model_config'
):
config
=
vllm_config
.
draft_model_config
.
hf_config
else
:
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
speculative_config
.
draft_model_config
.
hf_config
super
().
__init__
()
self
.
config
=
config
self
.
blocks
=
nn
.
ModuleList
([
...
...
vllm/v1/spec_decode/eagle.py
View file @
c6b636f9
...
...
@@ -4,14 +4,11 @@ import torch.nn as nn
from
vllm.attention.layer
import
Attention
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
,
set_current_vllm_config
)
get_layers_from_vllm_config
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader.utils
import
(
process_weights_after_loading
,
set_default_torch_dtype
)
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
...
...
@@ -280,51 +277,28 @@ class EagleProposer:
return
cu_num_tokens
,
token_indices
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
loader
=
get_model_loader
(
self
.
vllm_config
.
load_config
)
target_layer_num
=
self
.
vllm_config
.
model_config
.
get_num_layers
(
self
.
vllm_config
.
parallel_config
)
draft_model_config
=
\
self
.
vllm_config
.
speculative_config
.
draft_model_config
target_attn_layer_names
=
set
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
())
draft_model_config
=
\
self
.
vllm_config
.
speculative_config
.
draft_model_config
# FIXME(lily): This does not handle with distributed inference.
target_device
=
self
.
vllm_config
.
device_config
.
device
# We need to set the vllm_config here to register attention
# layers in the forward context.
with
set_default_torch_dtype
(
draft_model_config
.
dtype
),
set_current_vllm_config
(
self
.
vllm_config
):
draft_model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
draft_model_config
.
architectures
)
self
.
model
=
draft_model_cls
(
vllm_config
=
self
.
vllm_config
,
start_layer_id
=
target_layer_num
).
to
(
target_device
)
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
model_config
=
draft_model_config
)
draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
target_attn_layer_names
)
assert
len
(
draft_attn_layer_names
)
==
1
self
.
attn_layer_name
=
next
(
iter
(
draft_attn_layer_names
))
loaded_weights
=
self
.
model
.
load_weights
(
loader
.
get_all_weights
(
draft_model_config
,
self
.
model
))
process_weights_after_loading
(
self
.
model
,
draft_model_config
,
target_device
)
# share embed_tokens with the target model if needed
if
get_pp_group
().
world_size
==
1
:
assert
"model.embed_tokens.weight"
not
in
loaded_weights
,
\
"For PP = 1, Eagle draft should share embed with target model"
logger
.
info
(
"The EAGLE head shares the same vocab embedding"
\
" with the target model."
)
self
.
model
.
model
.
embed_tokens
=
target_model
.
model
.
embed_tokens
else
:
assert
"model.embed_tokens.weight"
in
loaded_weights
,
\
"For PP > 1, Eagle draft checkpoint should its own copy of "
" the model.embed_tokens.weight"
logger
.
info
(
"Since PP > 1, the EAGLE head loaded its own vocab embedding"
\
" weights instead of sharing them with the target model."
...
...
vllm/v1/spec_decode/medusa.py
View file @
c6b636f9
...
...
@@ -3,12 +3,10 @@
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.models.medusa
import
Medusa
from
vllm.model_executor.model_loader
import
get_model
from
vllm.v1.sample.metadata
import
SamplingMetadata
# Initialize logger
...
...
@@ -49,20 +47,9 @@ class MedusaProposer:
return
[
list
(
row
)
for
row
in
zip
(
*
draft_tokens
)]
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
# Get model loader and config
loader
=
get_model_loader
(
self
.
vllm_config
.
load_config
)
draft_config
=
self
.
vllm_config
.
speculative_config
.
draft_model_config
# Load model with proper dtype and config
with
set_default_torch_dtype
(
draft_config
.
dtype
),
\
set_current_vllm_config
(
self
.
vllm_config
):
self
.
model
=
Medusa
(
vllm_config
=
self
.
vllm_config
.
speculative_config
).
to
(
self
.
device
)
# Load model weights
weights
=
loader
.
get_all_weights
(
draft_config
,
self
.
model
)
self
.
model
.
load_weights
(
weights
)
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
model_config
=
self
.
vllm_config
.
speculative_config
.
draft_model_config
)
@
torch
.
inference_mode
()
def
dummy_run
(
self
,
num_tokens
:
int
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment