Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3521ba4f
Unverified
Commit
3521ba4f
authored
May 04, 2024
by
SangBin Cho
Committed by
GitHub
May 03, 2024
Browse files
[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
parent
2d7bce9c
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
181 additions
and
164 deletions
+181
-164
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+12
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+6
-1
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+3
-3
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+36
-27
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+29
-29
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+80
-87
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+15
-15
No files found.
vllm/engine/arg_utils.py
View file @
3521ba4f
...
@@ -44,7 +44,8 @@ class EngineArgs:
...
@@ -44,7 +44,8 @@ class EngineArgs:
tokenizer_revision
:
Optional
[
str
]
=
None
tokenizer_revision
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
enforce_eager
:
bool
=
False
enforce_eager
:
bool
=
False
max_context_len_to_capture
:
int
=
8192
max_context_len_to_capture
:
Optional
[
int
]
=
None
max_seq_len_to_capture
:
int
=
8192
disable_custom_all_reduce
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
tokenizer_pool_size
:
int
=
0
tokenizer_pool_size
:
int
=
0
tokenizer_pool_type
:
str
=
"ray"
tokenizer_pool_type
:
str
=
"ray"
...
@@ -322,6 +323,14 @@ class EngineArgs:
...
@@ -322,6 +323,14 @@ class EngineArgs:
default
=
EngineArgs
.
max_context_len_to_capture
,
default
=
EngineArgs
.
max_context_len_to_capture
,
help
=
'Maximum context length covered by CUDA '
help
=
'Maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq_len-to-capture instead'
')'
)
parser
.
add_argument
(
'--max-seq_len-to-capture'
,
type
=
int
,
default
=
EngineArgs
.
max_seq_len_to_capture
,
help
=
'Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.'
)
'larger than this, we fall back to eager mode.'
)
parser
.
add_argument
(
'--disable-custom-all-reduce'
,
parser
.
add_argument
(
'--disable-custom-all-reduce'
,
action
=
'store_true'
,
action
=
'store_true'
,
...
@@ -492,7 +501,8 @@ class EngineArgs:
...
@@ -492,7 +501,8 @@ class EngineArgs:
self
.
code_revision
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
code_revision
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
quantization
,
self
.
quantization_param_path
,
self
.
quantization
,
self
.
quantization_param_path
,
self
.
enforce_eager
,
self
.
max_context_len_to_capture
,
self
.
enforce_eager
,
self
.
max_context_len_to_capture
,
self
.
max_logprobs
,
self
.
skip_tokenizer_init
)
self
.
max_seq_len_to_capture
,
self
.
max_logprobs
,
self
.
skip_tokenizer_init
)
cache_config
=
CacheConfig
(
self
.
block_size
,
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
gpu_memory_utilization
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
...
...
vllm/entrypoints/llm.py
View file @
3521ba4f
...
@@ -69,6 +69,9 @@ class LLM:
...
@@ -69,6 +69,9 @@ class LLM:
disable CUDA graph and always execute the model in eager mode.
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
When a sequence has context length larger than this, we fall back
to eager mode.
to eager mode.
disable_custom_all_reduce: See ParallelConfig
disable_custom_all_reduce: See ParallelConfig
...
@@ -90,7 +93,8 @@ class LLM:
...
@@ -90,7 +93,8 @@ class LLM:
gpu_memory_utilization
:
float
=
0.9
,
gpu_memory_utilization
:
float
=
0.9
,
swap_space
:
int
=
4
,
swap_space
:
int
=
4
,
enforce_eager
:
bool
=
False
,
enforce_eager
:
bool
=
False
,
max_context_len_to_capture
:
int
=
8192
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
...
@@ -112,6 +116,7 @@ class LLM:
...
@@ -112,6 +116,7 @@ class LLM:
swap_space
=
swap_space
,
swap_space
=
swap_space
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
enforce_eager
,
max_context_len_to_capture
=
max_context_len_to_capture
,
max_context_len_to_capture
=
max_context_len_to_capture
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
**
kwargs
,
**
kwargs
,
)
)
...
...
vllm/model_executor/layers/sampler.py
View file @
3521ba4f
...
@@ -1033,8 +1033,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
...
@@ -1033,8 +1033,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
assert
seq_group
.
is_prompt
,
(
assert
seq_group
.
is_prompt
,
(
"Caller should ensure the sequence group is in a prefill stage."
)
"Caller should ensure the sequence group is in a prefill stage."
)
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
sub
query_len
=
seq_group
.
sub
query_len
query_len
=
seq_group
.
query_len
assert
sub
query_len
is
not
None
assert
query_len
is
not
None
# prompt has only 1 seq id.
# prompt has only 1 seq id.
assert
len
(
seq_ids
)
==
1
assert
len
(
seq_ids
)
==
1
seq_data
=
seq_group
.
seq_data
[
seq_ids
[
0
]]
seq_data
=
seq_group
.
seq_data
[
seq_ids
[
0
]]
...
@@ -1042,7 +1042,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
...
@@ -1042,7 +1042,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
prompt_tokens
=
seq_data
.
prompt_token_ids
prompt_tokens
=
seq_data
.
prompt_token_ids
# +1 because we are looking for a next prompt token.
# +1 because we are looking for a next prompt token.
next_token_index_start
=
computed_len
+
1
next_token_index_start
=
computed_len
+
1
next_token_index_end
=
min
(
computed_len
+
sub
query_len
+
1
,
next_token_index_end
=
min
(
computed_len
+
query_len
+
1
,
len
(
prompt_tokens
))
len
(
prompt_tokens
))
next_prompt_tokens
=
prompt_tokens
[
next_prompt_tokens
=
prompt_tokens
[
next_token_index_start
:
next_token_index_end
]
next_token_index_start
:
next_token_index_end
]
...
...
vllm/model_executor/sampling_metadata.py
View file @
3521ba4f
...
@@ -16,17 +16,26 @@ _SEED_0_REPLACEMENT = 3403598558
...
@@ -16,17 +16,26 @@ _SEED_0_REPLACEMENT = 3403598558
@
dataclass
@
dataclass
class
SequenceGroupToSample
:
class
SequenceGroupToSample
:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Sequence ids for the sequence group in a previous step.
# Sequence ids for the sequence group in a previous step.
seq_ids
:
List
[
int
]
seq_ids
:
List
[
int
]
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
# seq_id -> sequence data.
# seq_id -> sequence data.
seq_data
:
Dict
[
int
,
SequenceData
]
seq_data
:
Dict
[
int
,
SequenceData
]
# The length of the prompt of the sequence group. None if it is in a decode
# The length of the sequence (all tokens seen in the past + new token to
# compute attention) of the sequence group. None if it is in a decode
# stage.
# stage.
prompt_len
:
Optional
[
int
]
seq_len
:
Optional
[
int
]
# The length of the query tokens to compute in the current step. None if it
# The length of new query tokens to compute in the current step. None if it
# is in a decode stage. The length of subquery_len <= prompt_len.
# is in a decode stage. The length of query_len <= seq_len if chunked
subquery_len
:
Optional
[
int
]
# prefill is enabled.
query_len
:
Optional
[
int
]
# A random number generator for sampling.
# A random number generator for sampling.
generator
:
Optional
[
torch
.
Generator
]
generator
:
Optional
[
torch
.
Generator
]
# True if the sequence group is in prefill stage. False if it is in a
# True if the sequence group is in prefill stage. False if it is in a
...
@@ -46,8 +55,8 @@ class SequenceGroupToSample:
...
@@ -46,8 +55,8 @@ class SequenceGroupToSample:
if
len
(
self
.
prompt_logprob_indices
)
>
0
:
if
len
(
self
.
prompt_logprob_indices
)
>
0
:
assert
self
.
sampling_params
.
prompt_logprobs
is
not
None
assert
self
.
sampling_params
.
prompt_logprobs
is
not
None
if
self
.
is_prompt
:
if
self
.
is_prompt
:
assert
self
.
prompt
_len
is
not
None
assert
self
.
seq
_len
is
not
None
assert
self
.
sub
query_len
is
not
None
assert
self
.
query_len
is
not
None
class
SamplingMetadata
:
class
SamplingMetadata
:
...
@@ -94,8 +103,8 @@ class SamplingMetadata:
...
@@ -94,8 +103,8 @@ class SamplingMetadata:
@
staticmethod
@
staticmethod
def
prepare
(
def
prepare
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
sub
query_lens
:
Optional
[
List
[
int
]],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
device
:
str
,
pin_memory
:
bool
,
pin_memory
:
bool
,
)
->
"SamplingMetadata"
:
)
->
"SamplingMetadata"
:
...
@@ -104,8 +113,8 @@ class SamplingMetadata:
...
@@ -104,8 +113,8 @@ class SamplingMetadata:
selected_token_indices
,
selected_token_indices
,
categorized_sample_indices
,
categorized_sample_indices
,
num_prompts
,
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
prompt
_lens
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query
_lens
,
subquery_lens
,
device
)
device
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
target_device
=
device
,
target_device
=
device
,
...
@@ -137,8 +146,8 @@ class SamplingMetadata:
...
@@ -137,8 +146,8 @@ class SamplingMetadata:
def
_prepare_seq_groups
(
def
_prepare_seq_groups
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
sub
query_lens
:
Optional
[
List
[
int
]],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
device
:
str
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
...
@@ -146,9 +155,9 @@ def _prepare_seq_groups(
...
@@ -146,9 +155,9 @@ def _prepare_seq_groups(
Args:
Args:
seq_group_metadata_list: A list of sequence group to batch.
seq_group_metadata_list: A list of sequence group to batch.
prompt
_lens: A list of
prompt
lens per sequence group.
seq
_lens: A list of
sequence
lens per sequence group.
Index of prompt len should match with seq_group_metadata_list.
Index of prompt len should match with seq_group_metadata_list.
sub
query_lens: A list of query lengths. Prompt lens include the length
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
device: A device to use for random number generator,
`SequenceGroupToSample.generator`.
`SequenceGroupToSample.generator`.
...
@@ -189,8 +198,8 @@ def _prepare_seq_groups(
...
@@ -189,8 +198,8 @@ def _prepare_seq_groups(
is_prompt
=
seq_group_metadata
.
is_prompt
is_prompt
=
seq_group_metadata
.
is_prompt
generator
:
Optional
[
torch
.
Generator
]
=
None
generator
:
Optional
[
torch
.
Generator
]
=
None
# If the current seq group is in decode stage, it is None.
# If the current seq group is in decode stage, it is None.
prompt
_len
:
Optional
[
int
]
=
None
seq
_len
:
Optional
[
int
]
=
None
sub
query_len
:
Optional
[
int
]
=
None
query_len
:
Optional
[
int
]
=
None
prompt_logprob_indices
:
List
[
int
]
=
[]
prompt_logprob_indices
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
do_sample
=
seq_group_metadata
.
do_sample
do_sample
=
seq_group_metadata
.
do_sample
...
@@ -203,12 +212,12 @@ def _prepare_seq_groups(
...
@@ -203,12 +212,12 @@ def _prepare_seq_groups(
num_prompts
+=
1
num_prompts
+=
1
num_prefill_sample
=
len
(
seq_ids
)
num_prefill_sample
=
len
(
seq_ids
)
assert
num_prefill_sample
==
1
assert
num_prefill_sample
==
1
assert
sub
query_lens
is
not
None
and
prompt
_lens
is
not
None
assert
query_lens
is
not
None
and
seq
_lens
is
not
None
sub
query_len
,
prompt
_len
=
sub
query_lens
[
i
],
prompt
_lens
[
i
]
query_len
,
seq
_len
=
query_lens
[
i
],
seq
_lens
[
i
]
# If we need sampling, exclude num_prefill_sample tokens from
# If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob.
# prompt logprob.
prompt_logprob_len
=
(
sub
query_len
-
num_prefill_sample
prompt_logprob_len
=
(
query_len
-
num_prefill_sample
if
do_sample
else
sub
query_len
)
if
do_sample
else
query_len
)
sample_len
=
num_prefill_sample
if
do_sample
else
0
sample_len
=
num_prefill_sample
if
do_sample
else
0
else
:
else
:
# Decode
# Decode
...
@@ -267,8 +276,8 @@ def _prepare_seq_groups(
...
@@ -267,8 +276,8 @@ def _prepare_seq_groups(
seq_ids
=
seq_ids
,
seq_ids
=
seq_ids
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
seq_data
=
seq_group_metadata
.
seq_data
,
seq_data
=
seq_group_metadata
.
seq_data
,
prompt_len
=
prompt
_len
,
seq_len
=
seq
_len
,
sub
query_len
=
sub
query_len
,
query_len
=
query_len
,
generator
=
generator
,
generator
=
generator
,
is_prompt
=
is_prompt
,
is_prompt
=
is_prompt
,
prompt_logprob_indices
=
list
(
prompt_logprob_indices
),
prompt_logprob_indices
=
list
(
prompt_logprob_indices
),
...
@@ -367,8 +376,8 @@ class SamplingTensors:
...
@@ -367,8 +376,8 @@ class SamplingTensors:
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get
# For tokens in the prompt that we only need to get
# their logprobs
# their logprobs
sub
query_len
=
seq_group
.
sub
query_len
query_len
=
seq_group
.
query_len
assert
sub
query_len
is
not
None
assert
query_len
is
not
None
prefill_len
=
len
(
seq_group
.
prompt_logprob_indices
)
prefill_len
=
len
(
seq_group
.
prompt_logprob_indices
)
temperatures
+=
[
temperature
]
*
prefill_len
temperatures
+=
[
temperature
]
*
prefill_len
top_ps
+=
[
top_p
]
*
prefill_len
top_ps
+=
[
top_p
]
*
prefill_len
...
@@ -397,8 +406,8 @@ class SamplingTensors:
...
@@ -397,8 +406,8 @@ class SamplingTensors:
if
is_prompt
:
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
prompt_best_of
.
append
(
sampling_params
.
best_of
)
sub
query_len
=
seq_group
.
sub
query_len
query_len
=
seq_group
.
query_len
assert
sub
query_len
is
not
None
assert
query_len
is
not
None
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
seq_data
=
seq_group
.
seq_data
[
seq_id
]
...
...
vllm/worker/cpu_model_runner.py
View file @
3521ba4f
...
@@ -80,7 +80,7 @@ class CPUModelRunner:
...
@@ -80,7 +80,7 @@ class CPUModelRunner:
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
prompt
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
...
@@ -92,15 +92,15 @@ class CPUModelRunner:
...
@@ -92,15 +92,15 @@ class CPUModelRunner:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt_tokens
=
seq_data
.
get_token_ids
()
computed_len
=
seq_data
.
get_num_computed_tokens
()
computed_len
=
seq_data
.
get_num_computed_tokens
()
prompt
_len
=
len
(
prompt_tokens
)
seq
_len
=
len
(
prompt_tokens
)
prompt
_lens
.
append
(
prompt
_len
)
# Prompt token num
seq
_lens
.
append
(
seq
_len
)
# Prompt token num
input_tokens
.
extend
(
prompt_tokens
)
# Token ids
input_tokens
.
extend
(
prompt_tokens
)
# Token ids
# Token position ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
prompt
_len
)))
input_positions
.
extend
(
list
(
range
(
computed_len
,
seq
_len
)))
if
seq_group_metadata
.
multi_modal_data
:
if
seq_group_metadata
.
multi_modal_data
:
multi_modal_input_list
.
append
(
multi_modal_input_list
.
append
(
...
@@ -109,15 +109,15 @@ class CPUModelRunner:
...
@@ -109,15 +109,15 @@ class CPUModelRunner:
# Compute the slot mapping.
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0,
prompt
_len - sliding_window).
# where start_idx is max(0,
seq
_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
if
self
.
sliding_window
is
not
None
:
start_idx
=
max
(
0
,
prompt
_len
-
self
.
sliding_window
)
start_idx
=
max
(
0
,
seq
_len
-
self
.
sliding_window
)
for
i
in
range
(
computed_len
,
prompt
_len
):
for
i
in
range
(
computed_len
,
seq
_len
):
if
i
<
start_idx
:
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
continue
...
@@ -151,19 +151,19 @@ class CPUModelRunner:
...
@@ -151,19 +151,19 @@ class CPUModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
is_prompt
=
True
,
prompt_lens
=
prompt_lens
,
seq_lens
=
seq_lens
,
num_prefills
=
len
(
prompt_lens
),
seq_lens_tensor
=
None
,
max_seq_len
=
None
,
num_prefills
=
len
(
seq_lens
),
num_prefill_tokens
=
num_prompt_tokens
,
num_prefill_tokens
=
num_prompt_tokens
,
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
prefill_metadata
=
None
,
prefill_metadata
=
None
,
decode_metadata
=
None
,
decode_metadata
=
None
,
max_context_len
=
None
,
context_lens
=
None
,
block_tables
=
torch
.
tensor
([]),
block_tables
=
torch
.
tensor
([]),
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt
_lens
,
return
(
input_tokens
,
input_positions
,
attn_metadata
,
seq
_lens
,
multi_modal_input
)
multi_modal_input
)
def
_prepare_decode
(
def
_prepare_decode
(
...
@@ -174,7 +174,7 @@ class CPUModelRunner:
...
@@ -174,7 +174,7 @@ class CPUModelRunner:
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
context
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
...
@@ -192,9 +192,9 @@ class CPUModelRunner:
...
@@ -192,9 +192,9 @@ class CPUModelRunner:
position
=
seq_len
-
1
position
=
seq_len
-
1
input_positions
.
append
(
position
)
input_positions
.
append
(
position
)
context
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
seq_len
,
self
.
sliding_window
)
context
_lens
.
append
(
context
_len
)
seq
_lens
.
append
(
seq
_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_number
=
block_table
[
position
//
self
.
block_size
]
...
@@ -208,7 +208,7 @@ class CPUModelRunner:
...
@@ -208,7 +208,7 @@ class CPUModelRunner:
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
block_tables
.
append
(
block_table
)
max_
context
_len
=
max
(
context
_lens
)
max_
seq
_len
=
max
(
seq
_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
...
@@ -219,7 +219,7 @@ class CPUModelRunner:
...
@@ -219,7 +219,7 @@ class CPUModelRunner:
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context
_lens
,
seq_lens_tensor
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
device
=
self
.
device
)
...
@@ -236,14 +236,14 @@ class CPUModelRunner:
...
@@ -236,14 +236,14 @@ class CPUModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
None
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_seq_len
=
max_seq_len
,
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
len
(
input_tokens
),
num_decode_tokens
=
len
(
input_tokens
),
max_context_len
=
max_context_len
,
num_prefills
=
0
,
num_prefills
=
0
,
prefill_metadata
=
None
,
prefill_metadata
=
None
,
decode_metadata
=
None
,
decode_metadata
=
None
,
context_lens
=
context_lens
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
...
@@ -265,20 +265,20 @@ class CPUModelRunner:
...
@@ -265,20 +265,20 @@ class CPUModelRunner:
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
# Prepare input tensors.
if
is_prompt
:
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt
_lens
,
(
input_tokens
,
input_positions
,
attn_metadata
,
seq
_lens
,
multi_modal_input
multi_modal_input
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
else
:
(
input_tokens
,
input_positions
,
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt
_lens
=
[]
seq
_lens
=
[]
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_group_metadata_list
,
prompt
_lens
,
seq
_lens
,
#
sub
query_lens is not needed if chunked prefill is not
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# supported. Since CPU worker doesn't support chunked prefill
# just use
prompt
_lens instead.
# just use
seq
_lens instead.
prompt
_lens
,
seq
_lens
,
self
.
device
,
self
.
device
,
pin_memory
=
False
)
pin_memory
=
False
)
# Broadcast the metadata.
# Broadcast the metadata.
...
@@ -300,7 +300,7 @@ class CPUModelRunner:
...
@@ -300,7 +300,7 @@ class CPUModelRunner:
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_groups
=
None
,
seq_data
=
None
,
seq_data
=
None
,
prompt
_lens
=
None
,
seq
_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
categorized_sample_indices
=
None
,
generators
=
None
,
generators
=
None
,
...
...
vllm/worker/model_runner.py
View file @
3521ba4f
...
@@ -42,8 +42,8 @@ class PreparePromptMetadata(NamedTuple):
...
@@ -42,8 +42,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens
:
List
[
int
]
input_tokens
:
List
[
int
]
input_positions
:
List
[
int
]
input_positions
:
List
[
int
]
attn_metadata
:
Optional
[
AttentionMetadataPerStage
]
attn_metadata
:
Optional
[
AttentionMetadataPerStage
]
prompt
_lens
:
List
[
int
]
seq
_lens
:
List
[
int
]
sub
query_lens
:
List
[
int
]
query_lens
:
List
[
int
]
lora_index_mapping
:
List
[
int
]
lora_index_mapping
:
List
[
int
]
lora_prompt_mapping
:
List
[
int
]
lora_prompt_mapping
:
List
[
int
]
lora_requests
:
Set
[
LoRARequest
]
lora_requests
:
Set
[
LoRARequest
]
...
@@ -56,8 +56,8 @@ class PreparePromptMetadata(NamedTuple):
...
@@ -56,8 +56,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens
=
[],
input_tokens
=
[],
input_positions
=
[],
input_positions
=
[],
attn_metadata
=
None
,
attn_metadata
=
None
,
prompt
_lens
=
[],
seq
_lens
=
[],
sub
query_lens
=
[],
query_lens
=
[],
lora_index_mapping
=
[],
lora_index_mapping
=
[],
lora_prompt_mapping
=
[],
lora_prompt_mapping
=
[],
lora_requests
=
set
(),
lora_requests
=
set
(),
...
@@ -134,8 +134,7 @@ class ModelRunner:
...
@@ -134,8 +134,7 @@ class ModelRunner:
self
.
graph_memory_pool
:
Optional
[
Tuple
[
self
.
graph_memory_pool
:
Optional
[
Tuple
[
int
,
int
]]
=
None
# Set during graph capture.
int
,
int
]]
=
None
# Set during graph capture.
self
.
max_context_len_to_capture
=
(
self
.
max_seq_len_to_capture
=
(
self
.
model_config
.
max_seq_len_to_capture
self
.
model_config
.
max_context_len_to_capture
if
self
.
model_config
is
not
None
else
0
)
if
self
.
model_config
is
not
None
else
0
)
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
...
@@ -149,7 +148,7 @@ class ModelRunner:
...
@@ -149,7 +148,7 @@ class ModelRunner:
self
.
model
:
torch
.
nn
.
Module
# Set after load_model
self
.
model
:
torch
.
nn
.
Module
# Set after load_model
self
.
block_size
:
int
# Set after initial profiling.
self
.
block_size
:
int
# Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# When using CUDA graph, the input block tables must be padded to
# max_
context
_len_to_capture. However, creating the block table in
# max_
seq
_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# The shape of the cached block table will be
...
@@ -218,7 +217,7 @@ class ModelRunner:
...
@@ -218,7 +217,7 @@ class ModelRunner:
def
get_max_block_per_batch
(
self
)
->
int
:
def
get_max_block_per_batch
(
self
)
->
int
:
block_size
=
self
.
block_size
block_size
=
self
.
block_size
return
(
self
.
max_
context
_len_to_capture
+
block_size
-
1
)
//
block_size
return
(
self
.
max_
seq
_len_to_capture
+
block_size
-
1
)
//
block_size
def
_prepare_prompt
(
def
_prepare_prompt
(
self
,
self
,
...
@@ -231,9 +230,9 @@ class ModelRunner:
...
@@ -231,9 +230,9 @@ class ModelRunner:
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
lora_requests
:
Set
[
LoRARequest
]
=
set
()
prompt
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
sub
query_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
...
@@ -257,21 +256,19 @@ class ModelRunner:
...
@@ -257,21 +256,19 @@ class ModelRunner:
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
co
mputed
_len
=
seq_data
.
get_num_computed_tokens
()
co
ntext
_len
=
seq_data
.
get_num_computed_tokens
()
# We should use get_len here because in case of preemption
# We should use get_len here because in case of preemption
# it contains output tokens.
# it contains output tokens.
prefill_end
=
min
(
seq_data
.
get_len
(),
seq_len
=
min
(
seq_data
.
get_len
(),
context_len
+
token_chunk_size
)
computed_len
+
token_chunk_size
)
prompt_tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
prompt_tokens
=
seq_data
.
get_token_ids
()[
computed_len
:
prefill_end
]
seq_lens
.
append
(
seq_len
)
prompt_len
=
prefill_end
prompt_lens
.
append
(
prompt_len
)
# NOTE: This only works for oooooooxxx style attention.
# NOTE: This only works for oooooooxxx style attention.
if
computed_block_nums
is
not
None
and
len
(
if
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
:
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
:
# Prefix is not supported with sliding_window
# Prefix is not supported with sliding_window
co
mputed
_len
=
len
(
computed_block_nums
)
*
self
.
block_size
co
ntext
_len
=
len
(
computed_block_nums
)
*
self
.
block_size
prompt_tokens
=
prompt_tokens
[
co
mputed
_len
:]
prompt_tokens
=
prompt_tokens
[
co
ntext
_len
:]
prefix_block_tables
.
append
(
computed_block_nums
)
prefix_block_tables
.
append
(
computed_block_nums
)
elif
self
.
scheduler_config
.
chunked_prefill_enabled
:
elif
self
.
scheduler_config
.
chunked_prefill_enabled
:
if
seq_group_metadata
.
block_tables
is
not
None
:
if
seq_group_metadata
.
block_tables
is
not
None
:
...
@@ -285,25 +282,25 @@ class ModelRunner:
...
@@ -285,25 +282,25 @@ class ModelRunner:
prefix_block_tables
.
append
([])
prefix_block_tables
.
append
([])
# Right now, prefill start is always 0. However, this
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
# assumption can be changed once chunked prefill is introduced.
assert
co
mputed
_len
==
0
assert
co
ntext
_len
==
0
# actual prompt lens
# actual prompt lens
context_lens
.
append
(
co
mputed
_len
)
context_lens
.
append
(
co
ntext
_len
)
sub
query_lens
.
append
(
prompt
_len
-
co
mputed
_len
)
query_lens
.
append
(
seq
_len
-
co
ntext
_len
)
input_tokens
.
extend
(
prompt_tokens
)
input_tokens
.
extend
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
co
mputed_len
,
prefill_
en
d
)))
input_positions
.
extend
(
list
(
range
(
co
ntext_len
,
seq_l
en
)))
lora_id
=
seq_group_metadata
.
lora_int_id
lora_id
=
seq_group_metadata
.
lora_int_id
if
lora_id
>
0
:
if
lora_id
>
0
:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_index_mapping
+=
[
lora_id
]
*
(
prompt
_len
-
co
mputed
_len
)
lora_index_mapping
+=
[
lora_id
]
*
(
seq
_len
-
co
ntext
_len
)
lora_prompt_mapping
.
extend
(
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
[
lora_id
]
*
(
prompt
_len
-
co
mputed
_len
(
seq
_len
-
co
ntext
_len
if
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
if
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
if
seq_group_metadata
.
multi_modal_data
:
if
seq_group_metadata
.
multi_modal_data
:
...
@@ -313,24 +310,24 @@ class ModelRunner:
...
@@ -313,24 +310,24 @@ class ModelRunner:
if
seq_group_metadata
.
block_tables
is
None
:
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
prompt
_len
)
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq
_len
)
continue
continue
# Compute the slot mapping.
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0,
prompt
_len - sliding_window).
# where start_idx is max(0,
seq
_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
if
self
.
sliding_window
is
not
None
:
assert
co
mputed
_len
==
0
,
(
assert
co
ntext
_len
==
0
,
(
"Prefix caching is currently not supported with "
"Prefix caching is currently not supported with "
"sliding window attention"
)
"sliding window attention"
)
start_idx
=
max
(
0
,
prompt
_len
-
self
.
sliding_window
)
start_idx
=
max
(
0
,
seq
_len
-
self
.
sliding_window
)
for
i
in
range
(
co
mputed_len
,
prefill_
en
d
):
for
i
in
range
(
co
ntext_len
,
seq_l
en
):
if
i
<
start_idx
:
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
continue
...
@@ -340,9 +337,9 @@ class ModelRunner:
...
@@ -340,9 +337,9 @@ class ModelRunner:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
max_
sub
query_len
=
max
(
sub
query_lens
)
max_query_len
=
max
(
query_lens
)
max_
prompt
_len
=
max
(
prompt
_lens
)
max_
seq
_len
=
max
(
seq
_lens
)
assert
max_
sub
query_len
>
0
assert
max_query_len
>
0
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
...
@@ -369,40 +366,39 @@ class ModelRunner:
...
@@ -369,40 +366,39 @@ class ModelRunner:
# Query length can be shorter than key (i.e., prompt) when prefill
# Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached.
# is chunked or prefix cached.
sub
query_lens_tensor
=
torch
.
tensor
(
sub
query_lens
,
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
subquery_start_loc
=
torch
.
zeros
(
sub
query_lens_tensor
.
shape
[
0
]
+
1
,
subquery_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
self
.
device
)
prompt
_lens_tensor
=
torch
.
tensor
(
prompt
_lens
,
seq
_lens_tensor
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
device
=
self
.
device
)
seq_start_loc
=
torch
.
zeros
(
prompt
_lens_tensor
.
shape
[
0
]
+
1
,
seq_start_loc
=
torch
.
zeros
(
seq
_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
self
.
device
)
torch
.
cumsum
(
sub
query_lens_tensor
,
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dim
=
0
,
dtype
=
subquery_start_loc
.
dtype
,
dtype
=
subquery_start_loc
.
dtype
,
out
=
subquery_start_loc
[
1
:])
out
=
subquery_start_loc
[
1
:])
torch
.
cumsum
(
prompt
_lens_tensor
,
torch
.
cumsum
(
seq
_lens_tensor
,
dim
=
0
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
out
=
seq_start_loc
[
1
:])
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
is_prompt
=
True
,
prompt_lens
=
prompt_lens
,
seq_lens
=
seq_lens
,
prompt_lens_tensor
=
prompt_lens_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
max_subquery_len
=
max_subquery_len
,
max_query_len
=
max_query_len
,
max_context_len
=
None
,
max_seq_len
=
max_seq_len
,
max_prompt_len
=
max_prompt_len
,
subquery_start_loc
=
subquery_start_loc
,
subquery_start_loc
=
subquery_start_loc
,
seq_start_loc
=
seq_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens
=
context_lens_tensor
,
context_lens
_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
)
)
...
@@ -411,8 +407,8 @@ class ModelRunner:
...
@@ -411,8 +407,8 @@ class ModelRunner:
input_tokens
=
input_tokens
,
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
prompt_lens
=
prompt
_lens
,
seq_lens
=
seq
_lens
,
sub
query_lens
=
sub
query_lens
,
query_lens
=
query_lens
,
lora_index_mapping
=
lora_index_mapping
,
lora_index_mapping
=
lora_index_mapping
,
lora_prompt_mapping
=
lora_prompt_mapping
,
lora_prompt_mapping
=
lora_prompt_mapping
,
lora_requests
=
lora_requests
,
lora_requests
=
lora_requests
,
...
@@ -427,7 +423,7 @@ class ModelRunner:
...
@@ -427,7 +423,7 @@ class ModelRunner:
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
context
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
lora_index_mapping
:
List
[
int
]
=
[]
lora_index_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
...
@@ -455,9 +451,9 @@ class ModelRunner:
...
@@ -455,9 +451,9 @@ class ModelRunner:
position
=
seq_len
-
1
position
=
seq_len
-
1
input_positions
.
append
(
position
)
input_positions
.
append
(
position
)
context
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
seq_len
,
self
.
sliding_window
)
context
_lens
.
append
(
context
_len
)
seq
_lens
.
append
(
seq
_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_number
=
block_table
[
position
//
self
.
block_size
]
...
@@ -477,11 +473,10 @@ class ModelRunner:
...
@@ -477,11 +473,10 @@ class ModelRunner:
# See `capture_model` API for more details.
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
# For decoding requests, batch_size == input_tokens.
batch_size
=
len
(
input_tokens
)
batch_size
=
len
(
input_tokens
)
max_context_len
=
max
(
context_lens
)
max_seq_len
=
max
(
seq_lens
)
use_captured_graph
=
(
use_captured_graph
=
(
not
self
.
model_config
.
enforce_eager
not
self
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_
context
_len
<=
self
.
max_
context
_len_to_capture
)
and
max_
seq
_len
<=
self
.
max_
seq
_len_to_capture
)
if
use_captured_graph
:
if
use_captured_graph
:
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
assert
graph_batch_size
>=
batch_size
...
@@ -489,21 +484,21 @@ class ModelRunner:
...
@@ -489,21 +484,21 @@ class ModelRunner:
input_tokens
.
append
(
0
)
input_tokens
.
append
(
0
)
input_positions
.
append
(
0
)
input_positions
.
append
(
0
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
context
_lens
.
append
(
1
)
seq
_lens
.
append
(
1
)
block_tables
.
append
([])
block_tables
.
append
([])
lora_index_mapping
.
append
(
0
)
lora_index_mapping
.
append
(
0
)
batch_size
=
graph_batch_size
batch_size
=
graph_batch_size
context
_lens_tensor
=
torch
.
tensor
(
context
_lens
,
seq
_lens_tensor
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
device
=
self
.
device
)
if
use_captured_graph
:
if
use_captured_graph
:
# When using cuda-graph all these tensors should be
# When using cuda-graph all these tensors should be
# padded.
# padded.
assert
context
_lens_tensor
.
shape
[
0
]
==
len
(
input_tokens
)
assert
seq
_lens_tensor
.
shape
[
0
]
==
len
(
input_tokens
)
assert
context
_lens_tensor
.
shape
[
0
]
==
len
(
input_positions
)
assert
seq
_lens_tensor
.
shape
[
0
]
==
len
(
input_positions
)
assert
context
_lens_tensor
.
shape
[
0
]
==
len
(
slot_mapping
)
assert
seq
_lens_tensor
.
shape
[
0
]
==
len
(
slot_mapping
)
# The shape of graph_block_tables is
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# [max batch size, max context len // block size].
...
@@ -525,14 +520,13 @@ class ModelRunner:
...
@@ -525,14 +520,13 @@ class ModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
is_prompt
=
False
,
prompt_lens
=
None
,
seq_lens
=
None
,
prompt_lens_tensor
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
max_subquery_len
=
None
,
max_query_len
=
None
,
max_context_len
=
max_context_len
,
max_seq_len
=
max_seq_len
,
max_prompt_len
=
None
,
subquery_start_loc
=
None
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens_tensor
,
context_lens_tensor
=
None
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
,
)
)
...
@@ -565,8 +559,8 @@ class ModelRunner:
...
@@ -565,8 +559,8 @@ class ModelRunner:
input_tokens
,
input_tokens
,
input_positions
,
input_positions
,
prefill_attn_metadata
,
prefill_attn_metadata
,
prompt
_lens
,
seq
_lens
,
sub
query_lens
,
query_lens
,
lora_index_mapping
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_prompt_mapping
,
lora_requests
,
lora_requests
,
...
@@ -583,13 +577,13 @@ class ModelRunner:
...
@@ -583,13 +577,13 @@ class ModelRunner:
decode_slot_mapping
,
decode_slot_mapping
,
)
=
self
.
_prepare_decode
(
decode_reqs
)
)
=
self
.
_prepare_decode
(
decode_reqs
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt
_lens
,
sub
query_lens
,
seq_group_metadata_list
,
seq
_lens
,
query_lens
,
self
.
device
,
self
.
device
,
self
.
pin_memory
)
self
.
pin_memory
)
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
:
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
:
assert
(
len
(
prefill_reqs
)
and
len
(
decode_reqs
))
==
0
assert
(
len
(
prefill_reqs
)
and
len
(
decode_reqs
))
==
0
num_prefills
=
len
(
prompt
_lens
)
num_prefills
=
len
(
seq
_lens
)
num_prefill_tokens
=
len
(
input_tokens
)
num_prefill_tokens
=
len
(
input_tokens
)
num_decode_tokens
=
len
(
decode_input_tokens
)
num_decode_tokens
=
len
(
decode_input_tokens
)
...
@@ -886,7 +880,7 @@ class ModelRunner:
...
@@ -886,7 +880,7 @@ class ModelRunner:
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
.
fill_
(
_PAD_SLOT_ID
)
slot_mapping
.
fill_
(
_PAD_SLOT_ID
)
context
_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
seq
_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
graph_batch_size
=
_get_graph_batch_size
(
graph_batch_size
=
_get_graph_batch_size
(
...
@@ -908,14 +902,13 @@ class ModelRunner:
...
@@ -908,14 +902,13 @@ class ModelRunner:
# Create dummy attn_metadata.
# Create dummy attn_metadata.
decode_metadata
=
self
.
attn_backend
.
make_metadata
(
decode_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
is_prompt
=
False
,
prompt_lens
=
None
,
seq_lens
=
None
,
prompt_lens_tensor
=
None
,
seq_lens_tensor
=
seq_lens
[:
batch_size
],
max_subquery_len
=
None
,
max_query_len
=
None
,
max_context_len
=
self
.
max_context_len_to_capture
,
max_seq_len
=
self
.
max_seq_len_to_capture
,
max_prompt_len
=
None
,
subquery_start_loc
=
None
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens
[:
batch_size
]
,
context_lens
_tensor
=
None
,
block_tables
=
block_tables
[:
batch_size
],
block_tables
=
block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
use_cuda_graph
=
True
,
)
)
...
@@ -1025,7 +1018,7 @@ class CUDAGraphRunner:
...
@@ -1025,7 +1018,7 @@ class CUDAGraphRunner:
"positions"
:
positions
,
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
"kv_caches"
:
kv_caches
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"
context_lens
"
:
attn_metadata
.
decode_metadata
.
context_lens
,
"
seq_lens_tensor
"
:
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
...
@@ -1047,8 +1040,8 @@ class CUDAGraphRunner:
...
@@ -1047,8 +1040,8 @@ class CUDAGraphRunner:
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
non_blocking
=
True
)
non_blocking
=
True
)
self
.
input_buffers
[
"
context_lens
"
].
copy_
(
self
.
input_buffers
[
"
seq_lens_tensor
"
].
copy_
(
attn_metadata
.
decode_metadata
.
context_lens
,
non_blocking
=
True
)
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
non_blocking
=
True
)
self
.
input_buffers
[
"block_tables"
].
copy_
(
self
.
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
# Run the graph.
# Run the graph.
...
...
vllm/worker/neuron_model_runner.py
View file @
3521ba4f
...
@@ -52,7 +52,7 @@ class NeuronModelRunner:
...
@@ -52,7 +52,7 @@ class NeuronModelRunner:
input_positions
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
input_block_ids
:
List
[
int
]
=
[]
input_block_ids
:
List
[
int
]
=
[]
prompt
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
...
@@ -61,26 +61,26 @@ class NeuronModelRunner:
...
@@ -61,26 +61,26 @@ class NeuronModelRunner:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt
_len
=
len
(
prompt_tokens
)
seq
_len
=
len
(
prompt_tokens
)
prompt
_lens
.
append
(
prompt
_len
)
seq
_lens
.
append
(
seq
_len
)
input_tokens
.
append
(
prompt_tokens
)
input_tokens
.
append
(
prompt_tokens
)
input_positions
.
append
(
list
(
range
(
prompt
_len
)))
input_positions
.
append
(
list
(
range
(
seq
_len
)))
assert
seq_group_metadata
.
block_tables
is
not
None
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
assert
len
(
block_table
)
==
1
assert
len
(
block_table
)
==
1
input_block_ids
.
append
(
block_table
[
0
])
input_block_ids
.
append
(
block_table
[
0
])
max_
prompt
_len
=
max
(
prompt
_lens
)
max_
seq
_len
=
max
(
seq
_lens
)
assert
max_
prompt
_len
>
0
assert
max_
seq
_len
>
0
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
max_
prompt
_len
,
max_
seq
_len
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
input_positions
=
make_tensor_with_pad
(
input_positions
,
max_
prompt
_len
,
max_
seq
_len
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
...
@@ -88,7 +88,7 @@ class NeuronModelRunner:
...
@@ -88,7 +88,7 @@ class NeuronModelRunner:
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
return
input_tokens
,
input_positions
,
input_block_ids
,
prompt
_lens
return
input_tokens
,
input_positions
,
input_block_ids
,
seq
_lens
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
...
@@ -149,18 +149,18 @@ class NeuronModelRunner:
...
@@ -149,18 +149,18 @@ class NeuronModelRunner:
# Prepare input tensors.
# Prepare input tensors.
if
is_prompt
:
if
is_prompt
:
(
input_tokens
,
input_positions
,
input_block_ids
,
(
input_tokens
,
input_positions
,
input_block_ids
,
prompt
_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
seq
_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
else
:
(
input_tokens
,
input_positions
,
(
input_tokens
,
input_positions
,
input_block_ids
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
input_block_ids
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt
_lens
=
[]
seq
_lens
=
[]
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_group_metadata_list
,
prompt
_lens
,
seq
_lens
,
#
sub
query_lens is not needed if chunked prefill is not
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# supported. Since neuron worker doesn't support chunked prefill
# just use
prompt
_lens instead.
# just use
seq
_lens instead.
prompt
_lens
,
seq
_lens
,
self
.
device
,
self
.
device
,
self
.
pin_memory
)
self
.
pin_memory
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment