Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
55fe8a81
Unverified
Commit
55fe8a81
authored
Aug 02, 2023
by
Woosuk Kwon
Committed by
GitHub
Aug 02, 2023
Browse files
Refactor scheduler (#658)
parent
e8ddc08e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
205 additions
and
144 deletions
+205
-144
examples/llm_engine_example.py
examples/llm_engine_example.py
+1
-1
vllm/core/scheduler.py
vllm/core/scheduler.py
+96
-129
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+91
-9
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+17
-5
No files found.
examples/llm_engine_example.py
View file @
55fe8a81
...
...
@@ -28,7 +28,7 @@ def main(args: argparse.Namespace):
# Run the engine by calling `engine.step()` manually.
request_id
=
0
while
True
:
# To test
iteration-level schedul
ing, we add one request at each step.
# To test
continuous batch
ing, we add one request at each step.
if
test_prompts
:
prompt
,
sampling_params
=
test_prompts
.
pop
(
0
)
engine
.
add_request
(
str
(
request_id
),
prompt
,
sampling_params
)
...
...
vllm/core/scheduler.py
View file @
55fe8a81
...
...
@@ -12,8 +12,6 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
logger
=
init_logger
(
__name__
)
_LOGGING_INTERVAL_SEC
=
5
class
PreemptionMode
(
enum
.
Enum
):
"""Preemption modes.
...
...
@@ -32,19 +30,28 @@ class SchedulerOutputs:
def
__init__
(
self
,
scheduled_seq_groups
:
List
[
SequenceGroup
],
prompt_run
:
bool
,
num_batched_tokens
:
int
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
ignored_seq_groups
:
List
[
SequenceGroup
],
)
->
None
:
self
.
scheduled_seq_groups
=
scheduled_seq_groups
self
.
prompt_run
=
prompt_run
self
.
num_batched_tokens
=
num_batched_tokens
self
.
blocks_to_swap_in
=
blocks_to_swap_in
self
.
blocks_to_swap_out
=
blocks_to_swap_out
self
.
blocks_to_copy
=
blocks_to_copy
# Swap in and swap out should never happen at the same time.
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
self
.
ignored_seq_groups
=
ignored_seq_groups
def
is_empty
(
self
)
->
bool
:
return
(
not
self
.
blocks_to_swap_in
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
# NOTE: We do not consider the ignored sequence groups.
return
(
not
self
.
scheduled_seq_groups
and
not
self
.
blocks_to_swap_in
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
class
Scheduler
:
...
...
@@ -53,11 +60,9 @@ class Scheduler:
self
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
,
log_stats
:
bool
,
)
->
None
:
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
log_stats
=
log_stats
# Instantiate the scheduling policy.
self
.
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
...
...
@@ -75,10 +80,6 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
self
.
swapped
:
List
[
SequenceGroup
]
=
[]
self
.
last_logging_time
:
float
=
0.0
# List[timestamp, num_tokens]
self
.
num_input_tokens
:
List
[
Tuple
[
float
,
int
]]
=
[]
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
...
...
@@ -101,21 +102,80 @@ class Scheduler:
def
get_num_unfinished_seq_groups
(
self
)
->
int
:
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
+
len
(
self
.
swapped
)
def
_schedule
(
self
)
->
Tuple
[
SchedulerOutputs
,
List
[
str
],
List
[
SequenceGroup
]]:
def
_schedule
(
self
)
->
SchedulerOutputs
:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
# Fix the current time.
now
=
time
.
time
()
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
# in order to minimize the preemption overheads.
# Preemption happens only when there is no available slot to keep all
# the sequence groups in the RUNNING state.
# Join waiting sequences if possible.
if
not
self
.
swapped
:
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
scheduled
:
List
[
SequenceGroup
]
=
[]
num_batched_tokens
=
0
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while
self
.
waiting
:
seq_group
=
self
.
waiting
[
0
]
num_prompt_tokens
=
seq_group
.
get_seqs
()[
0
].
get_len
()
prompt_limit
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
scheduler_config
.
max_num_batched_tokens
)
if
num_prompt_tokens
>
prompt_limit
:
logger
.
warning
(
f
"Input prompt (
{
num_prompt_tokens
}
tokens) is too long"
f
" and exceeds limit of
{
prompt_limit
}
"
)
for
seq
in
seq_group
.
get_seqs
():
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
pop
(
0
)
break
# If the sequence group cannot be allocated, stop.
if
not
self
.
block_manager
.
can_allocate
(
seq_group
):
break
# If the number of batched tokens exceeds the limit, stop.
if
(
num_batched_tokens
+
num_prompt_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
WAITING
)
num_curr_seqs
=
sum
(
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
for
seq_group
in
self
.
running
)
if
(
num_curr_seqs
+
num_new_seqs
>
self
.
scheduler_config
.
max_num_seqs
):
break
seq_group
=
self
.
waiting
.
pop
(
0
)
self
.
_allocate
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_batched_tokens
+=
num_prompt_tokens
scheduled
.
append
(
seq_group
)
if
scheduled
:
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
scheduled
,
prompt_run
=
True
,
num_batched_tokens
=
num_batched_tokens
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
ignored_seq_groups
=
ignored_seq_groups
,
)
return
scheduler_outputs
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
self
.
running
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
running
)
...
...
@@ -173,124 +233,26 @@ class Scheduler:
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
for
seq_group
in
self
.
running
)
# Join waiting sequences if possible.
prompt_group_ids
:
List
[
str
]
=
[]
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
# prioritized over the sequence groups in the WAITING state.
# This is because we want to bound the amount of CPU memory taken by
# the swapped sequence groups.
if
not
self
.
swapped
:
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while
self
.
waiting
:
seq_group
=
self
.
waiting
[
0
]
# If the sequence group has been preempted in this step, stop.
if
seq_group
in
preempted
:
break
num_prompt_tokens
=
seq_group
.
get_seqs
()[
0
].
get_len
()
prompt_limit
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
scheduler_config
.
max_num_batched_tokens
)
if
num_prompt_tokens
>
prompt_limit
:
logger
.
warning
(
f
"Input prompt (
{
num_prompt_tokens
}
tokens) is too long"
f
" and exceeds limit of
{
prompt_limit
}
"
)
for
seq
in
seq_group
.
get_seqs
():
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
pop
(
0
)
break
# If the sequence group cannot be allocated, stop.
if
not
self
.
block_manager
.
can_allocate
(
seq_group
):
break
# If the number of batched tokens exceeds the limit, stop.
if
(
num_batched_tokens
+
num_prompt_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
WAITING
)
num_curr_seqs
=
sum
(
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
for
seq_group
in
self
.
running
)
if
(
num_curr_seqs
+
num_new_seqs
>
self
.
scheduler_config
.
max_num_seqs
):
break
seq_group
=
self
.
waiting
.
pop
(
0
)
self
.
_allocate
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_batched_tokens
+=
num_prompt_tokens
prompt_group_ids
.
append
(
seq_group
.
request_id
)
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
self
.
running
,
prompt_run
=
False
,
num_batched_tokens
=
num_batched_tokens
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
ignored_seq_groups
=
[],
)
if
not
self
.
log_stats
:
return
scheduler_outputs
,
prompt_group_ids
,
ignored_seq_groups
return
scheduler_outputs
# TODO(woosuk): Move the below code to the engine.
now
=
time
.
time
()
if
num_batched_tokens
>
0
:
self
.
num_input_tokens
.
append
((
now
,
num_batched_tokens
))
elapsed_time
=
now
-
self
.
last_logging_time
if
elapsed_time
>
_LOGGING_INTERVAL_SEC
:
self
.
last_logging_time
=
now
self
.
num_input_tokens
=
[(
t
,
n
)
for
t
,
n
in
self
.
num_input_tokens
if
now
-
t
<
_LOGGING_INTERVAL_SEC
]
if
len
(
self
.
num_input_tokens
)
>
1
:
total_num_tokens
=
sum
(
n
for
_
,
n
in
self
.
num_input_tokens
[:
-
1
])
window
=
now
-
self
.
num_input_tokens
[
0
][
0
]
avg_throughput
=
total_num_tokens
/
window
else
:
avg_throughput
=
0.0
total_num_gpu_blocks
=
self
.
cache_config
.
num_gpu_blocks
num_free_gpu_blocks
=
self
.
block_manager
.
get_num_free_gpu_blocks
()
num_used_gpu_blocks
=
total_num_gpu_blocks
-
num_free_gpu_blocks
gpu_cache_usage
=
num_used_gpu_blocks
/
total_num_gpu_blocks
total_num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
if
total_num_cpu_blocks
>
0
:
num_free_cpu_blocks
=
(
self
.
block_manager
.
get_num_free_cpu_blocks
())
num_used_cpu_blocks
=
total_num_cpu_blocks
-
num_free_cpu_blocks
cpu_cache_usage
=
num_used_cpu_blocks
/
total_num_cpu_blocks
else
:
cpu_cache_usage
=
0.0
logger
.
info
(
f
"Throughput:
{
avg_throughput
:.
1
f
}
tokens/s, "
f
"Running:
{
len
(
self
.
running
)
}
reqs, "
f
"Swapped:
{
len
(
self
.
swapped
)
}
reqs, "
f
"Pending:
{
len
(
self
.
waiting
)
}
reqs, "
f
"GPU KV cache usage:
{
gpu_cache_usage
*
100
:.
1
f
}
%, "
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
return
scheduler_outputs
,
prompt_group_ids
,
ignored_seq_groups
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
List
[
SequenceGroup
]]:
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
(
scheduler_outputs
,
prompt_group_ids
,
ignored_seq_groups
)
=
self
.
_schedule
()
scheduler_outputs
=
self
.
_schedule
()
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
seq_group
in
self
.
running
:
is_prompt
=
seq_group
.
request_id
in
prompt_group_ids
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
seq_data
:
Dict
[
int
,
List
[
SequenceData
]]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
...
...
@@ -300,20 +262,27 @@ class Scheduler:
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
seq_group
.
request_id
,
is_prompt
=
is_
prompt
,
is_prompt
=
scheduler_outputs
.
prompt
_run
,
seq_data
=
seq_data
,
sampling_params
=
seq_group
.
sampling_params
,
block_tables
=
block_tables
,
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
return
seq_group_metadata_list
,
scheduler_outputs
,
ignored_seq_groups
return
seq_group_metadata_list
,
scheduler_outputs
def
update
(
self
,
seq_outputs
:
Dict
[
int
,
SequenceOutputs
],
)
->
List
[
SequenceGroup
]:
# Update the running sequences and free blocks.
scheduled
:
List
[
SequenceGroup
]
=
[]
for
seq_group
in
self
.
running
:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
if
seq
.
seq_id
in
seq_outputs
:
scheduled
.
append
(
seq_group
)
break
# Update the scheduled sequences and free blocks.
for
seq_group
in
scheduled
:
# Process beam search results before processing the new tokens.
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
output
=
seq_outputs
[
seq
.
seq_id
]
...
...
@@ -331,9 +300,7 @@ class Scheduler:
# Append a new token to the sequence.
output
=
seq_outputs
[
seq
.
seq_id
]
seq
.
append_token_id
(
output
.
output_token
,
output
.
logprobs
)
# Return a shallow copy of the running queue to prevent the queue
# from being modified by the caller.
return
self
.
running
.
copy
()
return
scheduled
def
free_seq
(
self
,
seq
:
Sequence
,
finish_status
:
SequenceStatus
)
->
None
:
seq
.
status
=
finish_status
...
...
vllm/engine/llm_engine.py
View file @
55fe8a81
import
time
import
copy
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
,
TYPE_CHECKING
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
TYPE_CHECKING
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
...
...
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
_LOGGING_INTERVAL_SEC
=
5
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
...
...
@@ -102,7 +104,14 @@ class LLMEngine:
self
.
_init_cache
()
# Create the scheduler.
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
log_stats
)
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
)
# Logging.
self
.
last_logging_time
=
0.0
# List of (timestamp, num_tokens)
self
.
num_prompt_tokens
:
List
[
Tuple
[
float
,
int
]]
=
[]
# List of (timestamp, num_tokens)
self
.
num_generation_tokens
:
List
[
Tuple
[
float
,
int
]]
=
[]
def
_init_workers
(
self
,
distributed_init_method
:
str
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
...
...
@@ -288,12 +297,17 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
(
seq_group_metadata_list
,
scheduler_outputs
,
ignored_seq_groups
)
=
self
.
scheduler
.
schedule
()
if
((
not
seq_group_metadata_list
)
and
scheduler_outputs
.
is_empty
()
and
(
not
ignored_seq_groups
)):
# Nothing to do.
return
[]
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
ignored_seq_groups
:
# Nothing to do.
return
[]
# If there are ignored seq groups, we need to return them as the
# request outputs.
return
[
RequestOutput
.
from_seq_group
(
seq_group
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
]
# Execute the model.
output
=
self
.
_run_workers
(
...
...
@@ -315,11 +329,79 @@ class LLMEngine:
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
seq_groups
+
ignored_seq_groups
:
for
seq_group
in
seq_groups
+
scheduler_outputs
.
ignored_seq_groups
:
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
if
self
.
log_stats
:
# Log the system stats.
self
.
_log_system_stats
(
scheduler_outputs
.
prompt_run
,
scheduler_outputs
.
num_batched_tokens
)
return
request_outputs
def
_log_system_stats
(
self
,
prompt_run
:
bool
,
num_batched_tokens
:
int
,
)
->
None
:
now
=
time
.
time
()
# Log the number of batched input tokens.
if
prompt_run
:
self
.
num_prompt_tokens
.
append
((
now
,
num_batched_tokens
))
else
:
self
.
num_generation_tokens
.
append
((
now
,
num_batched_tokens
))
elapsed_time
=
now
-
self
.
last_logging_time
if
elapsed_time
<
_LOGGING_INTERVAL_SEC
:
return
# Discard the old stats.
self
.
num_prompt_tokens
=
[(
t
,
n
)
for
t
,
n
in
self
.
num_prompt_tokens
if
now
-
t
<
_LOGGING_INTERVAL_SEC
]
self
.
num_generation_tokens
=
[(
t
,
n
)
for
t
,
n
in
self
.
num_generation_tokens
if
now
-
t
<
_LOGGING_INTERVAL_SEC
]
if
len
(
self
.
num_prompt_tokens
)
>
1
:
total_num_tokens
=
sum
(
n
for
_
,
n
in
self
.
num_prompt_tokens
[:
-
1
])
window
=
now
-
self
.
num_prompt_tokens
[
0
][
0
]
avg_prompt_throughput
=
total_num_tokens
/
window
else
:
avg_prompt_throughput
=
0.0
if
len
(
self
.
num_generation_tokens
)
>
1
:
total_num_tokens
=
sum
(
n
for
_
,
n
in
self
.
num_generation_tokens
[:
-
1
])
window
=
now
-
self
.
num_generation_tokens
[
0
][
0
]
avg_generation_throughput
=
total_num_tokens
/
window
else
:
avg_generation_throughput
=
0.0
total_num_gpu_blocks
=
self
.
cache_config
.
num_gpu_blocks
num_free_gpu_blocks
=
(
self
.
scheduler
.
block_manager
.
get_num_free_gpu_blocks
())
num_used_gpu_blocks
=
total_num_gpu_blocks
-
num_free_gpu_blocks
gpu_cache_usage
=
num_used_gpu_blocks
/
total_num_gpu_blocks
total_num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
if
total_num_cpu_blocks
>
0
:
num_free_cpu_blocks
=
(
self
.
scheduler
.
block_manager
.
get_num_free_cpu_blocks
())
num_used_cpu_blocks
=
total_num_cpu_blocks
-
num_free_cpu_blocks
cpu_cache_usage
=
num_used_cpu_blocks
/
total_num_cpu_blocks
else
:
cpu_cache_usage
=
0.0
logger
.
info
(
"Avg prompt throughput: "
f
"
{
avg_prompt_throughput
:.
1
f
}
tokens/s, "
"Avg generation throughput: "
f
"
{
avg_generation_throughput
:.
1
f
}
tokens/s, "
f
"Running:
{
len
(
self
.
scheduler
.
running
)
}
reqs, "
f
"Swapped:
{
len
(
self
.
scheduler
.
swapped
)
}
reqs, "
f
"Pending:
{
len
(
self
.
scheduler
.
waiting
)
}
reqs, "
f
"GPU KV cache usage:
{
gpu_cache_usage
*
100
:.
1
f
}
%, "
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
self
.
last_logging_time
=
now
def
_decode_sequences
(
self
,
seq_groups
:
List
[
SequenceGroup
])
->
None
:
"""Decodes the sequence outputs."""
for
seq_group
in
seq_groups
:
...
...
vllm/model_executor/layers/attention.py
View file @
55fe8a81
...
...
@@ -20,12 +20,20 @@ class PagedAttention(nn.Module):
"""GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The
input 1D tensors can
be split into three parts: the prompt tokens, the
genera
tion to
kens, and the
paddings.
input 1D tensors can
either contain prompt tokens or generation tokens, in
addi
tion to paddings.
|<------------------------------------- num_valid_tokens ------------------------------------->|
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
If the input tensors contain prompt tokens, the layout is as follows:
|<---------------------- num_valid_tokens ---------------------->|
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
Otherwise, the layout is as follows:
|<------------------ num_valid_tokens ------------------->|
|<------- num_generation_tokens (M) ------->|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple
...
...
@@ -188,6 +196,8 @@ class PagedAttention(nn.Module):
# Compute the attention op for prompts.
num_prompt_tokens
=
input_metadata
.
num_prompt_tokens
if
num_prompt_tokens
>
0
:
# Prompt run.
assert
input_metadata
.
num_generation_tokens
==
0
self
.
set_attn_bias
(
input_metadata
)
self
.
multi_query_kv_attention
(
output
[:
num_prompt_tokens
],
...
...
@@ -217,6 +227,8 @@ class PagedAttention(nn.Module):
)
if
input_metadata
.
num_generation_tokens
>
0
:
# Decoding run.
assert
input_metadata
.
num_prompt_tokens
==
0
assert
key_cache
is
not
None
and
value_cache
is
not
None
,
(
"key_cache and value_cache must be provided when "
"generating tokens."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment