Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0640f227
Commit
0640f227
authored
Sep 09, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.0' into v0.6.0-dev
parents
82f1ffdf
32e7db25
Changes
335
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
542 additions
and
95 deletions
+542
-95
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+83
-26
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+2
-2
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+322
-33
vllm/worker/multi_step_worker.py
vllm/worker/multi_step_worker.py
+8
-1
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+11
-5
vllm/worker/neuron_worker.py
vllm/worker/neuron_worker.py
+27
-0
vllm/worker/openvino_model_runner.py
vllm/worker/openvino_model_runner.py
+2
-1
vllm/worker/openvino_worker.py
vllm/worker/openvino_worker.py
+2
-1
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+48
-14
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+2
-1
vllm/worker/utils.py
vllm/worker/utils.py
+1
-1
vllm/worker/worker.py
vllm/worker/worker.py
+2
-2
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+8
-3
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+18
-5
vllm/worker/xpu_worker.py
vllm/worker/xpu_worker.py
+6
-0
No files found.
vllm/worker/model_runner.py
View file @
0640f227
...
@@ -6,8 +6,8 @@ import time
...
@@ -6,8 +6,8 @@ import time
import
warnings
import
warnings
import
weakref
import
weakref
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
TypeVar
,
Union
)
Tuple
,
Type
,
TypeVar
,
Union
)
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
...
@@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
...
@@ -29,6 +30,7 @@ from vllm.lora.layers import LoRAMapping
...
@@ -29,6 +30,7 @@ from vllm.lora.layers import LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
,
SamplingMetadataCache
from
vllm.model_executor
import
SamplingMetadata
,
SamplingMetadataCache
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
...
@@ -41,10 +43,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
...
@@ -41,10 +43,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.prompt_adapter.worker_manager
import
(
from
vllm.prompt_adapter.worker_manager
import
(
LRUCacheWorkerPromptAdapterManager
)
LRUCacheWorkerPromptAdapterManager
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
PyObjectCache
,
async_tensor_h2d
,
from
vllm.utils
import
(
CudaMemoryProfiler
,
PyObjectCache
,
async_tensor_h2d
,
flatten_2d_lists
,
is_hip
,
is_pin_memory_available
)
flatten_2d_lists
,
is_hip
,
is_pin_memory_available
,
supports_dynamo
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_attn_metadata_broadcastable_dict
,
...
@@ -59,10 +61,14 @@ logger = init_logger(__name__)
...
@@ -59,10 +61,14 @@ logger = init_logger(__name__)
LORA_WARMUP_RANK
=
8
LORA_WARMUP_RANK
=
8
_BATCH_SIZE_ALIGNMENT
=
8
_BATCH_SIZE_ALIGNMENT
=
8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
33
)
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
1025
)
]
]
_NUM_WARMUP_ITERS
=
2
_NUM_WARMUP_ITERS
=
2
...
@@ -90,6 +96,9 @@ class ModelInputForGPU(ModelRunnerInputBase):
...
@@ -90,6 +96,9 @@ class ModelInputForGPU(ModelRunnerInputBase):
request_ids_to_seq_ids
:
Optional
[
Dict
[
str
,
List
[
int
]]]
=
None
request_ids_to_seq_ids
:
Optional
[
Dict
[
str
,
List
[
int
]]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
virtual_engine
:
int
=
0
virtual_engine
:
int
=
0
async_callback
:
Optional
[
Callable
]
=
None
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
tensor_dict
=
{
...
@@ -499,23 +508,48 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -499,23 +508,48 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
and
self
.
sliding_window
is
None
and
self
.
sliding_window
is
None
and
inter_data
.
is_prompt
)
and
inter_data
.
is_prompt
)
inter_data
.
prefix_cache_hit
=
prefix_cache_hit
inter_data
.
prefix_cache_hit
=
prefix_cache_hit
if
self
.
chunked_prefill_enabled
and
prefix_cache_hit
:
raise
RuntimeError
(
if
not
prefix_cache_hit
:
"chunked prefill cannot be used with prefix caching now."
)
return
# If prefix cache is hit, advance context length to bypass
assert
computed_block_nums
is
not
None
# hit blocks. Accordingly, input tokens, position and query length
# The cache hit prompt tokens in this sequence. Note that
# have to be updated.
# this may be larger than the sequence length if chunked
if
prefix_cache_hit
:
# prefill is enabled.
assert
computed_block_nums
is
not
None
prefix_cache_len
=
len
(
computed_block_nums
)
*
self
.
block_size
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
# The number of so far computed prompt tokens in this sequence.
context_len
=
inter_data
.
context_lens
[
seq_idx
]
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len
=
inter_data
.
seq_lens
[
seq_idx
]
if
prefix_cache_len
<=
context_len
:
# We already passed the cache hit region,
# so do normal computation.
pass
elif
context_len
<
prefix_cache_len
<
seq_len
:
# Partial hit. Compute the missing part.
uncomputed_start
=
prefix_cache_len
-
context_len
inter_data
.
input_tokens
[
seq_idx
]
=
inter_data
.
input_tokens
[
inter_data
.
input_tokens
[
seq_idx
]
=
inter_data
.
input_tokens
[
seq_idx
][
context_len
:]
seq_idx
][
uncomputed_start
:]
inter_data
.
input_positions
[
seq_idx
]
=
inter_data
.
input_positions
[
inter_data
.
input_positions
[
seq_idx
]
=
inter_data
.
input_positions
[
seq_idx
][
context_len
:]
seq_idx
][
uncomputed_start
:]
context_len
=
prefix_cache_len
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
query_lens
[
inter_data
.
query_lens
[
seq_idx
]
=
inter_data
.
seq_lens
[
seq_idx
]
-
context_len
seq_idx
]
=
inter_data
.
seq_lens
[
seq_idx
]
-
context_len
elif
seq_len
<=
prefix_cache_len
:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
inter_data
.
input_tokens
[
seq_idx
]
=
inter_data
.
input_tokens
[
seq_idx
][
-
1
:]
inter_data
.
input_positions
[
seq_idx
]
=
inter_data
.
input_positions
[
seq_idx
][
-
1
:]
inter_data
.
query_lens
[
seq_idx
]
=
1
inter_data
.
context_lens
[
seq_idx
]
=
inter_data
.
seq_lens
[
seq_idx
]
-
1
def
_compute_for_sliding_window
(
self
,
inter_data
:
InterDataForSeqGroup
,
def
_compute_for_sliding_window
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_idx
:
int
,
...
@@ -632,7 +666,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -632,7 +666,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def
_use_captured_graph
(
self
,
batch_size
:
int
,
def
_use_captured_graph
(
self
,
batch_size
:
int
,
max_decode_seq_len
:
int
)
->
bool
:
max_decode_seq_len
:
int
)
->
bool
:
return
(
self
.
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
return
(
self
.
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
batch_size
<=
self
.
runner
.
max_batchsize_to_capture
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
)
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
)
def
build
(
self
)
->
ModelInputForGPU
:
def
build
(
self
)
->
ModelInputForGPU
:
...
@@ -818,6 +852,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -818,6 +852,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
block_size
=
cache_config
.
block_size
self
.
max_seq_len_to_capture
=
self
.
model_config
.
max_seq_len_to_capture
self
.
max_seq_len_to_capture
=
self
.
model_config
.
max_seq_len_to_capture
self
.
max_batchsize_to_capture
=
_get_max_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
self
.
graph_runners
:
List
[
Dict
[
int
,
CUDAGraphRunner
]]
=
[
self
.
graph_runners
:
List
[
Dict
[
int
,
CUDAGraphRunner
]]
=
[
{}
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
{}
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
...
@@ -835,7 +871,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -835,7 +871,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# The shape of the cached block table will be
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
# (max batch size to capture, max context len to capture / block size).
self
.
graph_block_tables
=
np
.
zeros
(
self
.
graph_block_tables
=
np
.
zeros
(
(
max
(
_BATCH_SIZES_TO_CAPTURE
)
,
self
.
get_max_block_per_batch
()),
(
self
.
max_batchsize_to_capture
,
self
.
get_max_block_per_batch
()),
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
num_attn_heads
=
self
.
model_config
.
get_num_attention_heads
(
num_attn_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
)
self
.
parallel_config
)
...
@@ -945,7 +981,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -945,7 +981,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
"This may lead to less accurate results!"
)
if
envs
.
VLLM_TEST_DYNAMO_GRAPH_CAPTURE
:
if
envs
.
VLLM_TEST_DYNAMO_GRAPH_CAPTURE
and
supports_dynamo
()
:
self
.
model
=
torch
.
compile
(
self
.
model
,
self
.
model
=
torch
.
compile
(
self
.
model
,
fullgraph
=
True
,
fullgraph
=
True
,
backend
=
"eager"
)
backend
=
"eager"
)
...
@@ -1220,7 +1256,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1220,7 +1256,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
start_time
=
time
.
perf_counter
()
start_time
=
time
.
perf_counter
()
# Prepare dummy inputs. These will be reused for all batch sizes.
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size
=
max
(
_BATCH_SIZES_TO_CAPTURE
)
max_batch_size
=
self
.
max_batchsize_to_capture
input_tokens
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_tokens
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
...
@@ -1248,8 +1284,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1248,8 +1284,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
None
None
]
*
self
.
parallel_config
.
pipeline_parallel_size
]
*
self
.
parallel_config
.
pipeline_parallel_size
graph_batch_size
=
_get_graph_batch_size
(
graph_batch_size
=
self
.
max_batchsize_to_capture
self
.
scheduler_config
.
max_num_seqs
)
batch_size_capture_list
=
[
batch_size_capture_list
=
[
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
]
...
@@ -1357,7 +1392,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1357,7 +1392,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
)
->
ModelInputForGPUWithSamplingMetadata
:
)
->
ModelInputForGPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
metadata for the sampling step.
...
@@ -1481,6 +1516,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1481,6 +1516,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if
not
self
.
is_driver_worker
:
if
not
self
.
is_driver_worker
:
return
[]
return
[]
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Sample the next token.
# Sample the next token.
output
:
SamplerOutput
=
self
.
model
.
sample
(
output
:
SamplerOutput
=
self
.
model
.
sample
(
logits
=
logits
,
logits
=
logits
,
...
@@ -1672,3 +1710,22 @@ def _get_graph_batch_size(batch_size: int) -> int:
...
@@ -1672,3 +1710,22 @@ def _get_graph_batch_size(batch_size: int) -> int:
else
:
else
:
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
def
_get_max_graph_batch_size
(
max_num_seqs
:
int
)
->
int
:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling _get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
if not, it means the padded size is larger than the largest size in
_BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
"""
padded_size
=
_get_graph_batch_size
(
max_num_seqs
)
if
padded_size
in
_BATCH_SIZES_TO_CAPTURE
:
return
padded_size
assert
padded_size
>
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
return
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
vllm/worker/model_runner_base.py
View file @
0640f227
...
@@ -5,9 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
...
@@ -5,9 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import
torch
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
SequenceGroupMetadata
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
...
...
vllm/worker/multi_step_model_runner.py
View file @
0640f227
import
dataclasses
import
functools
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
)
try
:
try
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
...
@@ -13,9 +16,13 @@ import torch
...
@@ -13,9 +16,13 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SamplingMetadata
,
get_logprobs
,
get_pythonized_sample_results
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
Logprob
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceOutput
)
from
vllm.utils
import
PyObjectCache
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
...
@@ -31,6 +38,29 @@ if TYPE_CHECKING:
...
@@ -31,6 +38,29 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
seq_output_builder
():
return
SequenceOutput
(
0
,
0
,
{
0
:
Logprob
(
logprob
=
float
(
'inf'
),
rank
=
None
,
decoded_token
=
None
)})
def
completion_seq_group_output_builder
():
return
CompletionSequenceGroupOutput
([],
None
)
# Used by pythonization to reduce python object allocations
class
PythonizationCache
:
def
__init__
(
self
):
self
.
cached_seq_output
=
PyObjectCache
(
seq_output_builder
)
self
.
cached_completion_seq_group_output
=
PyObjectCache
(
completion_seq_group_output_builder
)
def
reset
(
self
):
self
.
cached_seq_output
.
reset
()
self
.
cached_completion_seq_group_output
.
reset
()
@
dataclass
@
dataclass
class
ModelOutput
:
class
ModelOutput
:
"""The output of a single model forward pass.
"""The output of a single model forward pass.
...
@@ -51,6 +81,9 @@ class ModelOutput:
...
@@ -51,6 +81,9 @@ class ModelOutput:
sampler_output_ready_event
:
torch
.
cuda
.
Event
sampler_output_ready_event
:
torch
.
cuda
.
Event
sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
pythonized
:
bool
=
False
pythonized
:
bool
=
False
# On-device tensor containing the logprobs of each token.
logprobs
:
Optional
[
"torch.Tensor"
]
=
None
pythonization_cache
:
Optional
[
PythonizationCache
]
=
None
def
pythonize
(
self
,
input_metadata
:
"StatefulModelInput"
,
def
pythonize
(
self
,
input_metadata
:
"StatefulModelInput"
,
copy_stream
:
torch
.
cuda
.
Stream
,
copy_stream
:
torch
.
cuda
.
Stream
,
...
@@ -76,7 +109,9 @@ class ModelOutput:
...
@@ -76,7 +109,9 @@ class ModelOutput:
blocking
:
bool
)
->
bool
:
blocking
:
bool
)
->
bool
:
"""
"""
If blocking is set, will block until the forward pass for the output is
If blocking is set, will block until the forward pass for the output is
ready and pythonize the output.
ready and pythonize the output. Upon completing Pythonization, erases
self.logprobs (note that a non-blocking call that is performed when
the sampler output is not yet ready, will not erase self.logprobs.)
"""
"""
assert
self
.
sampled_token_ids
is
not
None
assert
self
.
sampled_token_ids
is
not
None
if
not
blocking
and
not
self
.
sampler_output_ready_event
.
query
():
if
not
blocking
and
not
self
.
sampler_output_ready_event
.
query
():
...
@@ -87,7 +122,16 @@ class ModelOutput:
...
@@ -87,7 +122,16 @@ class ModelOutput:
with
torch
.
cuda
.
stream
(
copy_stream
):
with
torch
.
cuda
.
stream
(
copy_stream
):
_pythonize_sampler_output
(
input_metadata
,
self
.
sampler_output
,
_pythonize_sampler_output
(
input_metadata
,
self
.
sampler_output
,
pinned_sampled_token_buffer
,
pinned_sampled_token_buffer
,
self
.
sampled_token_ids
)
self
.
sampled_token_ids
,
self
.
logprobs
,
self
.
pythonization_cache
)
# Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its
# own CUDA stream, nonetheless _pythonize_sampler_output()
# cannot return until Pythonization is complete; therefore
# we know that by the time the CPU reaches this point,
# `self.logprobs` is no longer needed.
self
.
logprobs
=
None
return
True
return
True
...
@@ -191,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -191,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
self
.
pinned_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
self
.
pinned_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
self
.
pythonization_cache
=
PythonizationCache
()
def
make_model_input_from_broadcasted_tensor_dict
(
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
StatefulModelInput
:
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
StatefulModelInput
:
model_input
=
(
StatefulModelInput
.
from_broadcasted_tensor_dict
(
model_input
=
(
StatefulModelInput
.
from_broadcasted_tensor_dict
(
...
@@ -215,6 +261,79 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -215,6 +261,79 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
)
)
return
model_input
return
model_input
def
_async_process_outputs
(
self
,
model_input
:
StatefulModelInput
,
output_proc_callback
:
Callable
):
# Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize
output_proc_callback
()
cont
=
True
for
model_output
in
model_input
.
cached_outputs
:
if
not
model_output
.
pythonized
:
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
if
model_output
.
pythonized
:
ctx
=
output_proc_callback
.
keywords
[
"ctx"
]
is_async
=
False
is_last_step
=
False
ctx
.
output_queue
.
append
(
([
model_output
.
sampler_output
],
ctx
.
seq_group_metadata_list
,
ctx
.
scheduler_outputs
,
is_async
,
is_last_step
))
output_proc_callback
()
else
:
cont
=
False
if
not
cont
:
break
def
_final_process_outputs
(
self
,
model_input
:
StatefulModelInput
,
output_proc_callback
:
Optional
[
Callable
]):
assert
model_input
.
frozen_model_input
is
not
None
has_async_callback
=
output_proc_callback
is
not
None
outputs
=
[]
for
output_id
in
range
(
len
(
model_input
.
cached_outputs
)):
output
=
model_input
.
cached_outputs
[
output_id
]
is_last_step
=
output_id
==
len
(
model_input
.
cached_outputs
)
-
1
# For non-async case:
# -- We simply add the outputs
# For async case:
# -- Invoke callback, pythonize, add to callback queue and repeat
# -- For last output, just add to callback queue
if
has_async_callback
:
assert
output_proc_callback
is
not
None
# Invoke callback before pythonize (to overlap with GPU)
output_proc_callback
()
# Pythonize
if
not
output
.
pythonized
:
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
# For non last step, add to callback queue to chain
# callbacks=>pythonize pairs (for GPU overlap)
if
not
is_last_step
:
ctx
=
output_proc_callback
.
keywords
[
# type: ignore
"ctx"
]
# type: ignore
is_async
=
False
is_last_step
=
False
ctx
.
output_queue
.
append
(
([
output
.
sampler_output
],
ctx
.
seq_group_metadata_list
,
ctx
.
scheduler_outputs
,
is_async
,
is_last_step
))
else
:
outputs
.
append
(
output
.
sampler_output
)
else
:
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
outputs
.
append
(
output
.
sampler_output
)
return
outputs
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
...
@@ -271,6 +390,20 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -271,6 +390,20 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input
=
self
.
_advance_step
(
model_input
=
self
.
_advance_step
(
model_input
,
model_input
.
cached_outputs
[
-
1
].
sampler_output
)
model_input
,
model_input
.
cached_outputs
[
-
1
].
sampler_output
)
output_proc_callback
=
None
if
frozen_model_input
.
async_callback
is
not
None
:
output_proc_callback
=
frozen_model_input
.
async_callback
assert
output_proc_callback
is
not
None
async_callback
=
functools
.
partial
(
self
.
_async_process_outputs
,
model_input
=
model_input
,
output_proc_callback
=
output_proc_callback
)
frozen_model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
.
frozen_model_input
,
async_callback
=
async_callback
)
assert
frozen_model_input
is
not
None
# Execute the model
# Execute the model
output
=
self
.
_base_model_runner
.
execute_model
(
frozen_model_input
,
output
=
self
.
_base_model_runner
.
execute_model
(
frozen_model_input
,
kv_caches
,
kv_caches
,
...
@@ -294,16 +427,23 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -294,16 +427,23 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
0
].
sampled_token_ids
.
cpu
()
0
].
sampled_token_ids
.
cpu
()
model_input
.
cached_outputs
.
append
(
model_input
.
cached_outputs
.
append
(
ModelOutput
(
output
[
0
],
output_ready_event
,
ModelOutput
(
output
[
0
],
output_ready_event
,
output
[
0
].
sampled_token_ids
,
False
))
output
[
0
].
sampled_token_ids
,
False
,
# make sure we dont try to serialize any GPU tensors
output
[
0
].
logprobs
,
self
.
pythonization_cache
))
# These GPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or
# transferred to CPU
output
[
0
].
sampled_token_ids
=
None
output
[
0
].
sampled_token_ids
=
None
output
[
0
].
sampled_token_probs
=
None
output
[
0
].
sampled_token_probs
=
None
output
[
0
].
logprobs
=
None
output
[
0
].
logprobs
=
None
# Pythonize the output if CPU is ahead and the previous step is
# Pythonize the output if CPU is ahead and the previous step is
# ready.
# ready.
for
model_output
in
model_input
.
cached_outputs
:
if
frozen_model_input
.
async_callback
is
None
:
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
for
model_output
in
model_input
.
cached_outputs
:
self
.
pinned_sampled_token_ids
)
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
model_input
.
current_step
+=
1
model_input
.
current_step
+=
1
...
@@ -316,11 +456,9 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -316,11 +456,9 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Pythonize the output and block if needed since it is the last step
# Pythonize the output and block if needed since it is the last step
if
model_input
.
is_last_step
:
if
model_input
.
is_last_step
:
outputs
=
[]
outputs
=
self
.
_final_process_outputs
(
model_input
,
for
output
in
model_input
.
cached_outputs
:
output_proc_callback
)
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pythonization_cache
.
reset
()
self
.
pinned_sampled_token_ids
)
outputs
.
append
(
output
.
sampler_output
)
return
outputs
return
outputs
# should be [SamplerOutput]
# should be [SamplerOutput]
...
@@ -409,12 +547,76 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -409,12 +547,76 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
return
self
.
_base_model_runner
.
vocab_size
return
self
.
_base_model_runner
.
vocab_size
def
_pythonize_sampler_output
(
model_input
:
StatefulModelInput
,
DeferredLogprobsReturnType
=
Tuple
[
Optional
[
List
[
Optional
[
PromptLogprobs
]]],
output
:
SamplerOutput
,
Optional
[
List
[
SampleLogprobs
]]]
pinned_sampled_token_buffer
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
)
->
None
:
def
deferred_pythonize_logprobs
(
output
:
SamplerOutput
,
sampling_metadata
:
SamplingMetadata
,
logprobs_tensor
:
Optional
[
torch
.
Tensor
],
)
->
DeferredLogprobsReturnType
:
"""Perform deferred logprob Pythonization.
1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
utilizing the Pythonized sampler result computed in step 1.
These deferred computations are not required for single-step scheduling
or the `profile_run()` phase of multi-step scheduling.
Args:
output: sampler output (under deferred Pythonization)
sampling_metadata
Returns:
prompt_logprobs (CPU), sample_logprobs (CPU)
"""
# - Deferred pythonization of sample result
sampler_result
=
get_pythonized_sample_results
(
output
.
deferred_sample_results_args
)
# - Erase the GPU-side deferred sample_result
# computation args to ensure it is never
# pythonized or transferred to CPU
output
.
deferred_sample_results_args
=
None
# - Deferred pythonization of logprobs
(
prompt_logprobs
,
sample_logprobs
,
)
=
get_logprobs
(
logprobs_tensor
,
sampling_metadata
,
sampler_result
)
assert
len
(
prompt_logprobs
)
==
len
(
sampling_metadata
.
seq_groups
)
assert
len
(
sample_logprobs
)
==
len
(
sampling_metadata
.
seq_groups
)
return
prompt_logprobs
,
sample_logprobs
def
_pythonize_sampler_output
(
model_input
:
StatefulModelInput
,
output
:
SamplerOutput
,
pinned_sampled_token_buffer
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
logprobs_tensor
:
Optional
[
torch
.
Tensor
],
cache
:
Optional
[
PythonizationCache
],
)
->
None
:
""" This function is only called when the output tensors are ready.
""" This function is only called when the output tensors are ready.
See ModelOutput
See :class:`ModelOutput`.
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
adding a Pythonized output data structure
(:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
Args:
model_input
output: sampler output
pinned_sampled_token_token_buffer: CPU-side pinned memory
(receives copy of
GPU-side token buffer.)
sampled_token_ids: GPU-side token buffer
logprobs_tensor: GPU-side tensor containing
logprobs computed during sampling
"""
"""
assert
model_input
.
frozen_model_input
is
not
None
assert
model_input
.
frozen_model_input
is
not
None
...
@@ -434,20 +636,107 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
...
@@ -434,20 +636,107 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
sampling_metadata
=
frozen_model_input
.
sampling_metadata
sampling_metadata
=
frozen_model_input
.
sampling_metadata
for
(
seq_group
,
sample_result
)
in
zip
(
sampling_metadata
.
seq_groups
,
skip_sampler_cpu_output
=
(
samples_list
):
frozen_model_input
.
sampling_metadata
.
skip_sampler_cpu_output
)
seq_ids
=
seq_group
.
seq_ids
next_token_ids
=
sample_result
# We are guaranteed output tensors are ready, so it is safe to
parent_ids
=
[
0
]
# pythonize the sampler output & obtain CPU-side logprobs.
seq_outputs
:
List
[
SequenceOutput
]
=
[]
#
# However this computation may be skipped entirely
# if no pythonization was deferred.
seq_groups
=
sampling_metadata
.
seq_groups
logprobs_are_requested
=
any
([
sg
.
sampling_params
.
logprobs
is
not
None
or
sg
.
sampling_params
.
prompt_logprobs
is
not
None
for
sg
in
seq_groups
])
do_pythonize_logprobs
=
(
skip_sampler_cpu_output
and
logprobs_are_requested
)
(
prompt_logprobs
,
sample_logprobs
,
)
=
(
deferred_pythonize_logprobs
(
output
,
sampling_metadata
,
logprobs_tensor
)
if
do_pythonize_logprobs
else
(
None
,
None
))
for
sgdx
,
(
seq_group
,
sample_result
)
in
enumerate
(
zip
(
seq_groups
,
samples_list
)):
if
seq_group
.
sampling_params
.
logits_processors
:
if
seq_group
.
sampling_params
.
logits_processors
:
assert
len
(
seq_group
.
sampling_params
.
logits_processors
)
==
0
,
(
assert
len
(
seq_group
.
sampling_params
.
logits_processors
)
==
0
,
(
"Logits Processors are not supported in multi-step decoding"
)
"Logits Processors are not supported in multi-step decoding"
)
for
parent_id
,
next_token_id
in
zip
(
parent_ids
,
next_token_ids
):
# TODO(will): support logprobs
if
do_pythonize_logprobs
:
# Hard coded logprob
assert
prompt_logprobs
is
not
None
seq_outputs
.
append
(
assert
sample_logprobs
is
not
None
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
{
next_token_id
:
Logprob
(
logprob
=-
1
)}))
(
output
.
outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
group_prompt_logprobs
,
group_sample_logprobs
,
)
=
(
# Utilize deferred pythonization results
prompt_logprobs
[
sgdx
],
sample_logprobs
[
sgdx
],
)
elif
logprobs_are_requested
:
(
group_prompt_logprobs
,
group_sample_logprobs
,
)
=
(
# profile_run: use already-computed logprobs
output
.
outputs
[
sgdx
].
prompt_logprobs
,
[
sample
.
logprobs
for
sample
in
output
.
outputs
[
sgdx
].
samples
])
seq_ids
=
seq_group
.
seq_ids
next_token_ids
=
sample_result
parent_ids
=
[
0
]
if
cache
is
not
None
:
completion_seq_group_output
:
CompletionSequenceGroupOutput
=
\
cache
.
cached_completion_seq_group_output
.
get_object
()
completion_seq_group_output
.
samples
.
clear
()
seq_outputs
:
List
[
SequenceOutput
]
=
completion_seq_group_output
.
samples
else
:
seq_outputs
=
[]
for
tdx
,
(
parent_id
,
next_token_id
)
in
enumerate
(
zip
(
parent_ids
,
next_token_ids
)):
if
cache
is
not
None
:
seq_output
:
SequenceOutput
=
cache
.
cached_seq_output
.
get_object
(
)
seq_output
.
parent_seq_id
=
seq_ids
[
parent_id
]
seq_output
.
output_token
=
next_token_id
if
logprobs_are_requested
:
seq_output
.
logprobs
=
group_sample_logprobs
[
tdx
]
else
:
logprobs
=
next
(
iter
(
seq_output
.
logprobs
.
values
()))
seq_output
.
logprobs
.
clear
()
logprobs
.
logprob
=
float
(
'inf'
)
logprobs
.
rank
=
None
logprobs
.
decoded_token
=
None
seq_output
.
logprobs
[
next_token_id
]
=
logprobs
seq_outputs
.
append
(
seq_output
)
else
:
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
(
group_sample_logprobs
[
tdx
]
if
logprobs_are_requested
else
{
next_token_id
:
Logprob
(
logprob
=
float
(
'inf'
),
rank
=
None
,
decoded_token
=
None
)
})))
if
cache
is
not
None
:
completion_seq_group_output
.
prompt_logprobs
=
\
group_prompt_logprobs
if
logprobs_are_requested
else
None
output
.
outputs
.
append
(
completion_seq_group_output
)
else
:
output
.
outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
(
group_prompt_logprobs
if
logprobs_are_requested
else
None
)))
assert
len
(
output
.
outputs
)
>
0
assert
len
(
output
.
outputs
)
>
0
vllm/worker/multi_step_worker.py
View file @
0640f227
import
dataclasses
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
vllm.distributed
import
broadcast_tensor_dict
,
get_pp_group
from
vllm.distributed
import
broadcast_tensor_dict
,
get_pp_group
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.model_runner_base
import
BroadcastableModelInput
from
vllm.worker.model_runner_base
import
BroadcastableModelInput
from
vllm.worker.multi_step_model_runner
import
(
MultiStepModelRunner
,
from
vllm.worker.multi_step_model_runner
import
(
MultiStepModelRunner
,
StatefulModelInput
)
StatefulModelInput
)
...
@@ -61,6 +63,11 @@ class MultiStepWorker(Worker):
...
@@ -61,6 +63,11 @@ class MultiStepWorker(Worker):
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
execute_model_req
.
finished_requests_ids
))
if
execute_model_req
.
async_callback
:
model_input
.
frozen_model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
.
frozen_model_input
,
async_callback
=
execute_model_req
.
async_callback
)
else
:
else
:
# on subsequent steps we reuse the worker input and model input
# on subsequent steps we reuse the worker input and model input
multi_step_state
=
self
.
multi_step_states
[
virtual_engine
]
multi_step_state
=
self
.
multi_step_states
[
virtual_engine
]
...
...
vllm/worker/neuron_model_runner.py
View file @
0640f227
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -8,11 +9,11 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
...
@@ -8,11 +9,11 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.neuron
import
get_neuron_model
from
vllm.model_executor.model_loader.neuron
import
get_neuron_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
SequenceGroupMetadata
)
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
ModelRunnerBase
,
ModelRunnerInputBase
from
vllm.worker.model_runner_base
import
ModelRunnerBase
,
ModelRunnerInputBase
...
@@ -76,9 +77,14 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -76,9 +77,14 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self
.
model
:
nn
.
Module
# initialize after load_model.
self
.
model
:
nn
.
Module
# initialize after load_model.
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_neuron_model
(
self
.
model_config
,
if
find_spec
(
"transformers_neuronx"
)
is
not
None
:
parallel_config
=
self
.
parallel_config
,
self
.
model
=
get_neuron_model
(
scheduler_config
=
self
.
scheduler_config
)
self
.
model_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
else
:
raise
NotImplementedError
(
"Supports only Transformer-NeuronX based models."
)
def
_prepare_prompt
(
def
_prepare_prompt
(
self
,
self
,
...
...
vllm/worker/neuron_worker.py
View file @
0640f227
...
@@ -6,6 +6,8 @@ import torch.distributed
...
@@ -6,6 +6,8 @@ import torch.distributed
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
)
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.neuron_model_runner
import
NeuronModelRunner
from
vllm.worker.neuron_model_runner
import
NeuronModelRunner
...
@@ -24,12 +26,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -24,12 +26,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
cache_config
:
CacheConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
if
self
.
model_config
.
trust_remote_code
:
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
from
vllm.utils
import
init_cached_hf_modules
...
@@ -40,6 +48,8 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -40,6 +48,8 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self
.
is_driver_worker
=
True
self
.
is_driver_worker
=
True
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
self
.
init_distributed_environment
()
# Set random seed.
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
...
@@ -98,3 +108,20 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -98,3 +108,20 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
This is required for speculative decoding; it is not yet implemented.
This is required for speculative decoding; it is not yet implemented.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
init_distributed_environment
(
self
):
"""Neuron uses transformers-neuronx for tensor parallelism.
vLLM still needs the environment inited when TP/PP > 1
"""
init_distributed_environment
(
world_size
=
1
,
rank
=
self
.
rank
,
local_rank
=
self
.
local_rank
,
distributed_init_method
=
self
.
distributed_init_method
,
backend
=
"gloo"
,
)
ensure_model_parallel_initialized
(
1
,
1
,
)
vllm/worker/openvino_model_runner.py
View file @
0640f227
...
@@ -11,10 +11,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -11,10 +11,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.openvino
import
get_model
from
vllm.model_executor.model_loader.openvino
import
get_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SequenceGroupMetadata
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/worker/openvino_worker.py
View file @
0640f227
...
@@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict,
...
@@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.openvino_model_runner
import
OpenVINOModelRunner
from
vllm.worker.openvino_model_runner
import
OpenVINOModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
...
...
vllm/worker/tpu_model_runner.py
View file @
0640f227
import
time
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
)
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
numpy
as
np
import
numpy
as
np
...
@@ -10,14 +11,15 @@ import torch_xla.core.xla_model as xm
...
@@ -10,14 +11,15 @@ import torch_xla.core.xla_model as xm
import
torch_xla.runtime
as
xr
import
torch_xla.runtime
as
xr
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispacther
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
Logprob
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceOutput
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
_add_attn_metadata_broadcastable_dict
,
_add_attn_metadata_broadcastable_dict
,
...
@@ -50,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
...
@@ -50,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
best_of
:
List
[
int
]
best_of
:
List
[
int
]
seq_groups
:
List
[
List
[
int
]]
seq_groups
:
List
[
List
[
int
]]
virtual_engine
:
int
=
0
virtual_engine
:
int
=
0
async_callback
:
Optional
[
Callable
]
=
None
def
as_broadcastable_tensor_dict
(
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
...
@@ -144,11 +147,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -144,11 +147,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
)
)
model
=
model
.
eval
()
model
=
model
.
eval
()
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
model
=
ModelWrapper
(
model
)
self
.
model
=
ModelWrapper
(
model
)
self
.
model
=
torch
.
compile
(
model
,
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
_dummy_run
(
def
_dummy_run
(
self
,
self
,
...
@@ -235,8 +234,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -235,8 +234,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
torch
.
_dynamo
.
mark_dynamic
(
t
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
t
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
p
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
p
,
0
)
# Dummy run.
# Dummy run.
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
self
.
model
(
token_ids
,
num_samples
,
kv_caches
)
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
num_samples
,
kv_caches
,
is_prompt
=
is_prompt
)
def
warmup_model
(
def
warmup_model
(
self
,
self
,
...
@@ -530,7 +536,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -530,7 +536,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
if
getattr
(
arg
,
"context_lens"
,
None
)
is
not
None
:
if
getattr
(
arg
,
"context_lens"
,
None
)
is
not
None
:
arg
.
context_lens
=
arg
.
context_lens
.
to
(
self
.
device
)
arg
.
context_lens
=
arg
.
context_lens
.
to
(
self
.
device
)
new_args
.
append
(
arg
)
new_args
.
append
(
arg
)
return
self
.
model
(
*
new_args
)
return
self
.
model
(
*
new_args
,
is_prompt
=
is_prompt
)
num_prefills
=
model_input
.
attn_metadata
.
num_prefills
num_prefills
=
model_input
.
attn_metadata
.
num_prefills
is_prompt
=
num_prefills
>
0
is_prompt
=
num_prefills
>
0
...
@@ -558,6 +564,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -558,6 +564,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model_input
.
attn_metadata
,
model_input
.
input_lens
[
i
:
i
+
1
],
model_input
.
attn_metadata
,
model_input
.
input_lens
[
i
:
i
+
1
],
model_input
.
t
[
i
:
i
+
1
],
model_input
.
p
[
i
:
i
+
1
],
model_input
.
t
[
i
:
i
+
1
],
model_input
.
p
[
i
:
i
+
1
],
model_input
.
num_samples
,
kv_caches
)
model_input
.
num_samples
,
kv_caches
)
if
i
==
0
and
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Retrieve the outputs to CPU.
# Retrieve the outputs to CPU.
next_token_ids
+=
output_token_ids
.
cpu
().
tolist
()
next_token_ids
+=
output_token_ids
.
cpu
().
tolist
()
start_idx
=
end_idx
start_idx
=
end_idx
...
@@ -568,6 +576,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -568,6 +576,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model_input
.
attn_metadata
,
model_input
.
input_lens
,
model_input
.
attn_metadata
,
model_input
.
input_lens
,
model_input
.
t
,
model_input
.
p
,
model_input
.
num_samples
,
model_input
.
t
,
model_input
.
p
,
model_input
.
num_samples
,
kv_caches
)
kv_caches
)
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Retrieve the outputs to CPU.
# Retrieve the outputs to CPU.
next_token_ids
=
output_token_ids
.
cpu
().
tolist
()
next_token_ids
=
output_token_ids
.
cpu
().
tolist
()
...
@@ -591,7 +601,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -591,7 +601,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
batch_idx
+=
1
batch_idx
+=
1
else
:
else
:
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
next_token_id
=
next_token_ids
[
batch_idx
]
[
0
]
next_token_id
=
next_token_ids
[
batch_idx
]
seq_outputs
.
append
(
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
zero_logprob
}))
{
next_token_id
:
zero_logprob
}))
...
@@ -601,11 +611,32 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -601,11 +611,32 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
return
[
SamplerOutput
(
sampler_outputs
)]
return
[
SamplerOutput
(
sampler_outputs
)]
class
ModelWrapper
(
nn
.
Module
):
class
ModelWrapper
(
TorchCompileWrapperWithCustomDispacther
):
def
__init__
(
self
,
model
:
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
):
super
().
__init__
()
self
.
model
=
model
self
.
model
=
model
compiled_callable
=
torch
.
compile
(
self
.
forward
,
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
super
().
__init__
(
compiled_callable
)
def
__call__
(
self
,
*
args
,
is_prompt
:
bool
,
**
kwargs
):
if
len
(
self
.
compiled_codes
)
<
3
or
not
self
.
use_custom_dispatcher
:
# not fully compiled yet, or not using the custom dispatcher,
# let PyTorch handle it
return
self
.
compiled_callable
(
*
args
,
**
kwargs
)
# the 3 compiled codes are:
# 0: for profiling
# 1: for prompt
# 2: for decode
# dispatch to the compiled code directly, skip PyTorch
if
is_prompt
:
with
self
.
dispatch_to_code
(
1
):
return
self
.
forward
(
*
args
,
**
kwargs
)
else
:
with
self
.
dispatch_to_code
(
2
):
return
self
.
forward
(
*
args
,
**
kwargs
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -691,6 +722,9 @@ class ModelWrapper(nn.Module):
...
@@ -691,6 +722,9 @@ class ModelWrapper(nn.Module):
sampled_token_ids
=
torch
.
multinomial
(
probs
,
sampled_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
,
num_samples
,
replacement
=
True
)
replacement
=
True
)
if
num_samples
==
1
:
argmax_token_ids
=
argmax_token_ids
.
squeeze
(
dim
=-
1
)
sampled_token_ids
=
sampled_token_ids
.
squeeze
(
dim
=-
1
)
next_token_ids
=
torch
.
where
(
t
!=
0
,
sampled_token_ids
,
next_token_ids
=
torch
.
where
(
t
!=
0
,
sampled_token_ids
,
argmax_token_ids
)
argmax_token_ids
)
return
next_token_ids
return
next_token_ids
...
...
vllm/worker/tpu_worker.py
View file @
0640f227
...
@@ -102,8 +102,9 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -102,8 +102,9 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
# NOTE(woosuk): Set per-rank cache path since different ranks
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
# can have slightly different XLA graphs.
world_size
=
self
.
parallel_config
.
world_size
world_size
=
self
.
parallel_config
.
world_size
rank
=
xr
.
global_ordinal
()
per_rank_path
=
os
.
path
.
join
(
envs
.
VLLM_XLA_CACHE_PATH
,
per_rank_path
=
os
.
path
.
join
(
envs
.
VLLM_XLA_CACHE_PATH
,
f
"tp
{
world_size
}
_rank
{
self
.
rank
}
"
)
f
"tp
{
world_size
}
_rank
{
rank
}
"
)
xr
.
initialize_cache
(
per_rank_path
,
readonly
=
False
)
xr
.
initialize_cache
(
per_rank_path
,
readonly
=
False
)
def
load_model
(
self
):
def
load_model
(
self
):
...
...
vllm/worker/utils.py
View file @
0640f227
...
@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario(
...
@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario(
raise
NotImplementedError
(
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_PP'
])
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_PP'
])
if
enc_dec_mr
.
model_config
.
multimodal_
config
is
not
None
:
if
enc_dec_mr
.
model_config
.
is_
multimodal_
model
:
raise
NotImplementedError
(
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_MM'
])
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_MM'
])
...
...
vllm/worker/worker.py
View file @
0640f227
...
@@ -17,12 +17,12 @@ from vllm.distributed import (ensure_model_parallel_initialized,
...
@@ -17,12 +17,12 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
)
SequenceGroupMetadataDelta
)
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.embedding_model_runner
import
EmbeddingModelRunner
from
vllm.worker.embedding_model_runner
import
EmbeddingModelRunner
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
...
...
vllm/worker/worker_base.py
View file @
0640f227
...
@@ -11,9 +11,9 @@ from vllm.config import ObservabilityConfig
...
@@ -11,9 +11,9 @@ from vllm.config import ObservabilityConfig
from
vllm.distributed
import
broadcast_tensor_dict
,
get_pp_group
,
get_tp_group
from
vllm.distributed
import
broadcast_tensor_dict
,
get_pp_group
,
get_tp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
SamplerOutput
)
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
update_environment_variables
)
update_environment_variables
)
from
vllm.worker.model_runner_base
import
(
BroadcastableModelInput
,
from
vllm.worker.model_runner_base
import
(
BroadcastableModelInput
,
...
@@ -263,6 +263,11 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -263,6 +263,11 @@ class LocalOrDistributedWorkerBase(WorkerBase):
broadcast_data
.
update
(
kwargs
)
broadcast_data
.
update
(
kwargs
)
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
if
execute_model_req
.
async_callback
:
model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
,
async_callback
=
execute_model_req
.
async_callback
)
return
model_input
,
worker_input
,
kwargs
return
model_input
,
worker_input
,
kwargs
def
prepare_input
(
def
prepare_input
(
...
@@ -289,7 +294,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -289,7 +294,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
def
execute_model
(
def
execute_model
(
self
,
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
Optional
[
List
[
SamplerOutput
]]:
)
->
Optional
[
List
[
SamplerOutput
]]:
"""Executes at least one model step on the given sequences, unless no
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
sequences are provided."""
...
...
vllm/worker/xpu_model_runner.py
View file @
0640f227
...
@@ -12,14 +12,15 @@ from vllm.attention import get_attn_backend
...
@@ -12,14 +12,15 @@ from vllm.attention import get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.distributed
import
get_pp_group
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalInputs
,
MultiModalRegistry
)
MultiModalInputs
,
MultiModalRegistry
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
SequenceGroupMetadata
)
from
vllm.utils
import
CudaMemoryProfiler
,
make_tensor_with_pad
from
vllm.utils
import
CudaMemoryProfiler
,
make_tensor_with_pad
from
vllm.worker.model_runner
import
AttentionMetadata
,
SamplingMetadata
from
vllm.worker.model_runner
import
AttentionMetadata
,
SamplingMetadata
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
...
@@ -439,9 +440,11 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
...
@@ -439,9 +440,11 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
"Setting it to the minimum value of 1."
,
expr
)
"Setting it to the minimum value of 1."
,
expr
)
max_num_seqs
=
1
max_num_seqs
=
1
batch_size
=
0
for
group_id
in
range
(
max_num_seqs
):
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
batch_size
+=
seq_len
seq_data
,
dummy_multi_modal_data
=
self
.
input_registry
\
seq_data
,
dummy_multi_modal_data
=
self
.
input_registry
\
.
dummy_data_for_profiling
(
self
.
model_config
,
.
dummy_data_for_profiling
(
self
.
model_config
,
...
@@ -465,7 +468,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
...
@@ -465,7 +468,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
model_input
=
self
.
prepare_model_input
(
model_input
=
self
.
prepare_model_input
(
seqs
,
finished_requests_ids
=
finished_requests_ids
)
seqs
,
finished_requests_ids
=
finished_requests_ids
)
self
.
execute_model
(
model_input
,
kv_caches
)
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
batch_size
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
self
.
execute_model
(
model_input
,
kv_caches
,
intermediate_tensors
)
torch
.
xpu
.
synchronize
()
torch
.
xpu
.
synchronize
()
return
return
...
@@ -537,7 +546,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
...
@@ -537,7 +546,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
and
self
.
observability_config
.
collect_model_forward_time
):
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_start_time
=
time
.
time
()
model_forward_start_time
=
time
.
time
()
hidden_states
=
model_executable
(
hidden_
or_intermediate_
states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
...
@@ -545,12 +554,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
...
@@ -545,12 +554,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalInputs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
**
MultiModalInputs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
))
device
=
self
.
device
))
# Compute the logits in the last pipeline stage.
if
not
get_pp_group
().
is_last_rank
:
return
hidden_or_intermediate_states
if
(
self
.
observability_config
is
not
None
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_end_time
=
time
.
time
()
model_forward_end_time
=
time
.
time
()
# Compute the logits.
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
logits
=
self
.
model
.
compute_logits
(
hidden_
or_intermediate_
states
,
model_input
.
sampling_metadata
)
model_input
.
sampling_metadata
)
# Only perform sampling in the driver worker.
# Only perform sampling in the driver worker.
...
...
vllm/worker/xpu_worker.py
View file @
0640f227
...
@@ -14,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -14,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SpeculativeConfig
)
SpeculativeConfig
)
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.utils
import
is_xpu
from
vllm.utils
import
is_xpu
...
@@ -198,3 +199,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
...
@@ -198,3 +199,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
ensure_model_parallel_initialized
(
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
pipeline_parallel_size
)
if
parallel_config
.
pipeline_parallel_size
>
1
:
# torch-ccl xpu need a collective API warm up
# before calling send/recv API
get_pp_group
().
all_reduce
(
torch
.
zeros
(
1
).
xpu
())
Prev
1
…
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment