Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3521ba4f
Unverified
Commit
3521ba4f
authored
May 04, 2024
by
SangBin Cho
Committed by
GitHub
May 03, 2024
Browse files
[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
parent
2d7bce9c
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
181 additions
and
164 deletions
+181
-164
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+12
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+6
-1
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+3
-3
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+36
-27
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+29
-29
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+80
-87
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+15
-15
No files found.
vllm/engine/arg_utils.py
View file @
3521ba4f
...
...
@@ -44,7 +44,8 @@ class EngineArgs:
tokenizer_revision
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
enforce_eager
:
bool
=
False
max_context_len_to_capture
:
int
=
8192
max_context_len_to_capture
:
Optional
[
int
]
=
None
max_seq_len_to_capture
:
int
=
8192
disable_custom_all_reduce
:
bool
=
False
tokenizer_pool_size
:
int
=
0
tokenizer_pool_type
:
str
=
"ray"
...
...
@@ -322,6 +323,14 @@ class EngineArgs:
default
=
EngineArgs
.
max_context_len_to_capture
,
help
=
'Maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq_len-to-capture instead'
')'
)
parser
.
add_argument
(
'--max-seq_len-to-capture'
,
type
=
int
,
default
=
EngineArgs
.
max_seq_len_to_capture
,
help
=
'Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.'
)
parser
.
add_argument
(
'--disable-custom-all-reduce'
,
action
=
'store_true'
,
...
...
@@ -492,7 +501,8 @@ class EngineArgs:
self
.
code_revision
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
quantization
,
self
.
quantization_param_path
,
self
.
enforce_eager
,
self
.
max_context_len_to_capture
,
self
.
max_logprobs
,
self
.
skip_tokenizer_init
)
self
.
max_seq_len_to_capture
,
self
.
max_logprobs
,
self
.
skip_tokenizer_init
)
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
...
...
vllm/entrypoints/llm.py
View file @
3521ba4f
...
...
@@ -69,6 +69,9 @@ class LLM:
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
disable_custom_all_reduce: See ParallelConfig
...
...
@@ -90,7 +93,8 @@ class LLM:
gpu_memory_utilization
:
float
=
0.9
,
swap_space
:
int
=
4
,
enforce_eager
:
bool
=
False
,
max_context_len_to_capture
:
int
=
8192
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
**
kwargs
,
)
->
None
:
...
...
@@ -112,6 +116,7 @@ class LLM:
swap_space
=
swap_space
,
enforce_eager
=
enforce_eager
,
max_context_len_to_capture
=
max_context_len_to_capture
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
**
kwargs
,
)
...
...
vllm/model_executor/layers/sampler.py
View file @
3521ba4f
...
...
@@ -1033,8 +1033,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
assert
seq_group
.
is_prompt
,
(
"Caller should ensure the sequence group is in a prefill stage."
)
seq_ids
=
seq_group
.
seq_ids
sub
query_len
=
seq_group
.
sub
query_len
assert
sub
query_len
is
not
None
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
# prompt has only 1 seq id.
assert
len
(
seq_ids
)
==
1
seq_data
=
seq_group
.
seq_data
[
seq_ids
[
0
]]
...
...
@@ -1042,7 +1042,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
prompt_tokens
=
seq_data
.
prompt_token_ids
# +1 because we are looking for a next prompt token.
next_token_index_start
=
computed_len
+
1
next_token_index_end
=
min
(
computed_len
+
sub
query_len
+
1
,
next_token_index_end
=
min
(
computed_len
+
query_len
+
1
,
len
(
prompt_tokens
))
next_prompt_tokens
=
prompt_tokens
[
next_token_index_start
:
next_token_index_end
]
...
...
vllm/model_executor/sampling_metadata.py
View file @
3521ba4f
...
...
@@ -16,17 +16,26 @@ _SEED_0_REPLACEMENT = 3403598558
@
dataclass
class
SequenceGroupToSample
:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Sequence ids for the sequence group in a previous step.
seq_ids
:
List
[
int
]
sampling_params
:
SamplingParams
# seq_id -> sequence data.
seq_data
:
Dict
[
int
,
SequenceData
]
# The length of the prompt of the sequence group. None if it is in a decode
# The length of the sequence (all tokens seen in the past + new token to
# compute attention) of the sequence group. None if it is in a decode
# stage.
prompt_len
:
Optional
[
int
]
# The length of the query tokens to compute in the current step. None if it
# is in a decode stage. The length of subquery_len <= prompt_len.
subquery_len
:
Optional
[
int
]
seq_len
:
Optional
[
int
]
# The length of new query tokens to compute in the current step. None if it
# is in a decode stage. The length of query_len <= seq_len if chunked
# prefill is enabled.
query_len
:
Optional
[
int
]
# A random number generator for sampling.
generator
:
Optional
[
torch
.
Generator
]
# True if the sequence group is in prefill stage. False if it is in a
...
...
@@ -46,8 +55,8 @@ class SequenceGroupToSample:
if
len
(
self
.
prompt_logprob_indices
)
>
0
:
assert
self
.
sampling_params
.
prompt_logprobs
is
not
None
if
self
.
is_prompt
:
assert
self
.
prompt
_len
is
not
None
assert
self
.
sub
query_len
is
not
None
assert
self
.
seq
_len
is
not
None
assert
self
.
query_len
is
not
None
class
SamplingMetadata
:
...
...
@@ -94,8 +103,8 @@ class SamplingMetadata:
@
staticmethod
def
prepare
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt
_lens
:
List
[
int
],
sub
query_lens
:
Optional
[
List
[
int
]],
seq
_lens
:
List
[
int
],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
pin_memory
:
bool
,
)
->
"SamplingMetadata"
:
...
...
@@ -104,8 +113,8 @@ class SamplingMetadata:
selected_token_indices
,
categorized_sample_indices
,
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
prompt
_lens
,
subquery_lens
,
device
)
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query
_lens
,
device
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
device
,
...
...
@@ -137,8 +146,8 @@ class SamplingMetadata:
def
_prepare_seq_groups
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt
_lens
:
List
[
int
],
sub
query_lens
:
Optional
[
List
[
int
]],
seq
_lens
:
List
[
int
],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
...
...
@@ -146,9 +155,9 @@ def _prepare_seq_groups(
Args:
seq_group_metadata_list: A list of sequence group to batch.
prompt
_lens: A list of
prompt
lens per sequence group.
seq
_lens: A list of
sequence
lens per sequence group.
Index of prompt len should match with seq_group_metadata_list.
sub
query_lens: A list of query lengths. Prompt lens include the length
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
`SequenceGroupToSample.generator`.
...
...
@@ -189,8 +198,8 @@ def _prepare_seq_groups(
is_prompt
=
seq_group_metadata
.
is_prompt
generator
:
Optional
[
torch
.
Generator
]
=
None
# If the current seq group is in decode stage, it is None.
prompt
_len
:
Optional
[
int
]
=
None
sub
query_len
:
Optional
[
int
]
=
None
seq
_len
:
Optional
[
int
]
=
None
query_len
:
Optional
[
int
]
=
None
prompt_logprob_indices
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
do_sample
=
seq_group_metadata
.
do_sample
...
...
@@ -203,12 +212,12 @@ def _prepare_seq_groups(
num_prompts
+=
1
num_prefill_sample
=
len
(
seq_ids
)
assert
num_prefill_sample
==
1
assert
sub
query_lens
is
not
None
and
prompt
_lens
is
not
None
sub
query_len
,
prompt
_len
=
sub
query_lens
[
i
],
prompt
_lens
[
i
]
assert
query_lens
is
not
None
and
seq
_lens
is
not
None
query_len
,
seq
_len
=
query_lens
[
i
],
seq
_lens
[
i
]
# If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob.
prompt_logprob_len
=
(
sub
query_len
-
num_prefill_sample
if
do_sample
else
sub
query_len
)
prompt_logprob_len
=
(
query_len
-
num_prefill_sample
if
do_sample
else
query_len
)
sample_len
=
num_prefill_sample
if
do_sample
else
0
else
:
# Decode
...
...
@@ -267,8 +276,8 @@ def _prepare_seq_groups(
seq_ids
=
seq_ids
,
sampling_params
=
sampling_params
,
seq_data
=
seq_group_metadata
.
seq_data
,
prompt_len
=
prompt
_len
,
sub
query_len
=
sub
query_len
,
seq_len
=
seq
_len
,
query_len
=
query_len
,
generator
=
generator
,
is_prompt
=
is_prompt
,
prompt_logprob_indices
=
list
(
prompt_logprob_indices
),
...
...
@@ -367,8 +376,8 @@ class SamplingTensors:
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get
# their logprobs
sub
query_len
=
seq_group
.
sub
query_len
assert
sub
query_len
is
not
None
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
prefill_len
=
len
(
seq_group
.
prompt_logprob_indices
)
temperatures
+=
[
temperature
]
*
prefill_len
top_ps
+=
[
top_p
]
*
prefill_len
...
...
@@ -397,8 +406,8 @@ class SamplingTensors:
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
sub
query_len
=
seq_group
.
sub
query_len
assert
sub
query_len
is
not
None
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
...
...
vllm/worker/cpu_model_runner.py
View file @
3521ba4f
...
...
@@ -80,7 +80,7 @@ class CPUModelRunner:
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
prompt
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
...
...
@@ -92,15 +92,15 @@ class CPUModelRunner:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
computed_len
=
seq_data
.
get_num_computed_tokens
()
prompt
_len
=
len
(
prompt_tokens
)
seq
_len
=
len
(
prompt_tokens
)
prompt
_lens
.
append
(
prompt
_len
)
# Prompt token num
seq
_lens
.
append
(
seq
_len
)
# Prompt token num
input_tokens
.
extend
(
prompt_tokens
)
# Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
prompt
_len
)))
input_positions
.
extend
(
list
(
range
(
computed_len
,
seq
_len
)))
if
seq_group_metadata
.
multi_modal_data
:
multi_modal_input_list
.
append
(
...
...
@@ -109,15 +109,15 @@ class CPUModelRunner:
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0,
prompt
_len - sliding_window).
# where start_idx is max(0,
seq
_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
start_idx
=
max
(
0
,
prompt
_len
-
self
.
sliding_window
)
start_idx
=
max
(
0
,
seq
_len
-
self
.
sliding_window
)
for
i
in
range
(
computed_len
,
prompt
_len
):
for
i
in
range
(
computed_len
,
seq
_len
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
...
...
@@ -151,19 +151,19 @@ class CPUModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
prompt_lens
=
prompt_lens
,
num_prefills
=
len
(
prompt_lens
),
seq_lens
=
seq_lens
,
seq_lens_tensor
=
None
,
max_seq_len
=
None
,
num_prefills
=
len
(
seq_lens
),
num_prefill_tokens
=
num_prompt_tokens
,
num_decode_tokens
=
0
,
prefill_metadata
=
None
,
decode_metadata
=
None
,
max_context_len
=
None
,
context_lens
=
None
,
block_tables
=
torch
.
tensor
([]),
slot_mapping
=
slot_mapping
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt
_lens
,
return
(
input_tokens
,
input_positions
,
attn_metadata
,
seq
_lens
,
multi_modal_input
)
def
_prepare_decode
(
...
...
@@ -174,7 +174,7 @@ class CPUModelRunner:
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
context
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
...
...
@@ -192,9 +192,9 @@ class CPUModelRunner:
position
=
seq_len
-
1
input_positions
.
append
(
position
)
context
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
context
_lens
.
append
(
context
_len
)
seq
_lens
.
append
(
seq
_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
...
...
@@ -208,7 +208,7 @@ class CPUModelRunner:
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
max_
context
_len
=
max
(
context
_lens
)
max_
seq
_len
=
max
(
seq
_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
...
...
@@ -219,9 +219,9 @@ class CPUModelRunner:
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context
_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
block_tables
)
...
...
@@ -236,14 +236,14 @@ class CPUModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
None
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_seq_len
=
max_seq_len
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
len
(
input_tokens
),
max_context_len
=
max_context_len
,
num_prefills
=
0
,
prefill_metadata
=
None
,
decode_metadata
=
None
,
context_lens
=
context_lens
,
block_tables
=
block_tables
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
...
...
@@ -265,20 +265,20 @@ class CPUModelRunner:
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt
_lens
,
(
input_tokens
,
input_positions
,
attn_metadata
,
seq
_lens
,
multi_modal_input
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt
_lens
=
[]
seq
_lens
=
[]
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt
_lens
,
#
sub
query_lens is not needed if chunked prefill is not
seq
_lens
,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use
prompt
_lens instead.
prompt
_lens
,
# just use
seq
_lens instead.
seq
_lens
,
self
.
device
,
pin_memory
=
False
)
# Broadcast the metadata.
...
...
@@ -300,7 +300,7 @@ class CPUModelRunner:
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt
_lens
=
None
,
seq
_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
generators
=
None
,
...
...
vllm/worker/model_runner.py
View file @
3521ba4f
...
...
@@ -42,8 +42,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens
:
List
[
int
]
input_positions
:
List
[
int
]
attn_metadata
:
Optional
[
AttentionMetadataPerStage
]
prompt
_lens
:
List
[
int
]
sub
query_lens
:
List
[
int
]
seq
_lens
:
List
[
int
]
query_lens
:
List
[
int
]
lora_index_mapping
:
List
[
int
]
lora_prompt_mapping
:
List
[
int
]
lora_requests
:
Set
[
LoRARequest
]
...
...
@@ -56,8 +56,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens
=
[],
input_positions
=
[],
attn_metadata
=
None
,
prompt
_lens
=
[],
sub
query_lens
=
[],
seq
_lens
=
[],
query_lens
=
[],
lora_index_mapping
=
[],
lora_prompt_mapping
=
[],
lora_requests
=
set
(),
...
...
@@ -134,9 +134,8 @@ class ModelRunner:
self
.
graph_memory_pool
:
Optional
[
Tuple
[
int
,
int
]]
=
None
# Set during graph capture.
self
.
max_context_len_to_capture
=
(
self
.
model_config
.
max_context_len_to_capture
if
self
.
model_config
is
not
None
else
0
)
self
.
max_seq_len_to_capture
=
(
self
.
model_config
.
max_seq_len_to_capture
if
self
.
model_config
is
not
None
else
0
)
self
.
pin_memory
=
is_pin_memory_available
()
self
.
kv_cache_dtype
=
kv_cache_dtype
...
...
@@ -149,7 +148,7 @@ class ModelRunner:
self
.
model
:
torch
.
nn
.
Module
# Set after load_model
self
.
block_size
:
int
# Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_
context
_len_to_capture. However, creating the block table in
# max_
seq
_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
...
...
@@ -218,7 +217,7 @@ class ModelRunner:
def
get_max_block_per_batch
(
self
)
->
int
:
block_size
=
self
.
block_size
return
(
self
.
max_
context
_len_to_capture
+
block_size
-
1
)
//
block_size
return
(
self
.
max_
seq
_len_to_capture
+
block_size
-
1
)
//
block_size
def
_prepare_prompt
(
self
,
...
...
@@ -231,9 +230,9 @@ class ModelRunner:
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
prompt
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
sub
query_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
...
...
@@ -257,21 +256,19 @@ class ModelRunner:
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
co
mputed
_len
=
seq_data
.
get_num_computed_tokens
()
co
ntext
_len
=
seq_data
.
get_num_computed_tokens
()
# We should use get_len here because in case of preemption
# it contains output tokens.
prefill_end
=
min
(
seq_data
.
get_len
(),
computed_len
+
token_chunk_size
)
prompt_tokens
=
seq_data
.
get_token_ids
()[
computed_len
:
prefill_end
]
prompt_len
=
prefill_end
prompt_lens
.
append
(
prompt_len
)
seq_len
=
min
(
seq_data
.
get_len
(),
context_len
+
token_chunk_size
)
prompt_tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
seq_lens
.
append
(
seq_len
)
# NOTE: This only works for oooooooxxx style attention.
if
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
:
# Prefix is not supported with sliding_window
co
mputed
_len
=
len
(
computed_block_nums
)
*
self
.
block_size
prompt_tokens
=
prompt_tokens
[
co
mputed
_len
:]
co
ntext
_len
=
len
(
computed_block_nums
)
*
self
.
block_size
prompt_tokens
=
prompt_tokens
[
co
ntext
_len
:]
prefix_block_tables
.
append
(
computed_block_nums
)
elif
self
.
scheduler_config
.
chunked_prefill_enabled
:
if
seq_group_metadata
.
block_tables
is
not
None
:
...
...
@@ -285,25 +282,25 @@ class ModelRunner:
prefix_block_tables
.
append
([])
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert
co
mputed
_len
==
0
assert
co
ntext
_len
==
0
# actual prompt lens
context_lens
.
append
(
co
mputed
_len
)
sub
query_lens
.
append
(
prompt
_len
-
co
mputed
_len
)
context_lens
.
append
(
co
ntext
_len
)
query_lens
.
append
(
seq
_len
-
co
ntext
_len
)
input_tokens
.
extend
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
co
mputed_len
,
prefill_
en
d
)))
input_positions
.
extend
(
list
(
range
(
co
ntext_len
,
seq_l
en
)))
lora_id
=
seq_group_metadata
.
lora_int_id
if
lora_id
>
0
:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_index_mapping
+=
[
lora_id
]
*
(
prompt
_len
-
co
mputed
_len
)
lora_index_mapping
+=
[
lora_id
]
*
(
seq
_len
-
co
ntext
_len
)
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
(
prompt
_len
-
co
mputed
_len
(
seq
_len
-
co
ntext
_len
if
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
if
seq_group_metadata
.
multi_modal_data
:
...
...
@@ -313,24 +310,24 @@ class ModelRunner:
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
prompt
_len
)
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq
_len
)
continue
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0,
prompt
_len - sliding_window).
# where start_idx is max(0,
seq
_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
assert
co
mputed
_len
==
0
,
(
assert
co
ntext
_len
==
0
,
(
"Prefix caching is currently not supported with "
"sliding window attention"
)
start_idx
=
max
(
0
,
prompt
_len
-
self
.
sliding_window
)
start_idx
=
max
(
0
,
seq
_len
-
self
.
sliding_window
)
for
i
in
range
(
co
mputed_len
,
prefill_
en
d
):
for
i
in
range
(
co
ntext_len
,
seq_l
en
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
...
...
@@ -340,9 +337,9 @@ class ModelRunner:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
max_
sub
query_len
=
max
(
sub
query_lens
)
max_
prompt
_len
=
max
(
prompt
_lens
)
assert
max_
sub
query_len
>
0
max_query_len
=
max
(
query_lens
)
max_
seq
_len
=
max
(
seq
_lens
)
assert
max_query_len
>
0
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
...
...
@@ -369,40 +366,39 @@ class ModelRunner:
# Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached.
sub
query_lens_tensor
=
torch
.
tensor
(
sub
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
subquery_start_loc
=
torch
.
zeros
(
sub
query_lens_tensor
.
shape
[
0
]
+
1
,
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
subquery_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
prompt
_lens_tensor
=
torch
.
tensor
(
prompt
_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
seq_start_loc
=
torch
.
zeros
(
prompt
_lens_tensor
.
shape
[
0
]
+
1
,
seq
_lens_tensor
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
seq_start_loc
=
torch
.
zeros
(
seq
_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
sub
query_lens_tensor
,
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
subquery_start_loc
.
dtype
,
out
=
subquery_start_loc
[
1
:])
torch
.
cumsum
(
prompt
_lens_tensor
,
torch
.
cumsum
(
seq
_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
prompt_lens
=
prompt_lens
,
prompt_lens_tensor
=
prompt_lens_tensor
,
max_subquery_len
=
max_subquery_len
,
max_context_len
=
None
,
max_prompt_len
=
max_prompt_len
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
subquery_start_loc
=
subquery_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens
=
context_lens_tensor
,
context_lens
_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
)
...
...
@@ -411,8 +407,8 @@ class ModelRunner:
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
prompt_lens
=
prompt
_lens
,
sub
query_lens
=
sub
query_lens
,
seq_lens
=
seq
_lens
,
query_lens
=
query_lens
,
lora_index_mapping
=
lora_index_mapping
,
lora_prompt_mapping
=
lora_prompt_mapping
,
lora_requests
=
lora_requests
,
...
...
@@ -427,7 +423,7 @@ class ModelRunner:
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
context
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
lora_index_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
...
...
@@ -455,9 +451,9 @@ class ModelRunner:
position
=
seq_len
-
1
input_positions
.
append
(
position
)
context
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
context
_lens
.
append
(
context
_len
)
seq
_lens
.
append
(
seq
_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
...
...
@@ -477,11 +473,10 @@ class ModelRunner:
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
batch_size
=
len
(
input_tokens
)
max_context_len
=
max
(
context_lens
)
use_captured_graph
=
(
not
self
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_context_len
<=
self
.
max_context_len_to_capture
)
max_seq_len
=
max
(
seq_lens
)
use_captured_graph
=
(
not
self
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_seq_len
<=
self
.
max_seq_len_to_capture
)
if
use_captured_graph
:
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
...
...
@@ -489,21 +484,21 @@ class ModelRunner:
input_tokens
.
append
(
0
)
input_positions
.
append
(
0
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
context
_lens
.
append
(
1
)
seq
_lens
.
append
(
1
)
block_tables
.
append
([])
lora_index_mapping
.
append
(
0
)
batch_size
=
graph_batch_size
context
_lens_tensor
=
torch
.
tensor
(
context
_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
seq
_lens_tensor
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
if
use_captured_graph
:
# When using cuda-graph all these tensors should be
# padded.
assert
context
_lens_tensor
.
shape
[
0
]
==
len
(
input_tokens
)
assert
context
_lens_tensor
.
shape
[
0
]
==
len
(
input_positions
)
assert
context
_lens_tensor
.
shape
[
0
]
==
len
(
slot_mapping
)
assert
seq
_lens_tensor
.
shape
[
0
]
==
len
(
input_tokens
)
assert
seq
_lens_tensor
.
shape
[
0
]
==
len
(
input_positions
)
assert
seq
_lens_tensor
.
shape
[
0
]
==
len
(
slot_mapping
)
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
...
...
@@ -525,14 +520,13 @@ class ModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
prompt_lens
=
None
,
prompt_lens_tensor
=
None
,
max_subquery_len
=
None
,
max_context_len
=
max_context_len
,
max_prompt_len
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
None
,
max_seq_len
=
max_seq_len
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens_tensor
,
context_lens_tensor
=
None
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
...
...
@@ -565,8 +559,8 @@ class ModelRunner:
input_tokens
,
input_positions
,
prefill_attn_metadata
,
prompt
_lens
,
sub
query_lens
,
seq
_lens
,
query_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
,
...
...
@@ -583,13 +577,13 @@ class ModelRunner:
decode_slot_mapping
,
)
=
self
.
_prepare_decode
(
decode_reqs
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt
_lens
,
sub
query_lens
,
self
.
device
,
self
.
pin_memory
)
seq_group_metadata_list
,
seq
_lens
,
query_lens
,
self
.
device
,
self
.
pin_memory
)
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
:
assert
(
len
(
prefill_reqs
)
and
len
(
decode_reqs
))
==
0
num_prefills
=
len
(
prompt
_lens
)
num_prefills
=
len
(
seq
_lens
)
num_prefill_tokens
=
len
(
input_tokens
)
num_decode_tokens
=
len
(
decode_input_tokens
)
...
...
@@ -886,7 +880,7 @@ class ModelRunner:
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
.
fill_
(
_PAD_SLOT_ID
)
context
_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
seq
_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
graph_batch_size
=
_get_graph_batch_size
(
...
...
@@ -908,14 +902,13 @@ class ModelRunner:
# Create dummy attn_metadata.
decode_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
prompt_lens
=
None
,
prompt_lens_tensor
=
None
,
max_subquery_len
=
None
,
max_context_len
=
self
.
max_context_len_to_capture
,
max_prompt_len
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens
[:
batch_size
],
max_query_len
=
None
,
max_seq_len
=
self
.
max_seq_len_to_capture
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens
[:
batch_size
]
,
context_lens
_tensor
=
None
,
block_tables
=
block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
)
...
...
@@ -1025,7 +1018,7 @@ class CUDAGraphRunner:
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"
context_lens
"
:
attn_metadata
.
decode_metadata
.
context_lens
,
"
seq_lens_tensor
"
:
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
...
...
@@ -1047,8 +1040,8 @@ class CUDAGraphRunner:
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
non_blocking
=
True
)
self
.
input_buffers
[
"
context_lens
"
].
copy_
(
attn_metadata
.
decode_metadata
.
context_lens
,
non_blocking
=
True
)
self
.
input_buffers
[
"
seq_lens_tensor
"
].
copy_
(
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
non_blocking
=
True
)
self
.
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
# Run the graph.
...
...
vllm/worker/neuron_model_runner.py
View file @
3521ba4f
...
...
@@ -52,7 +52,7 @@ class NeuronModelRunner:
input_positions
:
List
[
List
[
int
]]
=
[]
input_block_ids
:
List
[
int
]
=
[]
prompt
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
...
...
@@ -61,26 +61,26 @@ class NeuronModelRunner:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt
_len
=
len
(
prompt_tokens
)
prompt
_lens
.
append
(
prompt
_len
)
seq
_len
=
len
(
prompt_tokens
)
seq
_lens
.
append
(
seq
_len
)
input_tokens
.
append
(
prompt_tokens
)
input_positions
.
append
(
list
(
range
(
prompt
_len
)))
input_positions
.
append
(
list
(
range
(
seq
_len
)))
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
assert
len
(
block_table
)
==
1
input_block_ids
.
append
(
block_table
[
0
])
max_
prompt
_len
=
max
(
prompt
_lens
)
assert
max_
prompt
_len
>
0
max_
seq
_len
=
max
(
seq
_lens
)
assert
max_
seq
_len
>
0
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
max_
prompt
_len
,
max_
seq
_len
,
pad
=
0
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
max_
prompt
_len
,
max_
seq
_len
,
pad
=
0
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
...
...
@@ -88,7 +88,7 @@ class NeuronModelRunner:
dtype
=
torch
.
long
,
device
=
self
.
device
)
return
input_tokens
,
input_positions
,
input_block_ids
,
prompt
_lens
return
input_tokens
,
input_positions
,
input_block_ids
,
seq
_lens
def
_prepare_decode
(
self
,
...
...
@@ -149,18 +149,18 @@ class NeuronModelRunner:
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
input_block_ids
,
prompt
_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
seq
_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
input_block_ids
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt
_lens
=
[]
seq
_lens
=
[]
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt
_lens
,
#
sub
query_lens is not needed if chunked prefill is not
seq
_lens
,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use
prompt
_lens instead.
prompt
_lens
,
# just use
seq
_lens instead.
seq
_lens
,
self
.
device
,
self
.
pin_memory
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment