Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5757d90e
Unverified
Commit
5757d90e
authored
Apr 02, 2024
by
Cade Daniel
Committed by
GitHub
Apr 03, 2024
Browse files
[Speculative decoding] Adding configuration object for speculative decoding (#3706)
Co-authored-by:
Lily Liu
<
lilyliupku@gmail.com
>
parent
a3c226e7
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
394 additions
and
61 deletions
+394
-61
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+41
-0
tests/spec_decode/e2e/test_correctness.py
tests/spec_decode/e2e/test_correctness.py
+50
-0
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+8
-10
tests/worker/test_swap.py
tests/worker/test_swap.py
+8
-9
vllm/config.py
vllm/config.py
+187
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+44
-11
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+9
-10
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+28
-16
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+3
-1
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+6
-1
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+5
-1
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+5
-1
No files found.
tests/spec_decode/e2e/conftest.py
0 → 100644
View file @
5757d90e
import
pytest
from
tests.conftest
import
cleanup
from
vllm
import
LLM
from
vllm.model_executor.utils
import
set_random_seed
@
pytest
.
fixture
def
baseline_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
seed
):
return
create_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
seed
)
@
pytest
.
fixture
def
test_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
seed
):
return
create_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
seed
)
def
create_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
distinct_llm_kwargs
,
seed
):
kwargs
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
distinct_llm_kwargs
,
}
def
generator_inner
():
llm
=
LLM
(
**
kwargs
)
set_random_seed
(
seed
)
yield
llm
del
llm
cleanup
()
for
llm
in
generator_inner
():
yield
llm
del
llm
tests/spec_decode/e2e/test_correctness.py
0 → 100644
View file @
5757d90e
import
pytest
from
vllm
import
SamplingParams
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
"model"
:
"facebook/opt-125m"
,
"speculative_model"
:
"facebook/opt-125m"
,
"num_speculative_tokens"
:
5
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_config
(
test_llm_generator
):
output_len
=
1024
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
with
pytest
.
raises
(
AssertionError
,
match
=
"Speculative decoding not yet supported for GPU backend"
):
get_token_ids_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
def
get_token_ids_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
):
for
llm
in
llm_generator
:
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
token_ids
=
[
output
.
outputs
[
0
].
token_ids
for
output
in
outputs
]
del
llm
return
token_ids
tests/spec_decode/utils.py
View file @
5757d90e
...
...
@@ -107,18 +107,16 @@ def create_worker(cls: type,
block_size
=
block_size
,
enforce_eager
=
enforce_eager
,
)
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
device_config
,
_
,
_
)
=
engine_args
.
create_engine_configs
()
engine_config
=
engine_args
.
create_engine_config
()
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
worker
=
cls
(
model_config
=
model_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
device_config
=
device_config
,
model_config
=
engine_config
.
model_config
,
parallel_config
=
engine_config
.
parallel_config
,
scheduler_config
=
engine_config
.
scheduler_config
,
device_config
=
engine_config
.
device_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
...
...
@@ -128,9 +126,9 @@ def create_worker(cls: type,
worker
.
init_device
()
worker
.
load_model
()
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
cache_config
.
num_cpu_blocks
=
0
worker
.
init_cache_engine
(
cache_config
)
engine_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
engine_config
.
cache_config
.
num_cpu_blocks
=
0
worker
.
init_cache_engine
(
engine_config
.
cache_config
)
worker
.
warm_up_model
()
return
worker
...
...
tests/worker/test_swap.py
View file @
5757d90e
...
...
@@ -10,19 +10,18 @@ def test_swap() -> None:
engine_args
=
EngineArgs
(
model
=
"facebook/opt-125m"
,
dtype
=
"half"
,
load_format
=
"dummy"
)
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
device_config
,
_
,
_
)
=
engine_args
.
create_engine_configs
()
cache_config
.
num_gpu_blocks
=
100
cache_config
.
num_cpu_blocks
=
100
engine_config
=
engine_args
.
create_engine_config
()
engine_config
.
cache_config
.
num_gpu_blocks
=
100
engine_config
.
cache_config
.
num_cpu_blocks
=
100
# Create the worker.
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
worker
=
Worker
(
model_config
=
model_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
device_config
=
device_config
,
model_config
=
engine_config
.
model_config
,
parallel_config
=
engine_config
.
parallel_config
,
scheduler_config
=
engine_config
.
scheduler_config
,
device_config
=
engine_config
.
device_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
...
...
@@ -32,7 +31,7 @@ def test_swap() -> None:
# Initialize the worker.
worker
.
init_device
()
worker
.
load_model
()
worker
.
init_cache_engine
(
cache_config
)
worker
.
init_cache_engine
(
engine_config
.
cache_config
)
worker
.
warm_up_model
()
# Randomly initialize the cache.
...
...
vllm/config.py
View file @
5757d90e
import
enum
import
json
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
fields
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Optional
,
Union
import
torch
...
...
@@ -617,6 +617,159 @@ class DeviceConfig:
self
.
device
=
torch
.
device
(
self
.
device_type
)
class
SpeculativeConfig
:
"""Configuration for speculative decoding.
The configuration is currently specialized to draft-model speculative
decoding with top-1 proposals.
"""
@
staticmethod
def
maybe_create_spec_config
(
target_model_config
:
ModelConfig
,
target_parallel_config
:
ParallelConfig
,
target_dtype
:
str
,
speculative_model
:
Optional
[
str
],
num_speculative_tokens
:
Optional
[
int
],
)
->
Optional
[
"SpeculativeConfig"
]:
"""Create a SpeculativeConfig if possible, else return None.
This function attempts to create a SpeculativeConfig object based on the
provided parameters. If the necessary conditions are met, it returns an
instance of SpeculativeConfig. Otherwise, it returns None.
Args:
target_model_config (ModelConfig): The configuration of the target
model.
target_parallel_config (ParallelConfig): The parallel configuration
for the target model.
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
model, if provided.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
"""
if
(
speculative_model
is
None
and
num_speculative_tokens
is
None
):
return
None
if
speculative_model
is
not
None
and
num_speculative_tokens
is
None
:
raise
ValueError
(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f
"
{
speculative_model
=
}
and
{
num_speculative_tokens
=
}
."
)
# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
draft_revision
=
None
draft_code_revision
=
None
draft_quantization
=
None
draft_max_model_len
=
None
draft_model_config
=
ModelConfig
(
model
=
speculative_model
,
tokenizer
=
target_model_config
.
tokenizer
,
tokenizer_mode
=
target_model_config
.
tokenizer_mode
,
trust_remote_code
=
target_model_config
.
trust_remote_code
,
download_dir
=
target_model_config
.
download_dir
,
load_format
=
target_model_config
.
load_format
,
dtype
=
target_model_config
.
dtype
,
seed
=
target_model_config
.
seed
,
revision
=
draft_revision
,
code_revision
=
draft_code_revision
,
tokenizer_revision
=
target_model_config
.
tokenizer_revision
,
max_model_len
=
draft_max_model_len
,
quantization
=
draft_quantization
,
enforce_eager
=
target_model_config
.
enforce_eager
,
max_context_len_to_capture
=
target_model_config
.
max_context_len_to_capture
,
max_logprobs
=
target_model_config
.
max_logprobs
,
)
draft_parallel_config
=
(
SpeculativeConfig
.
create_draft_parallel_config
(
target_parallel_config
))
return
SpeculativeConfig
(
draft_model_config
,
draft_parallel_config
,
num_speculative_tokens
,
)
@
staticmethod
def
create_draft_parallel_config
(
target_parallel_config
:
ParallelConfig
)
->
ParallelConfig
:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config. In the future the
draft worker can have a different parallel strategy, e.g. TP=1.
"""
draft_parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
target_parallel_config
.
pipeline_parallel_size
,
tensor_parallel_size
=
target_parallel_config
.
tensor_parallel_size
,
worker_use_ray
=
target_parallel_config
.
worker_use_ray
,
max_parallel_loading_workers
=
target_parallel_config
.
max_parallel_loading_workers
,
disable_custom_all_reduce
=
target_parallel_config
.
disable_custom_all_reduce
,
tokenizer_pool_config
=
target_parallel_config
.
tokenizer_pool_config
,
ray_workers_use_nsight
=
target_parallel_config
.
ray_workers_use_nsight
,
placement_group
=
target_parallel_config
.
placement_group
,
)
return
draft_parallel_config
def
__init__
(
self
,
draft_model_config
:
ModelConfig
,
draft_parallel_config
:
ParallelConfig
,
num_speculative_tokens
:
int
,
):
"""Create a SpeculativeConfig object.
Args:
draft_model_config: ModelConfig for the draft model.
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
if
self
.
num_speculative_tokens
<=
0
:
raise
ValueError
(
"Expected num_speculative_tokens to be greater "
f
"than zero (
{
self
.
num_speculative_tokens
}
)."
)
if
self
.
draft_model_config
:
self
.
draft_model_config
.
verify_with_parallel_config
(
self
.
draft_parallel_config
)
@
property
def
num_lookahead_slots
(
self
)
->
int
:
"""The number of additional slots the scheduler should allocate per
step, in addition to the slots allocated for each known token.
This is equal to the number of speculative tokens, as each speculative
token must be scored.
"""
return
self
.
num_speculative_tokens
def
__repr__
(
self
)
->
str
:
draft_model
=
self
.
draft_model_config
.
model
num_spec_tokens
=
self
.
num_speculative_tokens
return
f
"SpeculativeConfig(
{
draft_model
=
}
,
{
num_spec_tokens
=
}
)"
@
dataclass
class
LoRAConfig
:
max_lora_rank
:
int
...
...
@@ -838,3 +991,36 @@ def _get_and_verify_max_len(
"to incorrect model outputs or CUDA errors. Make sure the "
"value is correct and within the model context size."
)
return
int
(
max_model_len
)
@
dataclass
(
frozen
=
True
)
class
EngineConfig
:
"""Dataclass which contains all engine-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
model_config
:
ModelConfig
cache_config
:
CacheConfig
parallel_config
:
ParallelConfig
scheduler_config
:
SchedulerConfig
device_config
:
DeviceConfig
lora_config
:
Optional
[
LoRAConfig
]
vision_language_config
:
Optional
[
VisionLanguageConfig
]
speculative_config
:
Optional
[
SpeculativeConfig
]
def
__post_init__
(
self
):
"""Verify configs are valid & consistent with each other.
"""
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
if
self
.
lora_config
:
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
self
.
lora_config
.
verify_with_scheduler_config
(
self
.
scheduler_config
)
def
to_dict
(
self
):
"""Return the configs as a dictionary, for use in **kwargs.
"""
return
dict
(
(
field
.
name
,
getattr
(
self
,
field
.
name
))
for
field
in
fields
(
self
))
vllm/engine/arg_utils.py
View file @
5757d90e
import
argparse
import
dataclasses
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
TokenizerPoolConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TokenizerPoolConfig
,
VisionLanguageConfig
)
from
vllm.utils
import
str_to_int_tuple
...
...
@@ -61,9 +62,14 @@ class EngineArgs:
image_token_id
:
Optional
[
int
]
=
None
image_input_shape
:
Optional
[
str
]
=
None
image_feature_size
:
Optional
[
int
]
=
None
scheduler_delay_factor
:
float
=
0.0
enable_chunked_prefill
:
bool
=
False
# Speculative decoding configuration.
speculative_model
:
Optional
[
str
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
self
.
tokenizer
=
self
.
model
...
...
@@ -371,6 +377,20 @@ class EngineArgs:
default
=
False
,
help
=
'If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens'
)
parser
.
add_argument
(
'--speculative-model'
,
type
=
str
,
default
=
None
,
help
=
'The name of the draft model to be used in speculative decoding.'
)
parser
.
add_argument
(
'--num-speculative-tokens'
,
type
=
int
,
default
=
None
,
help
=
'The number of speculative tokens to sample from '
'the draft model in speculative decoding'
)
return
parser
@
classmethod
...
...
@@ -381,11 +401,7 @@ class EngineArgs:
engine_args
=
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
return
engine_args
def
create_engine_configs
(
self
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
,
DeviceConfig
,
Optional
[
LoRAConfig
],
Optional
[
VisionLanguageConfig
]]:
def
create_engine_config
(
self
,
)
->
EngineConfig
:
device_config
=
DeviceConfig
(
self
.
device
)
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
...
...
@@ -409,12 +425,23 @@ class EngineArgs:
self
.
tokenizer_pool_type
,
self
.
tokenizer_pool_extra_config
,
),
self
.
ray_workers_use_nsight
)
speculative_config
=
SpeculativeConfig
.
maybe_create_spec_config
(
target_model_config
=
model_config
,
target_parallel_config
=
parallel_config
,
target_dtype
=
self
.
dtype
,
speculative_model
=
self
.
speculative_model
,
num_speculative_tokens
=
self
.
num_speculative_tokens
,
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
,
model_config
.
max_model_len
,
self
.
use_v2_block_manager
,
num_lookahead_slots
=
self
.
num_lookahead_slots
,
num_lookahead_slots
=
(
self
.
num_lookahead_slots
if
speculative_config
is
None
else
speculative_config
.
num_lookahead_slots
),
delay_factor
=
self
.
scheduler_delay_factor
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
)
...
...
@@ -442,8 +469,14 @@ class EngineArgs:
else
:
vision_language_config
=
None
return
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
device_config
,
lora_config
,
vision_language_config
)
return
EngineConfig
(
model_config
=
model_config
,
cache_config
=
cache_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
vision_language_config
=
vision_language_config
,
speculative_config
=
speculative_config
)
@
dataclass
...
...
vllm/engine/async_llm_engine.py
View file @
5757d90e
...
...
@@ -328,28 +328,27 @@ class AsyncLLMEngine:
)
->
"AsyncLLMEngine"
:
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
device_config
=
engine_configs
[
4
]
engine_config
=
engine_args
.
create_engine_config
()
if
device_config
.
device_type
==
"neuron"
:
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
raise
NotImplementedError
(
"Neuron is not supported for "
"async engine yet."
)
elif
parallel_config
.
worker_use_ray
or
engine_args
.
engine_use_ray
:
initialize_ray_cluster
(
parallel_config
)
elif
(
engine_config
.
parallel_config
.
worker_use_ray
or
engine_args
.
engine_use_ray
):
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutorAsync
executor_class
=
RayGPUExecutorAsync
else
:
assert
parallel_config
.
world_size
==
1
,
(
assert
engine_config
.
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
executor_class
=
GPUExecutorAsync
# Create the async LLM engine.
engine
=
cls
(
parallel_config
.
worker_use_ray
,
engine_config
.
parallel_config
.
worker_use_ray
,
engine_args
.
engine_use_ray
,
*
engine_config
s
,
executor_class
,
*
*
engine_config
.
to_dict
()
,
executor_class
=
executor_class
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_stats
=
not
engine_args
.
disable_log_stats
,
max_log_len
=
engine_args
.
max_log_len
,
...
...
vllm/engine/llm_engine.py
View file @
5757d90e
...
...
@@ -5,7 +5,8 @@ from transformers import PreTrainedTokenizer
import
vllm
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VisionLanguageConfig
)
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics
import
StatLogger
,
Stats
...
...
@@ -52,6 +53,11 @@ class LLMEngine:
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
vision_language_config (Optional): The configuration related to vision
language models.
speculative_config (Optional): The configuration related to speculative
decoding.
executor_class: The model executor class for managing distributed
execution.
log_stats: Whether to log statistics.
...
...
@@ -66,7 +72,8 @@ class LLMEngine:
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
"VisionLanguageConfig"
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
executor_class
:
Type
[
ExecutorBase
],
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
...
...
@@ -74,6 +81,7 @@ class LLMEngine:
logger
.
info
(
f
"Initializing an LLM engine (v
{
vllm
.
__version__
}
) with config: "
f
"model=
{
model_config
.
model
!
r
}
, "
f
"speculative_config=
{
speculative_config
!
r
}
, "
f
"tokenizer=
{
model_config
.
tokenizer
!
r
}
, "
f
"tokenizer_mode=
{
model_config
.
tokenizer_mode
}
, "
f
"revision=
{
model_config
.
revision
}
, "
...
...
@@ -100,17 +108,23 @@ class LLMEngine:
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
speculative_config
=
speculative_config
self
.
log_stats
=
log_stats
self
.
_verify_args
()
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
self
.
seq_counter
=
Counter
()
self
.
model_executor
=
executor_class
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
device_config
,
lora_config
,
vision_language_config
)
self
.
model_executor
=
executor_class
(
model_config
=
model_config
,
cache_config
=
cache_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
vision_language_config
=
vision_language_config
,
speculative_config
=
speculative_config
,
)
# If usage stat is enabled, collect relevant info.
if
is_usage_stats_enabled
():
...
...
@@ -171,30 +185,28 @@ class LLMEngine:
)
->
"LLMEngine"
:
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
device_config
=
engine_configs
[
4
]
engine_config
=
engine_args
.
create_engine_config
()
# Initialize the cluster and specify the executor class.
if
device_config
.
device_type
==
"neuron"
:
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutor
executor_class
=
NeuronExecutor
elif
device_config
.
device_type
==
"cpu"
:
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
from
vllm.executor.cpu_executor
import
CPUExecutor
executor_class
=
CPUExecutor
elif
parallel_config
.
worker_use_ray
:
initialize_ray_cluster
(
parallel_config
)
elif
engine_config
.
parallel_config
.
worker_use_ray
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutor
executor_class
=
RayGPUExecutor
else
:
assert
parallel_config
.
world_size
==
1
,
(
assert
engine_config
.
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
from
vllm.executor.gpu_executor
import
GPUExecutor
executor_class
=
GPUExecutor
# Create the LLM engine.
engine
=
cls
(
*
engine_config
s
,
*
*
engine_config
.
to_dict
()
,
executor_class
=
executor_class
,
log_stats
=
not
engine_args
.
disable_log_stats
,
usage_context
=
usage_context
,
...
...
vllm/executor/executor_base.py
View file @
5757d90e
...
...
@@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from
typing
import
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VisionLanguageConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
...
...
@@ -25,6 +26,7 @@ class ExecutorBase(ABC):
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
)
->
None
:
raise
NotImplementedError
...
...
vllm/executor/gpu_executor.py
View file @
5757d90e
from
typing
import
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VisionLanguageConfig
)
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.utils
import
check_block_size_valid
from
vllm.logger
import
init_logger
...
...
@@ -24,6 +25,7 @@ class GPUExecutor(ExecutorBase):
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
)
->
None
:
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
...
...
@@ -33,6 +35,9 @@ class GPUExecutor(ExecutorBase):
self
.
device_config
=
device_config
self
.
vision_language_config
=
vision_language_config
assert
(
not
speculative_config
),
"Speculative decoding not yet supported for GPU backend"
# Instantiate the worker and load the model to GPU.
self
.
_init_worker
()
...
...
vllm/executor/neuron_executor.py
View file @
5757d90e
from
typing
import
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VisionLanguageConfig
)
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
...
...
@@ -21,6 +22,7 @@ class NeuronExecutor(ExecutorBase):
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
)
->
None
:
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
...
...
@@ -28,6 +30,8 @@ class NeuronExecutor(ExecutorBase):
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
assert
(
not
speculative_config
),
"Speculative decoding not yet supported for Neuron backend."
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
...
...
vllm/executor/ray_gpu_executor.py
View file @
5757d90e
...
...
@@ -6,7 +6,8 @@ from collections import defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VisionLanguageConfig
)
from
vllm.engine.ray_utils
import
RayWorkerVllm
,
ray
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.utils
import
check_block_size_valid
...
...
@@ -41,6 +42,7 @@ class RayGPUExecutor(ExecutorBase):
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
)
->
None
:
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
...
...
@@ -49,6 +51,8 @@ class RayGPUExecutor(ExecutorBase):
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
vision_language_config
=
vision_language_config
assert
(
not
speculative_config
),
"Speculative decoding not yet supported for RayGPU backend."
assert
self
.
parallel_config
.
worker_use_ray
placement_group
=
self
.
parallel_config
.
placement_group
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment