Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db7db4aa
Unverified
Commit
db7db4aa
authored
Nov 07, 2024
by
Cyrus Leung
Committed by
GitHub
Nov 07, 2024
Browse files
[Misc] Consolidate ModelConfig code related to HF config (#10104)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
1fa020c5
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
68 additions
and
43 deletions
+68
-43
docs/source/serving/compatibility_matrix.rst
docs/source/serving/compatibility_matrix.rst
+1
-1
tests/test_config.py
tests/test_config.py
+38
-0
vllm/config.py
vllm/config.py
+8
-6
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+1
-1
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+9
-0
vllm/utils.py
vllm/utils.py
+0
-4
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+1
-8
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+1
-4
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+8
-15
vllm/worker/worker.py
vllm/worker/worker.py
+1
-4
No files found.
docs/source/serving/compatibility_matrix.rst
View file @
db7db4aa
...
@@ -359,7 +359,7 @@ Feature x Hardware
...
@@ -359,7 +359,7 @@ Feature x Hardware
- ✅
- ✅
- ✅
- ✅
- ✅
- ✅
-
`✗ <https://github.com/vllm-project/vllm/blob/a84e598e2125960d3b4f716b78863f24ac562947/vllm/worker/cpu_model_runner.py#L125>`__
-
✅
- ✗
- ✗
* - :abbr:`logP (Logprobs)`
* - :abbr:`logP (Logprobs)`
- ✅
- ✅
...
...
tests/test_config.py
View file @
db7db4aa
...
@@ -165,3 +165,41 @@ def test_rope_customization():
...
@@ -165,3 +165,41 @@ def test_rope_customization():
assert
getattr
(
longchat_model_config
.
hf_config
,
"rope_scaling"
,
assert
getattr
(
longchat_model_config
.
hf_config
,
"rope_scaling"
,
None
)
==
TEST_ROPE_SCALING
None
)
==
TEST_ROPE_SCALING
assert
longchat_model_config
.
max_model_len
==
4096
assert
longchat_model_config
.
max_model_len
==
4096
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"is_encoder_decoder"
),
[
(
"facebook/opt-125m"
,
False
),
(
"facebook/bart-base"
,
True
),
(
"meta-llama/Llama-3.2-1B"
,
False
),
(
"meta-llama/Llama-3.2-11B-Vision"
,
True
),
])
def
test_is_encoder_decoder
(
model_id
,
is_encoder_decoder
):
config
=
ModelConfig
(
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float16"
,
seed
=
0
,
)
assert
config
.
is_encoder_decoder
==
is_encoder_decoder
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"uses_mrope"
),
[
(
"facebook/opt-125m"
,
False
),
(
"Qwen/Qwen2-VL-2B-Instruct"
,
True
),
])
def
test_uses_mrope
(
model_id
,
uses_mrope
):
config
=
ModelConfig
(
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float16"
,
seed
=
0
,
)
assert
config
.
uses_mrope
==
uses_mrope
vllm/config.py
View file @
db7db4aa
...
@@ -15,7 +15,8 @@ from vllm.platforms import current_platform
...
@@ -15,7 +15,8 @@ from vllm.platforms import current_platform
from
vllm.tracing
import
is_otel_available
,
otel_import_error_traceback
from
vllm.tracing
import
is_otel_available
,
otel_import_error_traceback
from
vllm.transformers_utils.config
import
(
ConfigFormat
,
get_config
,
from
vllm.transformers_utils.config
import
(
ConfigFormat
,
get_config
,
get_hf_image_processor_config
,
get_hf_image_processor_config
,
get_hf_text_config
)
get_hf_text_config
,
is_encoder_decoder
,
uses_mrope
)
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
print_warning_once
)
print_warning_once
)
...
@@ -667,12 +668,13 @@ class ModelConfig:
...
@@ -667,12 +668,13 @@ class ModelConfig:
return
self
.
multimodal_config
return
self
.
multimodal_config
@
property
@
property
def
is_encoder_decoder
_model
(
self
)
->
bool
:
def
is_encoder_decoder
(
self
)
->
bool
:
"""Extract the HF encoder/decoder model flag."""
"""Extract the HF encoder/decoder model flag."""
return
getattr
(
return
is_encoder_decoder
(
self
.
hf_config
)
self
.
hf_config
,
"is_encoder_decoder"
,
False
)
or
(
hasattr
(
self
.
hf_config
,
"text_config"
)
and
getattr
(
@
property
self
.
hf_config
.
text_config
,
"is_encoder_decoder"
,
False
))
def
uses_mrope
(
self
)
->
bool
:
return
uses_mrope
(
self
.
hf_config
)
@
property
@
property
def
is_multimodal_model
(
self
)
->
bool
:
def
is_multimodal_model
(
self
)
->
bool
:
...
...
vllm/inputs/preprocess.py
View file @
db7db4aa
...
@@ -580,4 +580,4 @@ class InputPreprocessor:
...
@@ -580,4 +580,4 @@ class InputPreprocessor:
)
)
def
is_encoder_decoder_model
(
self
):
def
is_encoder_decoder_model
(
self
):
return
self
.
model_config
.
is_encoder_decoder
_model
return
self
.
model_config
.
is_encoder_decoder
vllm/transformers_utils/config.py
View file @
db7db4aa
...
@@ -129,6 +129,15 @@ def uses_mrope(config: PretrainedConfig) -> bool:
...
@@ -129,6 +129,15 @@ def uses_mrope(config: PretrainedConfig) -> bool:
return
"mrope_section"
in
rope_scaling
return
"mrope_section"
in
rope_scaling
def
is_encoder_decoder
(
config
:
PretrainedConfig
)
->
bool
:
"""Detect if the model with this config is used as an encoder/decoder."""
text_config
=
getattr
(
config
,
"text_config"
,
None
)
if
text_config
is
not
None
:
return
is_encoder_decoder
(
text_config
)
return
getattr
(
config
,
"is_encoder_decoder"
,
False
)
def
get_config
(
def
get_config
(
model
:
Union
[
str
,
Path
],
model
:
Union
[
str
,
Path
],
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
...
...
vllm/utils.py
View file @
db7db4aa
...
@@ -88,9 +88,6 @@ STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
...
@@ -88,9 +88,6 @@ STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
"currently supported with encoder/"
"currently supported with encoder/"
"decoder models."
)
"decoder models."
)
STR_NOT_IMPL_ENC_DEC_CPU
=
(
"CPU is not currently supported with "
"encoder/decoder models."
)
# Efficiently import all enc/dec error strings
# Efficiently import all enc/dec error strings
# rather than having to import all of the above
# rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS
=
{
STR_NOT_IMPL_ENC_DEC_ERR_STRS
=
{
...
@@ -105,7 +102,6 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
...
@@ -105,7 +102,6 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC"
:
STR_NOT_IMPL_ENC_DEC_SPEC_DEC
,
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC"
:
STR_NOT_IMPL_ENC_DEC_SPEC_DEC
,
"STR_NOT_IMPL_ENC_DEC_BACKEND"
:
STR_NOT_IMPL_ENC_DEC_BACKEND
,
"STR_NOT_IMPL_ENC_DEC_BACKEND"
:
STR_NOT_IMPL_ENC_DEC_BACKEND
,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER"
:
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER
,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER"
:
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER
,
"STR_NOT_IMPL_ENC_DEC_CPU"
:
STR_NOT_IMPL_ENC_DEC_CPU
}
}
# Constants related to forcing the attention backend selection
# Constants related to forcing the attention backend selection
...
...
vllm/worker/cpu_model_runner.py
View file @
db7db4aa
...
@@ -18,7 +18,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
...
@@ -18,7 +18,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs
,
MultiModalPlaceholderMap
)
MultiModalInputs
,
MultiModalPlaceholderMap
)
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
...
@@ -163,7 +162,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
...
@@ -163,7 +162,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
# special processing for mrope position deltas.
# special processing for mrope position deltas.
mrope_positions
=
None
mrope_positions
=
None
if
self
.
runner
.
model_
i
s_mrope
:
if
self
.
runner
.
model_
config
.
use
s_mrope
:
image_grid_thw
=
mm_kwargs
.
get
(
"image_grid_thw"
,
None
)
image_grid_thw
=
mm_kwargs
.
get
(
"image_grid_thw"
,
None
)
video_grid_thw
=
mm_kwargs
.
get
(
"video_grid_thw"
,
None
)
video_grid_thw
=
mm_kwargs
.
get
(
"video_grid_thw"
,
None
)
assert
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
,
(
assert
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
,
(
...
@@ -446,12 +445,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
...
@@ -446,12 +445,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
# Lazy initialization.
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
model
:
nn
.
Module
# Set after init_Model
@
property
def
model_is_mrope
(
self
)
->
bool
:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
return
uses_mrope
(
self
.
model_config
.
hf_config
)
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
...
...
vllm/worker/cpu_worker.py
View file @
db7db4aa
...
@@ -151,7 +151,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -151,7 +151,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self
.
local_omp_cpuid
=
omp_cpuids
.
split
(
"|"
)[
rank
]
self
.
local_omp_cpuid
=
omp_cpuids
.
split
(
"|"
)[
rank
]
ModelRunnerClass
:
Type
[
CPUModelRunner
]
=
CPUModelRunner
ModelRunnerClass
:
Type
[
CPUModelRunner
]
=
CPUModelRunner
if
self
.
_
is_encoder_decoder
_model
()
:
if
self
.
model_config
.
is_encoder_decoder
:
ModelRunnerClass
=
CPUEncoderDecoderModelRunner
ModelRunnerClass
=
CPUEncoderDecoderModelRunner
self
.
model_runner
:
CPUModelRunner
=
ModelRunnerClass
(
self
.
model_runner
:
CPUModelRunner
=
ModelRunnerClass
(
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
...
@@ -188,9 +188,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -188,9 +188,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
raise
RuntimeError
(
"Profiler is not enabled."
)
raise
RuntimeError
(
"Profiler is not enabled."
)
self
.
profiler
.
stop
()
self
.
profiler
.
stop
()
def
_is_encoder_decoder_model
(
self
):
return
self
.
model_config
.
is_encoder_decoder_model
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
if
self
.
local_omp_cpuid
!=
"all"
:
if
self
.
local_omp_cpuid
!=
"all"
:
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
...
...
vllm/worker/model_runner.py
View file @
db7db4aa
...
@@ -47,7 +47,6 @@ from vllm.prompt_adapter.worker_manager import (
...
@@ -47,7 +47,6 @@ from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager
)
LRUCacheWorkerPromptAdapterManager
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.utils
import
(
DeviceMemoryProfiler
,
GiB_bytes
,
PyObjectCache
,
from
vllm.utils
import
(
DeviceMemoryProfiler
,
GiB_bytes
,
PyObjectCache
,
async_tensor_h2d
,
flatten_2d_lists
,
async_tensor_h2d
,
flatten_2d_lists
,
is_pin_memory_available
,
supports_dynamo
,
is_pin_memory_available
,
supports_dynamo
,
...
@@ -493,7 +492,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -493,7 +492,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
context_len
=
seq_data
.
get_num_computed_tokens
()
context_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
elif
self
.
runner
.
scheduler_config
.
is_multi_step
or
\
elif
self
.
runner
.
scheduler_config
.
is_multi_step
or
\
self
.
runner
.
model_config
.
is_encoder_decoder
_model
:
self
.
runner
.
model_config
.
is_encoder_decoder
:
context_len
=
seq_len
-
1
context_len
=
seq_len
-
1
else
:
else
:
context_len
=
seq_data
.
get_num_computed_tokens
()
context_len
=
seq_data
.
get_num_computed_tokens
()
...
@@ -666,7 +665,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -666,7 +665,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data
.
multi_modal_placeholder_maps
=
placeholder_maps
inter_data
.
multi_modal_placeholder_maps
=
placeholder_maps
# special processing for mrope position deltas.
# special processing for mrope position deltas.
if
self
.
runner
.
model_
i
s_mrope
:
if
self
.
runner
.
model_
config
.
use
s_mrope
:
image_grid_thw
=
mm_kwargs
.
get
(
"image_grid_thw"
,
None
)
image_grid_thw
=
mm_kwargs
.
get
(
"image_grid_thw"
,
None
)
video_grid_thw
=
mm_kwargs
.
get
(
"video_grid_thw"
,
None
)
video_grid_thw
=
mm_kwargs
.
get
(
"video_grid_thw"
,
None
)
assert
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
,
(
assert
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
,
(
...
@@ -711,7 +710,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -711,7 +710,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
encoder_seq_len
=
0
encoder_seq_len
=
0
if
self
.
runner
.
model_config
.
is_encoder_decoder
_model
:
if
self
.
runner
.
model_config
.
is_encoder_decoder
:
encoder_seq_len
=
seq_group_metadata
.
encoder_seq_data
.
get_len
()
encoder_seq_len
=
seq_group_metadata
.
encoder_seq_data
.
get_len
()
inter_data
=
self
.
init_cached_inter_data
(
inter_data
=
self
.
init_cached_inter_data
(
...
@@ -837,7 +836,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -837,7 +836,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if
not
inter_data
.
is_prompt
:
if
not
inter_data
.
is_prompt
:
max_decode_seq_len
=
max
(
max_decode_seq_len
,
max_decode_seq_len
=
max
(
max_decode_seq_len
,
max
(
inter_data
.
seq_lens
))
max
(
inter_data
.
seq_lens
))
if
self
.
runner
.
model_config
.
is_encoder_decoder
_model
:
if
self
.
runner
.
model_config
.
is_encoder_decoder
:
max_encoder_seq_len
=
max
(
max_encoder_seq_len
,
max_encoder_seq_len
=
max
(
max_encoder_seq_len
,
inter_data
.
encoder_seq_len
)
inter_data
.
encoder_seq_len
)
...
@@ -1375,12 +1374,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1375,12 +1374,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
raise
RuntimeError
(
"PromptAdapter is not enabled."
)
raise
RuntimeError
(
"PromptAdapter is not enabled."
)
return
self
.
prompt_adapter_manager
.
list_adapters
()
return
self
.
prompt_adapter_manager
.
list_adapters
()
@
property
def
model_is_mrope
(
self
)
->
bool
:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
return
uses_mrope
(
self
.
model_config
.
hf_config
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
capture_model
(
self
,
kv_caches
:
List
[
List
[
torch
.
Tensor
]])
->
None
:
def
capture_model
(
self
,
kv_caches
:
List
[
List
[
torch
.
Tensor
]])
->
None
:
"""Cuda graph capture a model.
"""Cuda graph capture a model.
...
@@ -1411,7 +1404,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1411,7 +1404,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_batch_size
=
self
.
max_batchsize_to_capture
max_batch_size
=
self
.
max_batchsize_to_capture
input_tokens
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_tokens
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
if
self
.
model_
i
s_mrope
:
if
self
.
model_
config
.
use
s_mrope
:
input_positions
=
torch
.
tile
(
input_positions
,
(
3
,
1
))
input_positions
=
torch
.
tile
(
input_positions
,
(
3
,
1
))
# Prepare dummy previous_hidden_states only if needed by the model.
# Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE.
# This is used by draft models such as EAGLE.
...
@@ -1447,7 +1440,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1447,7 +1440,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
attn_state
.
graph_capture_get_metadata_for_batch
(
self
.
attn_state
.
graph_capture_get_metadata_for_batch
(
batch_size
,
batch_size
,
is_encoder_decoder_model
=
self
.
model_config
.
is_encoder_decoder_model
=
self
.
model_config
.
is_encoder_decoder
_model
))
is_encoder_decoder
))
if
self
.
lora_config
:
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
lora_mapping
=
LoRAMapping
(
...
@@ -1466,7 +1459,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1466,7 +1459,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
graph_runner
=
CUDAGraphRunner
(
graph_runner
=
CUDAGraphRunner
(
self
.
model
,
self
.
attn_backend
.
get_name
(),
self
.
model
,
self
.
attn_backend
.
get_name
(),
self
.
attn_state
.
graph_clone
(
batch_size
),
self
.
attn_state
.
graph_clone
(
batch_size
),
self
.
model_config
.
is_encoder_decoder
_model
)
self
.
model_config
.
is_encoder_decoder
)
capture_inputs
=
{
capture_inputs
=
{
"input_ids"
:
"input_ids"
:
...
@@ -1497,7 +1490,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1497,7 +1490,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
model
.
get_seqlen_agnostic_capture_inputs
(
self
.
model
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
batch_size
)
})
})
if
self
.
model_config
.
is_encoder_decoder
_model
:
if
self
.
model_config
.
is_encoder_decoder
:
# add the additional inputs to capture for
# add the additional inputs to capture for
# encoder-decoder models.
# encoder-decoder models.
self
.
_update_inputs_to_capture_for_enc_dec_model
(
self
.
_update_inputs_to_capture_for_enc_dec_model
(
...
...
vllm/worker/worker.py
View file @
db7db4aa
...
@@ -77,7 +77,7 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -77,7 +77,7 @@ class Worker(LocalOrDistributedWorkerBase):
ModelRunnerClass
=
model_runner_cls
ModelRunnerClass
=
model_runner_cls
elif
model_config
.
task
==
"embedding"
:
elif
model_config
.
task
==
"embedding"
:
ModelRunnerClass
=
EmbeddingModelRunner
ModelRunnerClass
=
EmbeddingModelRunner
elif
self
.
_
is_encoder_decoder
_model
()
:
elif
self
.
model_config
.
is_encoder_decoder
:
ModelRunnerClass
=
EncoderDecoderModelRunner
ModelRunnerClass
=
EncoderDecoderModelRunner
self
.
model_runner
:
GPUModelRunnerBase
=
ModelRunnerClass
(
self
.
model_runner
:
GPUModelRunnerBase
=
ModelRunnerClass
(
vllm_config
=
self
.
vllm_config
,
vllm_config
=
self
.
vllm_config
,
...
@@ -119,9 +119,6 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -119,9 +119,6 @@ class Worker(LocalOrDistributedWorkerBase):
raise
RuntimeError
(
"Profiler is not enabled."
)
raise
RuntimeError
(
"Profiler is not enabled."
)
self
.
profiler
.
stop
()
self
.
profiler
.
stop
()
def
_is_encoder_decoder_model
(
self
):
return
self
.
model_config
.
is_encoder_decoder_model
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
# torch.distributed.all_reduce does not free the input tensor until
# torch.distributed.all_reduce does not free the input tensor until
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment