Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
249 additions
and
395 deletions
+249
-395
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+13
-12
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+188
-228
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+23
-118
vllm/worker/worker.py
vllm/worker/worker.py
+18
-18
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+7
-19
No files found.
vllm/worker/cpu_worker.py
View file @
1591c68f
...
...
@@ -13,7 +13,7 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
...
...
@@ -256,22 +256,24 @@ class CPUWorker(LoraNotSupportedWorkerBase):
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
List
[
SamplerOutput
]:
if
execute_model_req
is
None
:
seq_group_metadata_list
=
None
else
:
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
num_seq_groups
:
int
=
len
(
seq_group_metadata_list
)
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
assert
len
(
blocks_to_swap_in
)
==
0
assert
len
(
blocks_to_swap_out
)
==
0
assert
execute_model_req
is
not
None
blocks_to_copy
=
execute_model_req
.
blocks_to_copy
assert
len
(
execute_model_req
.
blocks_to_swap_in
)
==
0
assert
len
(
execute_model_req
.
blocks_to_swap_out
)
==
0
data
:
Dict
[
str
,
Any
]
=
{
"num_seq_groups"
:
num_seq_groups
,
"blocks_to_copy"
:
blocks_to_copy
,
"blocks_to_copy"
:
execute_model_req
.
blocks_to_copy
,
}
broadcast_tensor_dict
(
data
,
src
=
0
)
else
:
...
...
@@ -279,7 +281,6 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups
=
data
[
"num_seq_groups"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
assert
blocks_to_copy
is
not
None
self
.
cache_copy
(
blocks_to_copy
)
# If there is no input, we don't need to execute the model.
...
...
vllm/worker/model_runner.py
View file @
1591c68f
...
...
@@ -9,6 +9,7 @@ import torch.nn as nn
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataPerStage
,
get_attn_backend
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
,
with_pynccl_for_all_reduce
...
...
@@ -20,12 +21,11 @@ from vllm.lora.request import LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
async_tensor_h2d
,
is_hip
,
is_pin_memory_available
,
make_tensor_with_pad
,
maybe_expand_dim
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
get_kv_cache_torch_dtype
,
is_hip
,
is_pin_memory_available
,
make_tensor_with_pad
)
logger
=
init_logger
(
__name__
)
...
...
@@ -43,8 +43,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens
:
List
[
int
]
input_positions
:
List
[
int
]
attn_metadata
:
Optional
[
AttentionMetadataPerStage
]
prompt
_lens
:
List
[
int
]
sub
query_lens
:
List
[
int
]
seq
_lens
:
List
[
int
]
query_lens
:
List
[
int
]
lora_index_mapping
:
List
[
int
]
lora_prompt_mapping
:
List
[
int
]
lora_requests
:
Set
[
LoRARequest
]
...
...
@@ -57,8 +57,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens
=
[],
input_positions
=
[],
attn_metadata
=
None
,
prompt
_lens
=
[],
sub
query_lens
=
[],
seq
_lens
=
[],
query_lens
=
[],
lora_index_mapping
=
[],
lora_prompt_mapping
=
[],
lora_requests
=
set
(),
...
...
@@ -135,9 +135,8 @@ class ModelRunner:
self
.
graph_memory_pool
:
Optional
[
Tuple
[
int
,
int
]]
=
None
# Set during graph capture.
self
.
max_context_len_to_capture
=
(
self
.
model_config
.
max_context_len_to_capture
if
self
.
model_config
is
not
None
else
0
)
self
.
max_seq_len_to_capture
=
(
self
.
model_config
.
max_seq_len_to_capture
if
self
.
model_config
is
not
None
else
0
)
self
.
pin_memory
=
is_pin_memory_available
()
self
.
kv_cache_dtype
=
kv_cache_dtype
...
...
@@ -150,13 +149,16 @@ class ModelRunner:
self
.
model
:
torch
.
nn
.
Module
# Set after load_model
self
.
block_size
:
int
# Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_
context
_len_to_capture. However, creating the block table in
# max_
seq
_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self
.
graph_block_tables
:
torch
.
Tensor
# Set after initial profiling.
# Set if the backend is flashinfer.
self
.
flashinfer_workspace_buffer
:
torch
.
Tensor
def
load_model
(
self
)
->
None
:
with
CudaMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
...
...
@@ -170,8 +172,8 @@ class ModelRunner:
)
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
f
"Loading model weights took
"
f
"
{
self
.
model_memory_usage
/
float
(
2
**
30
)
:.
4
f
}
GB"
)
logger
.
info
(
"Loading model weights took
%.4f GB"
,
self
.
model_memory_usage
/
float
(
2
**
30
))
if
self
.
lora_config
:
assert
hasattr
(
self
.
model
,
"supported_lora_modules"
...
...
@@ -196,18 +198,19 @@ class ModelRunner:
self
.
model
.
load_kv_cache_scales
(
self
.
model_config
.
quantization_param_path
)
else
:
raise
RuntimeError
(
"Using FP8 KV cache and scaling "
"
factors provided but
model
"
f
"
{
self
.
model
.
__class__
}
does not "
"support loading scaling factors."
)
raise
RuntimeError
(
"Using FP8 KV cache and scaling
factors provided but "
"model %s does not support loading scaling factors."
,
self
.
model
.
__class__
)
else
:
logger
.
warn
(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
logger
.
warning
(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
elif
self
.
model_config
.
quantization_param_path
is
not
None
:
logger
.
warn
(
"KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used."
)
logger
.
warn
ing
(
"KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used."
)
def
set_block_size
(
self
,
block_size
:
int
)
->
None
:
self
.
block_size
=
block_size
...
...
@@ -218,7 +221,7 @@ class ModelRunner:
def
get_max_block_per_batch
(
self
)
->
int
:
block_size
=
self
.
block_size
return
(
self
.
max_
context
_len_to_capture
+
block_size
-
1
)
//
block_size
return
(
self
.
max_
seq
_len_to_capture
+
block_size
-
1
)
//
block_size
def
_prepare_prompt
(
self
,
...
...
@@ -231,9 +234,9 @@ class ModelRunner:
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
prompt
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
sub
query_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
...
...
@@ -257,21 +260,19 @@ class ModelRunner:
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
co
mputed
_len
=
seq_data
.
get_num_computed_tokens
()
co
ntext
_len
=
seq_data
.
get_num_computed_tokens
()
# We should use get_len here because in case of preemption
# it contains output tokens.
prefill_end
=
min
(
seq_data
.
get_len
(),
computed_len
+
token_chunk_size
)
prompt_tokens
=
seq_data
.
get_token_ids
()[
computed_len
:
prefill_end
]
prompt_len
=
prefill_end
prompt_lens
.
append
(
prompt_len
)
seq_len
=
min
(
seq_data
.
get_len
(),
context_len
+
token_chunk_size
)
prompt_tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
seq_lens
.
append
(
seq_len
)
# NOTE: This only works for oooooooxxx style attention.
if
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
:
# Prefix is not supported with sliding_window
co
mputed
_len
=
len
(
computed_block_nums
)
*
self
.
block_size
prompt_tokens
=
prompt_tokens
[
co
mputed
_len
:]
co
ntext
_len
=
len
(
computed_block_nums
)
*
self
.
block_size
prompt_tokens
=
prompt_tokens
[
co
ntext
_len
:]
prefix_block_tables
.
append
(
computed_block_nums
)
elif
self
.
scheduler_config
.
chunked_prefill_enabled
:
if
seq_group_metadata
.
block_tables
is
not
None
:
...
...
@@ -285,25 +286,25 @@ class ModelRunner:
prefix_block_tables
.
append
([])
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert
co
mputed
_len
==
0
assert
co
ntext
_len
==
0
# actual prompt lens
context_lens
.
append
(
co
mputed
_len
)
sub
query_lens
.
append
(
prompt
_len
-
co
mputed
_len
)
context_lens
.
append
(
co
ntext
_len
)
query_lens
.
append
(
seq
_len
-
co
ntext
_len
)
input_tokens
.
extend
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
co
mputed_len
,
prefill_
en
d
)))
input_positions
.
extend
(
list
(
range
(
co
ntext_len
,
seq_l
en
)))
lora_id
=
seq_group_metadata
.
lora_int_id
if
lora_id
>
0
:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_index_mapping
+=
[
lora_id
]
*
(
prompt
_len
-
co
mputed
_len
)
lora_index_mapping
+=
[
lora_id
]
*
(
seq
_len
-
co
ntext
_len
)
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
(
prompt
_len
-
co
mputed
_len
(
seq
_len
-
co
ntext
_len
if
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
if
seq_group_metadata
.
multi_modal_data
:
...
...
@@ -313,24 +314,25 @@ class ModelRunner:
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
prompt
_len
)
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq
_len
)
continue
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0,
prompt
_len - sliding_window).
# where start_idx is max(0,
seq
_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
assert
co
mputed
_len
==
0
,
(
assert
co
ntext
_len
==
0
,
(
"Prefix caching is currently not supported with "
"sliding window attention"
)
start_idx
=
max
(
0
,
prompt
_len
-
self
.
sliding_window
)
start_idx
=
max
(
0
,
seq
_len
-
self
.
sliding_window
)
for
i
in
range
(
co
mputed_len
,
prefill_
en
d
):
for
i
in
range
(
co
ntext_len
,
seq_l
en
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
...
...
@@ -340,9 +342,9 @@ class ModelRunner:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
max_
sub
query_len
=
max
(
sub
query_lens
)
max_
prompt
_len
=
max
(
prompt
_lens
)
assert
max_
sub
query_len
>
0
max_query_len
=
max
(
query_lens
)
max_
seq
_len
=
max
(
seq
_lens
)
assert
max_query_len
>
0
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
...
...
@@ -369,50 +371,57 @@ class ModelRunner:
# Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached.
sub
query_lens_tensor
=
torch
.
tensor
(
sub
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
subquery_start_loc
=
torch
.
zeros
(
sub
query_lens_tensor
.
shape
[
0
]
+
1
,
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
subquery_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
prompt
_lens_tensor
=
torch
.
tensor
(
prompt
_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
seq_start_loc
=
torch
.
zeros
(
prompt
_lens_tensor
.
shape
[
0
]
+
1
,
seq
_lens_tensor
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
seq_start_loc
=
torch
.
zeros
(
seq
_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
sub
query_lens_tensor
,
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
subquery_start_loc
.
dtype
,
out
=
subquery_start_loc
[
1
:])
torch
.
cumsum
(
prompt
_lens_tensor
,
torch
.
cumsum
(
seq
_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
prompt_lens
=
prompt_lens
,
prompt_lens_tensor
=
prompt_lens_tensor
,
max_subquery_len
=
max_subquery_len
,
max_context_len
=
None
,
max_prompt_len
=
max_prompt_len
,
subquery_start_loc
=
subquery_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
)
if
self
.
attn_backend
is
FlashInferBackend
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
use_cuda_graph
=
False
,
seq_start_loc
=
seq_start_loc
,
max_seq_len
=
max_seq_len
,
block_tables
=
block_tables
)
else
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
subquery_start_loc
=
subquery_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
)
return
PreparePromptMetadata
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
prompt_lens
=
prompt
_lens
,
sub
query_lens
=
sub
query_lens
,
seq_lens
=
seq
_lens
,
query_lens
=
query_lens
,
lora_index_mapping
=
lora_index_mapping
,
lora_prompt_mapping
=
lora_prompt_mapping
,
lora_requests
=
lora_requests
,
...
...
@@ -427,12 +436,30 @@ class ModelRunner:
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
context
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
lora_index_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
paged_kv_indices
:
List
[
int
]
=
[]
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
paged_kv_indptr
:
List
[
int
]
=
[
0
]
# paged_kv_last_page_len is the length of the last page of each request
paged_kv_last_page_len
:
List
[
int
]
=
[]
if
len
(
seq_group_metadata_list
)
==
0
:
return
PrepareDecodeMetadata
.
empty
()
...
...
@@ -455,9 +482,9 @@ class ModelRunner:
position
=
seq_len
-
1
input_positions
.
append
(
position
)
context
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
context
_lens
.
append
(
context
_len
)
seq
_lens
.
append
(
seq
_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
...
...
@@ -473,15 +500,21 @@ class ModelRunner:
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
paged_kv_indices
.
extend
(
block_table
)
paged_kv_indptr
.
append
(
paged_kv_indptr
[
-
1
]
+
len
(
block_table
))
last_page_len
=
seq_data
.
get_len
()
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
paged_kv_last_page_len
.
append
(
last_page_len
)
# vLLM uses cuda graph only for decoding requests.
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
batch_size
=
len
(
input_tokens
)
max_context_len
=
max
(
context_lens
)
use_captured_graph
=
(
not
self
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_context_len
<=
self
.
max_context_len_to_capture
)
max_seq_len
=
max
(
seq_lens
)
use_captured_graph
=
(
not
self
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_seq_len
<=
self
.
max_seq_len_to_capture
)
if
use_captured_graph
:
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
...
...
@@ -489,21 +522,21 @@ class ModelRunner:
input_tokens
.
append
(
0
)
input_positions
.
append
(
0
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
context
_lens
.
append
(
1
)
seq
_lens
.
append
(
1
)
block_tables
.
append
([])
lora_index_mapping
.
append
(
0
)
batch_size
=
graph_batch_size
context
_lens_tensor
=
torch
.
tensor
(
context
_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
seq
_lens_tensor
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
if
use_captured_graph
:
# When using cuda-graph all these tensors should be
# padded.
assert
context
_lens_tensor
.
shape
[
0
]
==
len
(
input_tokens
)
assert
context
_lens_tensor
.
shape
[
0
]
==
len
(
input_positions
)
assert
context
_lens_tensor
.
shape
[
0
]
==
len
(
slot_mapping
)
assert
seq
_lens_tensor
.
shape
[
0
]
==
len
(
input_tokens
)
assert
seq
_lens_tensor
.
shape
[
0
]
==
len
(
input_positions
)
assert
seq
_lens_tensor
.
shape
[
0
]
==
len
(
slot_mapping
)
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
...
...
@@ -523,19 +556,51 @@ class ModelRunner:
device
=
self
.
device
,
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
prompt_lens
=
None
,
prompt_lens_tensor
=
None
,
max_subquery_len
=
None
,
max_context_len
=
max_context_len
,
max_prompt_len
=
None
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
if
self
.
attn_backend
is
FlashInferBackend
:
if
not
hasattr
(
self
,
"flashinfer_workspace_buffer"
):
# Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
self
.
flashinfer_workspace_buffer
=
torch
.
empty
(
16
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
paged_kv_indptr
=
torch
.
tensor
(
paged_kv_indptr
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
paged_kv_indices
=
torch
.
tensor
(
paged_kv_indices
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
paged_kv_last_page_len
=
torch
.
tensor
(
paged_kv_last_page_len
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
kv_cache_dtype
,
self
.
model_config
.
dtype
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
use_cuda_graph
=
False
,
workspace_buffer
=
self
.
flashinfer_workspace_buffer
,
paged_kv_indptr
=
paged_kv_indptr
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len
=
paged_kv_last_page_len
,
num_qo_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
head_dim
=
self
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
data_type
=
kv_cache_dtype
)
else
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
None
,
max_seq_len
=
max_seq_len
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
return
PrepareDecodeMetadata
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
...
...
@@ -546,108 +611,6 @@ class ModelRunner:
slot_mapping
=
slot_mapping
,
)
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
subquery_lens
:
Optional
[
List
[
int
]],
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
if
seq_group_metadata
.
is_prompt
:
assert
len
(
seq_ids
)
==
1
assert
subquery_lens
is
not
None
subquery_len
=
subquery_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
(
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
))
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
subquery_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
subquery_len
-
1
)
selected_token_start_idx
+=
subquery_len
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
sampling_params
.
seed
)
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
num_seqs
))
selected_token_start_idx
+=
num_seqs
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
list
(
zip
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
+
num_seqs
),
range
(
categorized_sampled_token_indices_start_idx
,
categorized_sampled_token_indices_start_idx
+
num_seqs
))))
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
if
sampling_params
.
seed
is
not
None
:
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
)
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
async_tensor_h2d
(
seq_ids
,
dtype
=
torch
.
int
,
target_device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
),
2
,
2
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
generators
=
generators
,
)
return
sampling_metadata
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
...
@@ -667,8 +630,8 @@ class ModelRunner:
input_tokens
,
input_positions
,
prefill_attn_metadata
,
prompt
_lens
,
sub
query_lens
,
seq
_lens
,
query_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
,
...
...
@@ -684,14 +647,14 @@ class ModelRunner:
decode_lora_requests
,
decode_slot_mapping
,
)
=
self
.
_prepare_decode
(
decode_reqs
)
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
s
ubquery_lens
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
s
elf
.
pin_memory
)
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
:
assert
(
len
(
prefill_reqs
)
and
len
(
decode_reqs
))
==
0
num_prefills
=
len
(
prompt
_lens
)
num_prefills
=
len
(
seq
_lens
)
num_prefill_tokens
=
len
(
input_tokens
)
num_decode_tokens
=
len
(
decode_input_tokens
)
...
...
@@ -787,12 +750,9 @@ class ModelRunner:
**
metadata_dict
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
generators
=
None
,
perform_sampling
=
False
,
num_prompts
=
0
,
)
# if it is a mixed batch, decode attn_metadata is broadcasted
...
...
@@ -851,7 +811,7 @@ class ModelRunner:
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
s
ampling_metadata
.
perform_sampling
:
if
not
s
elf
.
is_driver_worker
:
return
None
# Sample the next token.
...
...
@@ -859,6 +819,7 @@ class ModelRunner:
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
return
output
@
torch
.
inference_mode
()
...
...
@@ -928,10 +889,10 @@ class ModelRunner:
torch
.
cuda
.
synchronize
()
return
def
remove_all_loras
(
self
)
->
bool
:
def
remove_all_loras
(
self
):
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
remove_all_loras
()
self
.
lora_manager
.
remove_all_loras
()
def
set_active_loras
(
self
,
lora_requests
:
Set
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
...
...
@@ -990,7 +951,7 @@ class ModelRunner:
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
.
fill_
(
_PAD_SLOT_ID
)
context
_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
seq
_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
graph_batch_size
=
_get_graph_batch_size
(
...
...
@@ -1012,14 +973,13 @@ class ModelRunner:
# Create dummy attn_metadata.
decode_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
prompt_lens
=
None
,
prompt_lens_tensor
=
None
,
max_subquery_len
=
None
,
max_context_len
=
self
.
max_context_len_to_capture
,
max_prompt_len
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens
[:
batch_size
],
max_query_len
=
None
,
max_seq_len
=
self
.
max_seq_len_to_capture
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens
[:
batch_size
]
,
context_lens
_tensor
=
None
,
block_tables
=
block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
)
...
...
@@ -1054,7 +1014,7 @@ class ModelRunner:
end_time
=
time
.
perf_counter
()
elapsed_time
=
end_time
-
start_time
# This usually takes < 10 seconds.
logger
.
info
(
f
"Graph capturing finished in
{
elapsed_time
:.
0
f
}
secs."
)
logger
.
info
(
"Graph capturing finished in
%.0f secs."
,
elapsed_time
)
def
__del__
(
self
)
->
None
:
# Delete the CUDA graphs before deleting the pynccl communicator.
...
...
@@ -1129,7 +1089,7 @@ class CUDAGraphRunner:
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"
context_lens
"
:
attn_metadata
.
decode_metadata
.
context_lens
,
"
seq_lens_tensor
"
:
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
...
...
@@ -1151,8 +1111,8 @@ class CUDAGraphRunner:
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
non_blocking
=
True
)
self
.
input_buffers
[
"
context_lens
"
].
copy_
(
attn_metadata
.
decode_metadata
.
context_lens
,
non_blocking
=
True
)
self
.
input_buffers
[
"
seq_lens_tensor
"
].
copy_
(
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
non_blocking
=
True
)
self
.
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
# Run the graph.
...
...
vllm/worker/neuron_model_runner.py
View file @
1591c68f
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
...
...
@@ -8,10 +8,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader.neuron
import
get_neuron_model
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
maybe_expand_dim
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
logger
=
init_logger
(
__name__
)
...
...
@@ -54,7 +52,7 @@ class NeuronModelRunner:
input_positions
:
List
[
List
[
int
]]
=
[]
input_block_ids
:
List
[
int
]
=
[]
prompt
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
...
...
@@ -63,26 +61,26 @@ class NeuronModelRunner:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt
_len
=
len
(
prompt_tokens
)
prompt
_lens
.
append
(
prompt
_len
)
seq
_len
=
len
(
prompt_tokens
)
seq
_lens
.
append
(
seq
_len
)
input_tokens
.
append
(
prompt_tokens
)
input_positions
.
append
(
list
(
range
(
prompt
_len
)))
input_positions
.
append
(
list
(
range
(
seq
_len
)))
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
assert
len
(
block_table
)
==
1
input_block_ids
.
append
(
block_table
[
0
])
max_
prompt
_len
=
max
(
prompt
_lens
)
assert
max_
prompt
_len
>
0
max_
seq
_len
=
max
(
seq
_lens
)
assert
max_
seq
_len
>
0
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
max_
prompt
_len
,
max_
seq
_len
,
pad
=
0
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
max_
prompt
_len
,
max_
seq
_len
,
pad
=
0
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
...
...
@@ -90,7 +88,7 @@ class NeuronModelRunner:
dtype
=
torch
.
long
,
device
=
self
.
device
)
return
input_tokens
,
input_positions
,
input_block_ids
,
prompt
_lens
return
input_tokens
,
input_positions
,
input_block_ids
,
seq
_lens
def
_prepare_decode
(
self
,
...
...
@@ -141,106 +139,6 @@ class NeuronModelRunner:
return
input_tokens
,
input_positions
,
input_block_ids
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
if
seq_group_metadata
.
is_prompt
:
assert
len
(
seq_ids
)
==
1
assert
prompt_lens
is
not
None
prompt_len
=
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
prompt_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
(
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
))
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
prompt_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
selected_token_start_idx
+=
prompt_len
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
sampling_params
.
seed
)
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
num_seqs
))
selected_token_start_idx
+=
num_seqs
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
zip
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
+
num_seqs
),
range
(
categorized_sampled_token_indices_start_idx
,
categorized_sampled_token_indices_start_idx
+
num_seqs
)))
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
if
sampling_params
.
seed
is
not
None
:
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
)
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
async_tensor_h2d
(
seq_ids
,
dtype
=
torch
.
int
,
target_device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
),
2
,
2
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
generators
=
generators
,
)
return
sampling_metadata
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
...
@@ -251,13 +149,20 @@ class NeuronModelRunner:
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
input_block_ids
,
prompt
_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
seq
_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
input_block_ids
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
seq_lens
=
[]
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens
,
self
.
device
,
self
.
pin_memory
)
return
(
input_tokens
,
input_positions
,
input_block_ids
,
sampling_metadata
)
...
...
vllm/worker/worker.py
View file @
1591c68f
...
...
@@ -11,13 +11,14 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
VisionLanguageConfig
)
from
vllm.distributed
import
(
broadcast_tensor_dict
,
ensure_model_parallel_initialized
,
get_tensor_model_parallel_cpu_group
,
init_distributed_environment
)
from
vllm.distributed.device_communicators
import
pynccl_utils
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
init_custom_ar
)
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.worker_base
import
WorkerBase
...
...
@@ -210,19 +211,21 @@ class Worker(WorkerBase):
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
num_lookahead_slots
:
int
=
0
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
if
execute_model_req
is
None
:
seq_group_metadata_list
=
None
else
:
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
assert
execute_model_req
is
not
None
num_seq_groups
=
len
(
seq_group_metadata_list
)
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
blocks_to_swap_in
=
execute_model_req
.
blocks_to_swap_in
blocks_to_swap_out
=
execute_model_req
.
blocks_to_swap_out
blocks_to_copy
=
execute_model_req
.
blocks_to_copy
data
:
Dict
[
str
,
Any
]
=
{
"num_seq_groups"
:
num_seq_groups
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
...
...
@@ -237,9 +240,6 @@ class Worker(WorkerBase):
blocks_to_swap_out
=
data
[
"blocks_to_swap_out"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
self
.
cache_swap
(
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
# If there is no input, we don't need to execute the model.
...
...
@@ -288,6 +288,9 @@ def init_worker_distributed_environment(
init_distributed_environment
(
parallel_config
.
world_size
,
rank
,
distributed_init_method
,
local_rank
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
if
pynccl_utils
.
is_initialized
():
pynccl_world_size
=
pynccl_utils
.
get_world_size
()
if
pynccl_world_size
!=
parallel_config
.
world_size
:
...
...
@@ -298,12 +301,9 @@ def init_worker_distributed_environment(
elif
parallel_config
.
world_size
>
1
:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
# NOTE(kaichao): By default, pynccl will use information inside
# `parallel_state` for initialization.
pynccl_utils
.
init_process_group
()
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
# NOTE(kaichao): By default, pynccl is initialized for tp group.
pynccl_utils
.
init_process_group
(
group
=
get_tensor_model_parallel_cpu_group
())
# Initialize a custom fast all-reduce implementation.
if
not
parallel_config
.
disable_custom_all_reduce
:
...
...
vllm/worker/worker_base.py
View file @
1591c68f
import
datetime
import
importlib
import
os
import
tempfile
import
threading
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
get_vllm_instance_id
,
update_environment_variables
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
update_environment_variables
)
logger
=
init_logger
(
__name__
)
...
...
@@ -50,10 +48,8 @@ class WorkerBase(ABC):
@
abstractmethod
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
List
[
SamplerOutput
]:
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
raise
NotImplementedError
...
...
@@ -128,15 +124,7 @@ class WorkerWrapperBase:
function tracing if required.
Arguments are passed to the worker class constructor.
"""
if
int
(
os
.
getenv
(
"VLLM_TRACE_FUNCTION"
,
"0"
)):
tmp_dir
=
tempfile
.
gettempdir
()
filename
=
(
f
"VLLM_TRACE_FUNCTION_for_process_
{
os
.
getpid
()
}
"
f
"_thread_
{
threading
.
get_ident
()
}
_"
f
"at_
{
datetime
.
datetime
.
now
()
}
.log"
).
replace
(
" "
,
"_"
)
log_path
=
os
.
path
.
join
(
tmp_dir
,
"vllm"
,
get_vllm_instance_id
(),
filename
)
os
.
makedirs
(
os
.
path
.
dirname
(
log_path
),
exist_ok
=
True
)
enable_trace_function_call
(
log_path
)
enable_trace_function_call_for_thread
()
mod
=
importlib
.
import_module
(
self
.
worker_module_name
)
worker_class
=
getattr
(
mod
,
self
.
worker_class_name
)
...
...
Prev
1
…
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment