Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bbf55c48
Unverified
Commit
bbf55c48
authored
Aug 17, 2024
by
Roger Wang
Committed by
GitHub
Aug 17, 2024
Browse files
[VLM] Refactor `MultiModalConfig` initialization and profiling (#7530)
parent
1ef13cf9
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
34 additions
and
61 deletions
+34
-61
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+2
-5
vllm/worker/embedding_model_runner.py
vllm/worker/embedding_model_runner.py
+2
-4
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+9
-14
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+9
-14
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+1
-4
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+1
-4
vllm/worker/utils.py
vllm/worker/utils.py
+1
-1
vllm/worker/worker.py
vllm/worker/worker.py
+2
-5
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+7
-10
No files found.
vllm/worker/cpu_worker.py
View file @
bbf55c48
...
@@ -7,8 +7,8 @@ import torch.distributed
...
@@ -7,8 +7,8 @@ import torch.distributed
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
get_attn_backend
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModa
lConfig
,
P
arallel
Config
,
ModelConfig
,
Paralle
lConfig
,
P
romptAdapter
Config
,
PromptAdapterConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -132,7 +132,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -132,7 +132,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
distributed_init_method
:
str
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
...
@@ -148,7 +147,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -148,7 +147,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self
.
distributed_init_method
=
distributed_init_method
self
.
distributed_init_method
=
distributed_init_method
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
multimodal_config
=
multimodal_config
self
.
is_driver_worker
=
is_driver_worker
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
...
@@ -173,7 +171,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -173,7 +171,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
cache_config
,
cache_config
,
load_config
=
self
.
load_config
,
load_config
=
self
.
load_config
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
multimodal_config
=
self
.
multimodal_config
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
prompt_adapter_config
=
self
.
prompt_adapter_config
,
prompt_adapter_config
=
self
.
prompt_adapter_config
,
is_driver_worker
=
is_driver_worker
)
is_driver_worker
=
is_driver_worker
)
...
...
vllm/worker/embedding_model_runner.py
View file @
bbf55c48
...
@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
...
@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import
torch
import
torch
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MultiModalInputs
from
vllm.multimodal
import
MultiModalInputs
...
@@ -44,7 +44,6 @@ class EmbeddingModelRunner(
...
@@ -44,7 +44,6 @@ class EmbeddingModelRunner(
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
):
):
super
().
__init__
(
model_config
,
super
().
__init__
(
model_config
,
...
@@ -57,7 +56,6 @@ class EmbeddingModelRunner(
...
@@ -57,7 +56,6 @@ class EmbeddingModelRunner(
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
,
is_driver_worker
=
is_driver_worker
,
prompt_adapter_config
=
prompt_adapter_config
,
prompt_adapter_config
=
prompt_adapter_config
,
multimodal_config
=
multimodal_config
,
observability_config
=
observability_config
)
observability_config
=
observability_config
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
vllm/worker/enc_dec_model_runner.py
View file @
bbf55c48
...
@@ -10,8 +10,8 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
...
@@ -10,8 +10,8 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend
,
get_global_forced_attn_backend
,
global_force_attn_backend
)
global_force_attn_backend
)
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
...
@@ -82,7 +82,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -82,7 +82,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
...
@@ -90,7 +89,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -90,7 +89,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
'''
'''
EncoderDecoderModelRunner constructor.
EncoderDecoderModelRunner constructor.
`lora_config`
, `multimodal_config`,
and prompt_adapter_config are
`lora_config` and
`
prompt_adapter_config
`
are
unused (since these features are not yet supported for encoder/decoder
unused (since these features are not yet supported for encoder/decoder
models) but these arguments are present here for compatibility with
models) but these arguments are present here for compatibility with
the base-class constructor.
the base-class constructor.
...
@@ -273,14 +272,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -273,14 +272,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# number of tokens equal to max_num_batched_tokens.
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
model_config
=
self
.
model_config
max_mm_tokens
=
self
.
mm_registry
.
get_max_multimodal_tokens
(
mm_config
=
self
.
multimodal_config
self
.
model_config
)
input_registry
=
self
.
input_registry
mm_registry
=
self
.
mm_registry
mm_registry
.
init_mm_limits_per_prompt
(
model_config
,
mm_config
)
max_mm_tokens
=
mm_registry
.
get_max_multimodal_tokens
(
model_config
)
if
max_mm_tokens
>
0
:
if
max_mm_tokens
>
0
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Multi-modal encoder-decoder models are not supported yet"
)
"Multi-modal encoder-decoder models are not supported yet"
)
...
@@ -291,8 +284,10 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -291,8 +284,10 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
batch_size
+=
seq_len
batch_size
+=
seq_len
seq_data
,
_
=
input_registry
\
seq_data
,
_
=
self
.
input_registry
\
.
dummy_data_for_profiling
(
model_config
,
seq_len
,
mm_registry
)
.
dummy_data_for_profiling
(
self
.
model_config
,
seq_len
,
self
.
mm_registry
)
# Having more tokens is over-conservative but otherwise fine
# Having more tokens is over-conservative but otherwise fine
assert
len
(
seq_data
.
prompt_token_ids
)
>=
seq_len
,
(
assert
len
(
seq_data
.
prompt_token_ids
)
>=
seq_len
,
(
...
...
vllm/worker/model_runner.py
View file @
bbf55c48
...
@@ -27,8 +27,8 @@ except ImportError:
...
@@ -27,8 +27,8 @@ except ImportError:
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
...
@@ -804,7 +804,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -804,7 +804,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
return_hidden_states
:
bool
=
False
,
return_hidden_states
:
bool
=
False
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
...
@@ -819,7 +818,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -819,7 +818,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
load_config
=
load_config
self
.
load_config
=
load_config
self
.
is_driver_worker
=
is_driver_worker
self
.
is_driver_worker
=
is_driver_worker
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
multimodal_config
=
multimodal_config
self
.
return_hidden_states
=
return_hidden_states
self
.
return_hidden_states
=
return_hidden_states
self
.
observability_config
=
observability_config
self
.
observability_config
=
observability_config
...
@@ -866,6 +864,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -866,6 +864,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
mm_registry
=
mm_registry
self
.
mm_registry
=
mm_registry
self
.
multi_modal_input_mapper
=
mm_registry
\
self
.
multi_modal_input_mapper
=
mm_registry
\
.
create_input_mapper
(
model_config
)
.
create_input_mapper
(
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
self
.
model_config
)
# Lazy initialization
# Lazy initialization
self
.
model
:
nn
.
Module
# Set after load_model
self
.
model
:
nn
.
Module
# Set after load_model
...
@@ -893,7 +892,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -893,7 +892,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
device_config
=
self
.
device_config
,
device_config
=
self
.
device_config
,
load_config
=
self
.
load_config
,
load_config
=
self
.
load_config
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
multimodal_config
=
self
.
multimodal_config
,
parallel_config
=
self
.
parallel_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
scheduler_config
=
self
.
scheduler_config
,
cache_config
=
self
.
cache_config
)
cache_config
=
self
.
cache_config
)
...
@@ -1056,14 +1054,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1056,14 +1054,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# To exercise the worst scenario for GPU memory consumption,
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
# of images processed.
model_config
=
self
.
model_config
mm_config
=
self
.
multimodal_config
input_registry
=
self
.
input_registry
max_mm_tokens
=
self
.
mm_registry
.
get_max_multimodal_tokens
(
mm_registry
=
self
.
mm_registry
self
.
model_config
)
mm_registry
.
init_mm_limits_per_prompt
(
model_config
,
mm_config
)
max_mm_tokens
=
mm_registry
.
get_max_multimodal_tokens
(
model_config
)
if
max_mm_tokens
>
0
:
if
max_mm_tokens
>
0
:
max_num_seqs_orig
=
max_num_seqs
max_num_seqs_orig
=
max_num_seqs
max_num_seqs
=
min
(
max_num_seqs
,
max_num_seqs
=
min
(
max_num_seqs
,
...
@@ -1082,8 +1075,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1082,8 +1075,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
batch_size
+=
seq_len
batch_size
+=
seq_len
seq_data
,
dummy_multi_modal_data
=
input_registry
\
seq_data
,
dummy_multi_modal_data
=
self
.
input_registry
\
.
dummy_data_for_profiling
(
model_config
,
seq_len
,
mm_registry
)
.
dummy_data_for_profiling
(
self
.
model_config
,
seq_len
,
self
.
mm_registry
)
seq
=
SequenceGroupMetadata
(
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
request_id
=
str
(
group_id
),
...
...
vllm/worker/tpu_model_runner.py
View file @
bbf55c48
...
@@ -11,7 +11,7 @@ import torch_xla.runtime as xr
...
@@ -11,7 +11,7 @@ import torch_xla.runtime as xr
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -89,7 +89,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -89,7 +89,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
load_config
:
LoadConfig
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
):
):
self
.
model_config
=
model_config
self
.
model_config
=
model_config
...
@@ -98,7 +97,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -98,7 +97,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
load_config
=
load_config
self
.
load_config
=
load_config
self
.
multimodal_config
=
multimodal_config
self
.
is_driver_worker
=
is_driver_worker
self
.
is_driver_worker
=
is_driver_worker
self
.
block_size
=
self
.
cache_config
.
block_size
self
.
block_size
=
self
.
cache_config
.
block_size
...
@@ -142,7 +140,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -142,7 +140,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
parallel_config
=
self
.
parallel_config
,
parallel_config
=
self
.
parallel_config
,
cache_config
=
self
.
cache_config
,
cache_config
=
self
.
cache_config
,
scheduler_config
=
self
.
scheduler_config
,
scheduler_config
=
self
.
scheduler_config
,
multimodal_config
=
self
.
multimodal_config
,
lora_config
=
None
,
lora_config
=
None
,
)
)
model
=
model
.
eval
()
model
=
model
.
eval
()
...
...
vllm/worker/tpu_worker.py
View file @
bbf55c48
...
@@ -7,7 +7,7 @@ import torch_xla.runtime as xr
...
@@ -7,7 +7,7 @@ import torch_xla.runtime as xr
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
)
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -31,7 +31,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -31,7 +31,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
load_config
:
LoadConfig
,
multimodal_config
:
Optional
[
MultiModalConfig
],
local_rank
:
int
,
local_rank
:
int
,
rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
distributed_init_method
:
str
,
...
@@ -44,7 +43,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -44,7 +43,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
load_config
=
load_config
self
.
load_config
=
load_config
self
.
multimodal_config
=
multimodal_config
self
.
local_rank
=
local_rank
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
distributed_init_method
=
distributed_init_method
...
@@ -64,7 +62,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -64,7 +62,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
device_config
,
device_config
,
cache_config
,
cache_config
,
load_config
,
load_config
,
multimodal_config
,
is_driver_worker
=
is_driver_worker
)
is_driver_worker
=
is_driver_worker
)
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
...
...
vllm/worker/utils.py
View file @
bbf55c48
...
@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario(
...
@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario(
raise
NotImplementedError
(
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_PP'
])
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_PP'
])
if
enc_dec_mr
.
multimodal_config
is
not
None
:
if
enc_dec_mr
.
model_config
.
multimodal_config
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_MM'
])
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_MM'
])
...
...
vllm/worker/worker.py
View file @
bbf55c48
...
@@ -7,8 +7,8 @@ import torch
...
@@ -7,8 +7,8 @@ import torch
import
torch.distributed
import
torch.distributed
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
SpeculativeConfig
)
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
,
init_distributed_environment
,
...
@@ -46,7 +46,6 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -46,7 +46,6 @@ class Worker(LocalOrDistributedWorkerBase):
rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
distributed_init_method
:
str
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
speculative_config
:
Optional
[
SpeculativeConfig
]
=
None
,
speculative_config
:
Optional
[
SpeculativeConfig
]
=
None
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
...
@@ -73,7 +72,6 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -73,7 +72,6 @@ class Worker(LocalOrDistributedWorkerBase):
# note: lazy import to avoid importing torch before initializing
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
init_cached_hf_modules
()
self
.
multimodal_config
=
multimodal_config
self
.
observability_config
=
observability_config
self
.
observability_config
=
observability_config
# Return hidden states from target model if the draft model is an
# Return hidden states from target model if the draft model is an
...
@@ -103,7 +101,6 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -103,7 +101,6 @@ class Worker(LocalOrDistributedWorkerBase):
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
is_driver_worker
,
is_driver_worker
=
is_driver_worker
,
prompt_adapter_config
=
prompt_adapter_config
,
prompt_adapter_config
=
prompt_adapter_config
,
multimodal_config
=
multimodal_config
,
observability_config
=
observability_config
,
observability_config
=
observability_config
,
**
speculative_args
,
**
speculative_args
,
)
)
...
...
vllm/worker/xpu_model_runner.py
View file @
bbf55c48
...
@@ -125,6 +125,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
...
@@ -125,6 +125,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self
.
mm_registry
=
mm_registry
self
.
mm_registry
=
mm_registry
self
.
multi_modal_input_mapper
=
mm_registry
\
self
.
multi_modal_input_mapper
=
mm_registry
\
.
create_input_mapper
(
model_config
)
.
create_input_mapper
(
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
self
.
model_config
)
# Lazy initialization.
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
model
:
nn
.
Module
# Set after init_Model
...
@@ -166,14 +167,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
...
@@ -166,14 +167,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# To exercise the worst scenario for GPU memory consumption,
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
# of images processed.
model_config
=
self
.
model_config
max_mm_tokens
=
self
.
mm_registry
.
get_max_multimodal_tokens
(
mm_config
=
self
.
multimodal_config
self
.
model_config
)
input_registry
=
self
.
input_registry
mm_registry
=
self
.
mm_registry
mm_registry
.
init_mm_limits_per_prompt
(
model_config
,
mm_config
)
max_mm_tokens
=
mm_registry
.
get_max_multimodal_tokens
(
model_config
)
if
max_mm_tokens
>
0
:
if
max_mm_tokens
>
0
:
max_num_seqs_orig
=
max_num_seqs
max_num_seqs_orig
=
max_num_seqs
max_num_seqs
=
min
(
max_num_seqs
,
max_num_seqs
=
min
(
max_num_seqs
,
...
@@ -190,8 +185,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
...
@@ -190,8 +185,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
seq_data
,
dummy_multi_modal_data
=
input_registry
\
seq_data
,
dummy_multi_modal_data
=
self
.
input_registry
\
.
dummy_data_for_profiling
(
model_config
,
seq_len
,
mm_registry
)
.
dummy_data_for_profiling
(
self
.
model_config
,
seq_len
,
self
.
mm_registry
)
seq
=
SequenceGroupMetadata
(
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
request_id
=
str
(
group_id
),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment