Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
18bd7587
Unverified
Commit
18bd7587
authored
Nov 01, 2024
by
youkaichao
Committed by
GitHub
Nov 01, 2024
Browse files
[1/N] pass the complete config from engine to executor (#9933)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
598b6d7b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
65 additions
and
137 deletions
+65
-137
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+21
-29
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+1
-6
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+13
-24
vllm/executor/xpu_executor.py
vllm/executor/xpu_executor.py
+8
-36
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+21
-41
No files found.
vllm/engine/async_llm_engine.py
View file @
18bd7587
...
...
@@ -680,7 +680,7 @@ class AsyncLLMEngine(EngineClient):
# Create the async LLM engine.
engine
=
cls
(
**
engine_config
.
to_dict
()
,
vllm_config
=
engine_config
,
executor_class
=
executor_class
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_stats
=
not
engine_args
.
disable_log_stats
,
...
...
vllm/engine/llm_engine.py
View file @
18bd7587
...
...
@@ -13,11 +13,8 @@ import torch
from
typing_extensions
import
TypeIs
,
TypeVar
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.core.scheduler
import
(
ScheduledSequenceGroup
,
Scheduler
,
SchedulerOutputs
)
from
vllm.engine.arg_utils
import
EngineArgs
...
...
@@ -222,17 +219,7 @@ class LLMEngine:
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
decoding_config
:
Optional
[
DecodingConfig
],
observability_config
:
Optional
[
ObservabilityConfig
],
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
],
vllm_config
:
EngineConfig
,
executor_class
:
Type
[
ExecutorBase
],
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
...
...
@@ -240,6 +227,22 @@ class LLMEngine:
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
)
->
None
:
# TODO: remove the local variables and use self.* throughout the class.
model_config
=
self
.
model_config
=
vllm_config
.
model_config
cache_config
=
self
.
cache_config
=
vllm_config
.
cache_config
lora_config
=
self
.
lora_config
=
vllm_config
.
lora_config
parallel_config
=
self
.
parallel_config
=
vllm_config
.
parallel_config
scheduler_config
=
self
.
scheduler_config
=
vllm_config
.
scheduler_config
device_config
=
self
.
device_config
=
vllm_config
.
device_config
speculative_config
=
self
.
speculative_config
=
vllm_config
.
speculative_config
# noqa
load_config
=
self
.
load_config
=
vllm_config
.
load_config
decoding_config
=
self
.
decoding_config
=
vllm_config
.
decoding_config
or
DecodingConfig
(
# noqa
)
prompt_adapter_config
=
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
# noqa
observability_config
=
self
.
observability_config
=
vllm_config
.
observability_config
or
ObservabilityConfig
(
# noqa
)
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
...
...
@@ -340,18 +343,7 @@ class LLMEngine:
self
.
input_processor
=
input_registry
.
create_input_processor
(
model_config
)
self
.
model_executor
=
executor_class
(
model_config
=
model_config
,
cache_config
=
cache_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
speculative_config
=
speculative_config
,
load_config
=
load_config
,
prompt_adapter_config
=
prompt_adapter_config
,
observability_config
=
self
.
observability_config
,
)
self
.
model_executor
=
executor_class
(
vllm_config
=
vllm_config
,
)
if
self
.
model_config
.
task
!=
"embedding"
:
self
.
_initialize_kv_caches
()
...
...
@@ -582,7 +574,7 @@ class LLMEngine:
executor_class
=
cls
.
_get_executor_cls
(
engine_config
)
# Create the LLM engine.
engine
=
cls
(
**
engine_config
.
to_dict
()
,
vllm_config
=
engine_config
,
executor_class
=
executor_class
,
log_stats
=
not
engine_args
.
disable_log_stats
,
usage_context
=
usage_context
,
...
...
vllm/engine/multiprocessing/engine.py
View file @
18bd7587
...
...
@@ -7,8 +7,6 @@ import cloudpickle
import
zmq
from
vllm
import
AsyncEngineArgs
,
SamplingParams
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.engine.multiprocessing
import
(
ENGINE_DEAD_ERROR
,
IPC_DATA_EXT
,
...
...
@@ -30,9 +28,6 @@ if VLLM_USE_V1:
else
:
from
vllm.engine.llm_engine
import
LLMEngine
CONFIG_TYPE
=
Union
[
ModelConfig
,
DecodingConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
]
logger
=
init_logger
(
__name__
)
POLLING_TIMEOUT_MS
=
10000
...
...
@@ -130,7 +125,7 @@ class MQLLMEngine:
return
cls
(
ipc_path
=
ipc_path
,
use_async_sockets
=
use_async_sockets
,
**
engine_config
.
to_dict
()
,
vllm_config
=
engine_config
,
executor_class
=
executor_class
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_stats
=
not
engine_args
.
disable_log_stats
,
...
...
vllm/executor/executor_base.py
View file @
18bd7587
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.config
import
EngineConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
...
@@ -23,27 +20,19 @@ class ExecutorBase(ABC):
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
],
observability_config
:
Optional
[
ObservabilityConfig
],
vllm_config
:
EngineConfig
,
)
->
None
:
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
load_config
=
load_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
speculative_config
=
speculative_config
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
observability_config
=
observability_config
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
load_config
=
vllm_config
.
load_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device_config
=
vllm_config
.
device_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
_init_executor
()
@
abstractmethod
...
...
vllm/executor/xpu_executor.py
View file @
18bd7587
...
...
@@ -2,10 +2,7 @@ from typing import Callable, List, Optional, Tuple, Type, Union
import
torch
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.config
import
ModelConfig
,
ParallelConfig
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
...
...
@@ -21,38 +18,13 @@ class XPUExecutor(GPUExecutor):
uses_ray
:
bool
=
False
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
observability_config
:
Optional
[
ObservabilityConfig
],
)
->
None
:
assert
device_config
.
device_type
==
"xpu"
assert
(
not
speculative_config
),
"Speculative decoding not yet supported for XPU backend"
model_config
=
_verify_and_get_model_config
(
model_config
)
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
load_config
=
load_config
self
.
lora_config
=
lora_config
self
.
parallel_config
=
_verify_and_get_parallel_config
(
parallel_config
)
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
speculative_config
=
None
self
.
observability_config
=
observability_config
# Instantiate the worker and load the model to GPU.
self
.
_init_executor
()
def
_init_executor
(
self
)
->
None
:
assert
self
.
device_config
.
device_type
==
"xpu"
assert
self
.
speculative_config
is
None
,
(
"Speculative decoding not yet supported for XPU backend"
)
self
.
model_config
=
_verify_and_get_model_config
(
self
.
model_config
)
GPUExecutor
.
_init_executor
(
self
)
def
_get_worker_module_and_class
(
self
)
->
Tuple
[
str
,
str
,
Optional
[
Callable
[[],
Type
[
WorkerBase
]]]]:
...
...
vllm/v1/engine/llm_engine.py
View file @
18bd7587
...
...
@@ -2,11 +2,8 @@ import time
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
...
...
@@ -35,17 +32,7 @@ class LLMEngine:
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
decoding_config
:
Optional
[
DecodingConfig
],
observability_config
:
Optional
[
ObservabilityConfig
],
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
],
vllm_config
:
EngineConfig
,
executor_class
:
Type
[
GPUExecutor
],
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
...
...
@@ -53,6 +40,22 @@ class LLMEngine:
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
)
->
None
:
# TODO: remove the local variables and use self.* throughout the class.
model_config
=
self
.
model_config
=
vllm_config
.
model_config
cache_config
=
self
.
cache_config
=
vllm_config
.
cache_config
lora_config
=
self
.
lora_config
=
vllm_config
.
lora_config
parallel_config
=
self
.
parallel_config
=
vllm_config
.
parallel_config
scheduler_config
=
self
.
scheduler_config
=
vllm_config
.
scheduler_config
device_config
=
self
.
device_config
=
vllm_config
.
device_config
speculative_config
=
self
.
speculative_config
=
vllm_config
.
speculative_config
# noqa
load_config
=
self
.
load_config
=
vllm_config
.
load_config
decoding_config
=
self
.
decoding_config
=
vllm_config
.
decoding_config
or
DecodingConfig
(
# noqa
)
prompt_adapter_config
=
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
# noqa
observability_config
=
self
.
observability_config
=
vllm_config
.
observability_config
or
ObservabilityConfig
(
# noqa
)
# Override the configs for V1.
# FIXME
if
usage_context
==
UsageContext
.
LLM_CLASS
:
...
...
@@ -112,18 +115,6 @@ class LLMEngine:
model_config
.
mm_processor_kwargs
,
)
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
speculative_config
=
speculative_config
self
.
load_config
=
load_config
self
.
decoding_config
=
decoding_config
or
DecodingConfig
()
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
)
self
.
log_stats
=
log_stats
assert
not
self
.
model_config
.
skip_tokenizer_init
...
...
@@ -154,18 +145,7 @@ class LLMEngine:
# Request id -> RequestOutput
self
.
request_outputs
:
Dict
[
str
,
RequestOutput
]
=
{}
self
.
model_executor
=
executor_class
(
model_config
=
model_config
,
cache_config
=
cache_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
speculative_config
=
speculative_config
,
load_config
=
load_config
,
prompt_adapter_config
=
prompt_adapter_config
,
observability_config
=
self
.
observability_config
,
)
self
.
model_executor
=
executor_class
(
vllm_config
=
vllm_config
)
assert
self
.
model_config
.
task
!=
"embedding"
self
.
_initialize_kv_caches
()
...
...
@@ -203,7 +183,7 @@ class LLMEngine:
executor_class
=
cls
.
_get_executor_cls
(
engine_config
)
# Create the LLM engine.
engine
=
cls
(
**
engine_config
.
to_dict
()
,
vllm_config
=
engine_config
,
executor_class
=
executor_class
,
log_stats
=
not
engine_args
.
disable_log_stats
,
usage_context
=
usage_context
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment