Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1a7eb7da
Unverified
Commit
1a7eb7da
authored
Mar 10, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 10, 2023
Browse files
Support beam search & parallel generation (#7)
parent
04e5acc0
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
662 additions
and
163 deletions
+662
-163
cacheflow/block.py
cacheflow/block.py
+4
-0
cacheflow/master/frontend.py
cacheflow/master/frontend.py
+28
-5
cacheflow/master/scheduler.py
cacheflow/master/scheduler.py
+55
-36
cacheflow/models/__init__.py
cacheflow/models/__init__.py
+3
-1
cacheflow/models/input_metadata.py
cacheflow/models/input_metadata.py
+11
-7
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+10
-0
cacheflow/models/opt.py
cacheflow/models/opt.py
+2
-1
cacheflow/models/sample.py
cacheflow/models/sample.py
+258
-17
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+54
-19
cacheflow/sequence.py
cacheflow/sequence.py
+72
-6
cacheflow/worker/cache_engine.py
cacheflow/worker/cache_engine.py
+27
-8
cacheflow/worker/controller.py
cacheflow/worker/controller.py
+7
-10
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+76
-43
csrc/cache.cpp
csrc/cache.cpp
+14
-2
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+31
-1
server.py
server.py
+10
-7
No files found.
cacheflow/block.py
View file @
1a7eb7da
...
...
@@ -35,6 +35,10 @@ class LogicalTokenBlock:
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
token_ids
[:
self
.
num_tokens
]
def
get_last_token_id
(
self
)
->
int
:
assert
self
.
num_tokens
>
0
return
self
.
token_ids
[
self
.
num_tokens
-
1
]
class
PhysicalTokenBlock
:
...
...
cacheflow/master/frontend.py
View file @
1a7eb7da
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
transformers
import
AutoTokenizer
...
...
@@ -25,12 +25,35 @@ class Frontend:
def
query
(
self
,
prompt
:
str
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
n
:
int
=
1
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
use_beam_search
:
bool
=
False
,
stop_token_ids
:
Set
[
int
]
=
set
(),
max_num_steps
:
int
=
16
,
# From OpenAI API.
num_logprobs
:
int
=
0
,
context_window_size
:
Optional
[
int
]
=
None
,
)
->
None
:
if
sampling_params
is
None
:
sampling_params
=
SamplingParams
()
token_ids
:
List
[
int
]
=
self
.
tokenizer
.
encode
(
prompt
)
# Stop when we see an EOS token.
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
sampling_params
=
SamplingParams
(
n
=
n
,
temperature
=
temperature
,
top_p
=
top_p
,
use_beam_search
=
use_beam_search
,
stop_token_ids
=
stop_token_ids
,
max_num_steps
=
max_num_steps
,
num_logprobs
=
num_logprobs
,
context_window_size
=
context_window_size
,
)
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
self
.
_add_query
(
token_ids
,
sampling_params
)
def
_add_query
(
self
,
token_ids
:
List
[
int
],
sampling_params
:
SamplingParams
,
)
->
None
:
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
...
...
cacheflow/master/scheduler.py
View file @
1a7eb7da
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
from
cacheflow.master.block_manager
import
BlockSpaceManager
from
cacheflow.master.frontend
import
Frontend
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.sequence
import
SequenceStatus
_MAX_NUM_BATCHED_TOKENS
=
2048
...
...
@@ -66,7 +68,7 @@ class Scheduler:
def
_append
(
self
,
seq_group
:
SequenceGroup
,
blocks_to_copy
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
,
)
->
None
:
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
...
...
@@ -74,7 +76,10 @@ class Scheduler:
ret
=
self
.
block_manager
.
append
(
seq
)
if
ret
is
not
None
:
src_block
,
dst_block
=
ret
blocks_to_copy
[
src_block
]
=
dst_block
if
src_block
in
blocks_to_copy
:
blocks_to_copy
[
src_block
].
append
(
dst_block
)
else
:
blocks_to_copy
[
src_block
]
=
[
dst_block
]
def
_swap_in
(
self
,
...
...
@@ -83,9 +88,8 @@ class Scheduler:
)
->
None
:
mapping
=
self
.
block_manager
.
swap_in
(
seq_group
)
blocks_to_swap_in
.
update
(
mapping
)
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
SWAPPED
:
seq
.
status
=
SequenceStatus
.
RUNNING
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
seq
.
status
=
SequenceStatus
.
RUNNING
self
.
running
.
append
(
seq_group
)
def
_swap_out
(
...
...
@@ -96,16 +100,15 @@ class Scheduler:
assert
self
.
block_manager
.
can_swap_out
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_out
(
seq_group
)
blocks_to_swap_out
.
update
(
mapping
)
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
RUNNING
:
seq
.
status
=
SequenceStatus
.
SWAPPED
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
seq
.
status
=
SequenceStatus
.
SWAPPED
self
.
swapped
.
append
(
seq_group
)
def
step
(
self
)
->
None
:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
=
{}
# 1. Reserve new slots for the running sequences.
# NOTE: Here we implicitly assume FCFS scheduling.
...
...
@@ -143,6 +146,10 @@ class Scheduler:
# All swapped sequences are swapped in.
self
.
swapped
.
clear
()
# Ensure that swap-in and swap-out never happen at the same timestep.
if
blocks_to_swap_in
:
assert
not
blocks_to_swap_out
num_batched_tokens
=
sum
(
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
for
seq_group
in
self
.
running
...
...
@@ -152,7 +159,6 @@ class Scheduler:
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
if
not
self
.
swapped
:
# FIXME(woosuk): Acquire a lock to protect pending.
self
.
_fetch_inputs
()
for
i
,
seq_group
in
enumerate
(
self
.
pending
):
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
...
...
@@ -168,39 +174,45 @@ class Scheduler:
else
:
self
.
pending
.
clear
()
# Ensure that swap-in and swap-out never happen at the same timestep.
if
blocks_to_swap_in
:
assert
not
blocks_to_swap_out
# 4. Create input data structures.
prompt_tokens
:
Dict
[
int
,
List
[
int
]]
=
{}
generation_tokens
:
Dict
[
int
,
int
]
=
{}
context_lens
:
Dict
[
int
,
int
]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
input_seq_groups
:
List
[
SequenceGroupInputs
]
=
[]
for
seq_group
in
self
.
running
:
group_id
=
seq_group
.
group_id
num_steps
=
self
.
num_steps
[
group_id
]
# NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences.
is_prompt
=
num_steps
==
0
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
!=
SequenceStatus
.
RUNNING
:
continue
input_tokens
:
Dict
[
int
,
List
[
int
]]
=
{}
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
seq_id
=
seq
.
seq_id
block_tables
[
seq_id
]
=
self
.
block_manager
.
get_block_table
(
seq
)
if
is_prompt
:
promp
t_tokens
[
seq_id
]
=
seq
.
get_token_ids
()
inpu
t_tokens
[
seq_id
]
=
seq
.
get_token_ids
()
else
:
generation_tokens
[
seq_id
]
=
seq
.
get_token_ids
()[
-
1
]
context_lens
[
seq_id
]
=
seq
.
get_len
()
input_tokens
[
seq_id
]
=
[
seq
.
get_last_token_id
()]
seq_logprobs
[
seq_id
]
=
seq
.
cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len
=
seq
.
get_len
()
input_seq_group
=
SequenceGroupInputs
(
group_id
=
group_id
,
is_prompt
=
is_prompt
,
input_tokens
=
input_tokens
,
context_len
=
seq_len
,
seq_logprobs
=
seq_logprobs
,
sampling_params
=
self
.
sampling_params
[
group_id
],
block_tables
=
block_tables
,
)
input_seq_groups
.
append
(
input_seq_group
)
# 5. Execute the first stage of the pipeline.
self
.
controllers
[
0
].
execute_stage
(
prompt_tokens
,
generation_tokens
,
context_lens
,
block_tables
,
input_seq_groups
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
...
...
@@ -208,7 +220,7 @@ class Scheduler:
def
post_step
(
self
,
next_token
s
:
Dict
[
int
,
Tuple
[
int
,
int
]
],
seq_output
s
:
Dict
[
int
,
SequenceOutputs
],
)
->
None
:
# Update the running sequences and free blocks.
for
seq_group
in
self
.
running
:
...
...
@@ -216,25 +228,32 @@ class Scheduler:
self
.
num_steps
[
group_id
]
+=
1
stop_token_ids
=
self
.
sampling_params
[
group_id
].
stop_token_ids
# Process beam search results before processing the next tokens.
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
parent_seq_id
,
next_token
=
next_token
s
[
seq
.
seq_id
]
if
seq
.
seq_id
!=
parent_seq_id
:
output
=
seq_output
s
[
seq
.
seq_id
]
if
seq
.
seq_id
!=
output
.
parent_seq_id
:
# The sequence is a fork of the parent sequence (beam search).
# Free the current sequence.
self
.
block_manager
.
free
(
seq
)
# Fork the parent sequence.
parent_seq
=
seq_group
.
find
(
parent_seq_id
)
seq
.
logical_token_blocks
=
parent_seq
.
logical_token_blocks
.
copy
(
)
parent_seq
=
seq_group
.
find
(
output
.
parent_seq_id
)
parent_seq
.
fork
(
seq
)
self
.
block_manager
.
fork
(
parent_seq
,
seq
)
# Process the next tokens.
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
# Append a new token to the sequence.
seq
.
append
([
next_token
])
output
=
seq_outputs
[
seq
.
seq_id
]
seq
.
append
(
output
.
output_token
,
output
.
logprobs
)
# Check if the sequence has generated a stop token.
if
nex
t_token
in
stop_token_ids
:
if
output
.
outpu
t_token
in
stop_token_ids
:
self
.
_free_seq
(
seq
)
continue
...
...
cacheflow/models/__init__.py
View file @
1a7eb7da
from
cacheflow.models.input_metadata
import
InputMetadata
from
cacheflow.models.model_utils
import
get_model
from
cacheflow.models.model_utils
import
set_seed
__all__
=
[
'get_model'
,
'InputMetadata'
,
'get_model'
,
'set_seed'
]
cacheflow/models/input_metadata.py
View file @
1a7eb7da
from
typing
import
List
from
typing
import
List
,
Dict
,
Tuple
import
torch
from
cacheflow.sampling_params
import
SamplingParams
class
InputMetadata
:
def
__init__
(
self
,
seq_ids
:
List
[
int
],
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
prompt_lens
:
List
[
int
],
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
# FIXME: Rename
max_context_len
:
int
,
block_tables
:
torch
.
Tensor
,
)
->
None
:
self
.
seq_ids
=
seq_ids
self
.
seq_groups
=
seq_groups
self
.
seq_logprobs
=
seq_logprobs
self
.
prompt_lens
=
prompt_lens
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
...
...
@@ -23,19 +26,20 @@ class InputMetadata:
self
.
block_tables
=
block_tables
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompt_tokens
=
sum
(
prompt_lens
)
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
if
block_tables
.
numel
()
>
0
:
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
else
:
self
.
max_num_blocks_per_seq
=
0
assert
self
.
num_generation_tokens
==
block_tables
.
shape
[
0
]
assert
self
.
num_prompts
+
self
.
num_generation_tokens
==
len
(
seq_ids
)
assert
block_tables
.
shape
[
0
]
==
self
.
num_generation_tokens
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
def
__repr__
(
self
)
->
str
:
return
(
f
'InputMetadata('
f
'seq_ids=
{
self
.
seq_ids
}
, '
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
...
...
cacheflow/models/model_utils.py
View file @
1a7eb7da
import
random
from
typing
import
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
...
...
@@ -30,3 +32,11 @@ def get_model(
model
=
hf_model
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
)
return
model
.
eval
()
raise
ValueError
(
f
'Invalid model name:
{
model_name
}
'
)
def
set_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
seed
)
cacheflow/models/opt.py
View file @
1a7eb7da
...
...
@@ -9,6 +9,7 @@ from transformers import PreTrainedModel
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
OPTCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.sequence
import
SequenceOutputs
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -261,7 +262,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Tuple
[
int
,
int
]
]:
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
...
...
cacheflow/models/sample.py
View file @
1a7eb7da
...
...
@@ -4,6 +4,8 @@ import torch
import
torch.nn
as
nn
from
cacheflow.models
import
InputMetadata
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceOutputs
class
Sampler
(
nn
.
Module
):
...
...
@@ -16,27 +18,266 @@ class Sampler(nn.Module):
embedding
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
Dict
[
int
,
Tuple
[
int
,
int
]]:
# Get the hidden states of the last tokens.
start_idx
=
0
last_token_indicies
:
List
[
int
]
=
[]
for
prompt_len
in
input_metadata
.
prompt_lens
:
last_token_indicies
.
append
(
start_idx
+
prompt_len
-
1
)
start_idx
+=
prompt_len
last_token_indicies
.
extend
(
range
(
start_idx
,
start_idx
+
input_metadata
.
num_generation_tokens
))
hidden_states
=
hidden_states
[
last_token_indicies
]
)
->
Dict
[
int
,
SequenceOutputs
]:
# Get the hidden states that we use for sampling.
hidden_states
=
_prune_hidden_states
(
hidden_states
,
input_metadata
)
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input_metadata
)
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
if
any
(
t
!=
1.0
for
t
in
temperatures
):
t
=
torch
.
tensor
(
temperatures
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities (before applying top-p).
logprobs
=
torch
.
log
(
probs
)
# Apply top-p truncation.
top_ps
=
_get_top_ps
(
input_metadata
)
assert
len
(
top_ps
)
==
probs
.
shape
[
0
]
if
any
(
p
<
1.0
for
p
in
top_ps
):
p
=
torch
.
tensor
(
top_ps
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
probs
=
_apply_top_p
(
probs
,
p
)
# Sample the next tokens.
# TODO(woosuk): Implement other sampling methods.
next_token_ids
=
torch
.
argmax
(
logits
,
dim
=-
1
)
return
_sample
(
probs
,
logprobs
,
input_metadata
)
def
_prune_hidden_states
(
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
start_idx
=
0
last_token_indicies
:
List
[
int
]
=
[]
for
prompt_len
in
input_metadata
.
prompt_lens
:
last_token_indicies
.
append
(
start_idx
+
prompt_len
-
1
)
start_idx
+=
prompt_len
last_token_indicies
.
extend
(
range
(
start_idx
,
start_idx
+
input_metadata
.
num_generation_tokens
))
return
hidden_states
[
last_token_indicies
]
def
_get_temperatures
(
input_metadata
:
InputMetadata
,
)
->
List
[
float
]:
# Collect the temperatures for the logits.
temperatures
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
if
temperature
==
0.0
:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature
=
1.0
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
temperatures
.
append
(
temperature
)
else
:
# A generation token.
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
return
temperatures
def
_get_top_ps
(
input_metadata
:
InputMetadata
,
)
->
List
[
float
]:
top_ps
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
top_ps
.
append
(
sampling_params
.
top_p
)
else
:
# A generation token.
top_ps
+=
[
sampling_params
.
top_p
]
*
len
(
seq_ids
)
return
top_ps
def
_apply_top_p
(
probs
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# TODO(woosuk): Optimize.
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
mask
=
(
probs_sum
-
probs_sort
)
>
p
.
unsqueeze
(
dim
=
1
)
probs_sort
[
mask
]
=
0.0
probs_sort
.
div_
(
probs_sort
.
sum
(
dim
=-
1
,
keepdim
=
True
))
probs
=
torch
.
gather
(
probs_sort
,
dim
=-
1
,
index
=
torch
.
argsort
(
probs_idx
,
dim
=-
1
))
return
probs
def
_get_topk_logprobs
(
logprobs
:
torch
.
Tensor
,
num_logprobs
:
int
,
)
->
Dict
[
int
,
float
]:
if
num_logprobs
==
0
:
return
{}
topk_logprobs
,
topk_ids
=
torch
.
topk
(
logprobs
,
num_logprobs
)
if
num_logprobs
==
1
:
topk_logprobs
=
[
topk_logprobs
.
item
()]
topk_ids
=
[
topk_ids
.
item
()]
else
:
topk_logprobs
=
topk_logprobs
.
tolist
()
topk_ids
=
topk_ids
.
tolist
()
token_to_logprob
:
Dict
[
int
,
float
]
=
{}
for
token_id
,
logprob
in
zip
(
topk_ids
,
topk_logprobs
):
token_to_logprob
[
token_id
]
=
logprob
return
token_to_logprob
def
_sample_from_prompt
(
prob
:
torch
.
Tensor
,
sampling_params
:
SamplingParams
,
)
->
List
[
int
]:
if
sampling_params
.
use_beam_search
:
# Beam search.
beam_width
=
sampling_params
.
n
_
,
next_token_ids
=
torch
.
topk
(
prob
,
beam_width
)
next_token_ids
=
next_token_ids
.
tolist
()
elif
sampling_params
.
temperature
==
0.0
:
# Greedy sampling.
assert
sampling_params
.
n
==
1
next_token_id
=
torch
.
argmax
(
prob
)
next_token_ids
=
[
next_token_id
.
item
()]
else
:
# Neucleus sampling.
# Sample n tokens for the prompt.
n
=
sampling_params
.
n
next_token_ids
=
torch
.
multinomial
(
prob
,
num_samples
=
n
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
tolist
()
return
next_token_ids
def
_sample_from_generation_tokens
(
seq_ids
:
List
[
int
],
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
seq_logprobs
:
List
[
float
],
sampling_params
:
SamplingParams
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
# NOTE(woosuk): sampling_params.n can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if
sampling_params
.
use_beam_search
:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs
=
torch
.
tensor
(
seq_logprobs
,
dtype
=
torch
.
float
,
device
=
logprobs
.
device
)
logprobs
=
logprobs
+
seq_logprobs
.
unsqueeze
(
dim
=
1
)
vocab_size
=
logprobs
.
size
(
-
1
)
beam_width
=
len
(
seq_ids
)
_
,
topk_ids
=
torch
.
topk
(
logprobs
.
flatten
(),
beam_width
)
seq_idx
=
torch
.
div
(
topk_ids
,
vocab_size
,
rounding_mode
=
'floor'
).
tolist
()
beam_seq_ids
=
[
seq_ids
[
i
]
for
i
in
seq_idx
]
token_ids
=
(
topk_ids
%
vocab_size
).
tolist
()
beam_outputs
:
Dict
[
int
,
Tuple
[
int
,
int
]]
=
{}
outstanding_beams
:
List
[
Tuple
[
int
,
int
]]
=
[]
# If a beam survives, continue with it.
for
seq_id
,
token_id
in
zip
(
beam_seq_ids
,
token_ids
):
if
seq_id
not
in
beam_outputs
:
beam_outputs
[
seq_id
]
=
(
seq_id
,
token_id
)
else
:
outstanding_beams
.
append
((
seq_id
,
token_id
))
# If a beam is discarded, fork another beam.
for
seq_id
in
seq_ids
:
if
seq_id
not
in
beam_outputs
:
beam_outputs
[
seq_id
]
=
outstanding_beams
.
pop
()
assert
not
outstanding_beams
parent_seq_ids
=
[
beam_outputs
[
seq_id
][
0
]
for
seq_id
in
seq_ids
]
next_token_ids
=
[
beam_outputs
[
seq_id
][
1
]
for
seq_id
in
seq_ids
]
elif
sampling_params
.
temperature
==
0.0
:
# Greedy sampling.
assert
len
(
seq_ids
)
==
1
next_token_id
=
torch
.
argmax
(
probs
,
dim
=-
1
)
next_token_ids
=
[
next_token_id
.
item
()]
parent_seq_ids
=
seq_ids
else
:
# Neucleus sampling.
# Sample 1 token for each sequence in the group.
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
squeeze
(
dim
=-
1
).
tolist
()
parent_seq_ids
=
seq_ids
return
parent_seq_ids
,
next_token_ids
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
Dict
[
int
,
SequenceOutputs
]:
seq_outputs
:
Dict
[
int
,
SequenceOutputs
]
=
{}
# TODO(woosuk): Optimize.
idx
=
0
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
# Generate the next tokens for a prompt input.
assert
len
(
seq_ids
)
==
sampling_params
.
n
prob
=
probs
[
idx
]
logprob
=
logprobs
[
idx
]
idx
+=
1
# Sample the next tokens.
next_token_ids
=
_sample_from_prompt
(
prob
,
sampling_params
)
# Get top-k log probabilities for the next tokens.
next_logprobs
=
_get_topk_logprobs
(
logprob
,
sampling_params
.
num_logprobs
)
# Build the output.
for
seq_id
,
next_token_id
in
zip
(
seq_ids
,
next_token_ids
):
output_logprobs
=
next_logprobs
.
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
next_token_id
].
item
()
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
seq_id
,
next_token_id
,
output_logprobs
)
else
:
# Generate the next tokens for generation tokens.
prob
=
probs
[
idx
:
idx
+
len
(
seq_ids
)]
logprob
=
logprobs
[
idx
:
idx
+
len
(
seq_ids
)]
idx
+=
len
(
seq_ids
)
# Sample the next tokens.
seq_logprobs
=
[
input_metadata
.
seq_logprobs
[
seq_id
]
for
seq_id
in
seq_ids
]
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
# Get top-k log probabilities for the next tokens.
next_logprobs
:
Dict
[
int
,
Dict
[
int
,
float
]]
=
{}
for
i
,
seq_id
in
enumerate
(
seq_ids
):
next_logprobs
[
seq_id
]
=
_get_topk_logprobs
(
logprob
[
i
],
sampling_params
.
num_logprobs
)
# Build the output.
for
seq_id
,
parent_seq_id
,
next_token_id
in
zip
(
seq_ids
,
parent_seq_ids
,
next_token_ids
):
i
=
seq_ids
.
index
(
parent_seq_id
)
output_logprobs
=
next_logprobs
[
parent_seq_id
].
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
i
,
next_token_id
].
item
()
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
parent_seq_id
,
next_token_id
,
output_logprobs
,
)
# Return the next tokens.
next_tokens
:
Dict
[
int
,
Tuple
[
int
,
int
]]
=
{}
for
seq_id
,
token_id
in
zip
(
input_metadata
.
seq_ids
,
next_token_ids
):
next_tokens
[
seq_id
]
=
(
seq_id
,
token_id
)
return
next_tokens
return
seq_outputs
cacheflow/sampling_params.py
View file @
1a7eb7da
...
...
@@ -5,27 +5,51 @@ class SamplingParams:
def
__init__
(
self
,
n
:
int
=
1
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
use_beam_search
:
bool
=
False
,
stop_token_ids
:
Set
[
int
]
=
[],
max_num_steps
:
int
=
16
,
# From OpenAI API.
max_context_len
:
Optional
[
int
]
=
None
,
n
:
int
,
temperature
:
float
,
top_p
:
float
,
use_beam_search
:
bool
,
stop_token_ids
:
Set
[
int
],
max_num_steps
:
int
,
num_logprobs
:
int
,
context_window_size
:
Optional
[
int
],
)
->
None
:
assert
n
>=
1
assert
temperature
>=
0.0
assert
0.0
<
top_p
<=
1.0
if
n
<
1
:
raise
ValueError
(
f
'n must be at least 1, got
{
n
}
.'
)
if
temperature
<
0.0
:
raise
ValueError
(
f
'temperature must be non-negative, got
{
temperature
}
.'
)
if
not
0.0
<
top_p
<=
1.0
:
raise
ValueError
(
f
'top_p must be in (0, 1], got
{
top_p
}
.'
)
if
max_num_steps
<
1
:
raise
ValueError
(
f
'max_num_steps must be at least 1, got
{
max_num_steps
}
.'
)
if
num_logprobs
<
0
:
raise
ValueError
(
f
'num_logprobs must be non-negative, got
{
num_logprobs
}
.'
)
if
context_window_size
is
not
None
and
context_window_size
<
0
:
raise
ValueError
(
'context_window_size must be non-negative, '
f
'got
{
context_window_size
}
.'
)
if
use_beam_search
:
assert
n
>
1
assert
temperature
>
0.0
assert
top_p
==
1.0
if
n
==
1
:
raise
ValueError
(
'n must be greater than 1 when using beam search.'
)
if
temperature
>
0.0
:
raise
ValueError
(
'temperature must be 0 when using beam search.'
)
if
top_p
<
1.0
:
raise
ValueError
(
'top_p must be 1 when using beam search.'
)
elif
temperature
==
0.0
:
# Zero temperature means greedy decoding.
assert
n
==
1
assert
top_p
==
1.0
assert
max_num_steps
>=
1
assert
max_context_len
is
None
or
max_context_len
>=
0
# Zero temperature means greedy sampling.
if
n
>
1
:
raise
ValueError
(
'n must be 1 when using greedy sampling.'
)
if
top_p
<
1.0
:
raise
ValueError
(
'top_p must be 1 when using greedy sampling.'
)
self
.
n
=
n
self
.
temperature
=
temperature
...
...
@@ -33,4 +57,15 @@ class SamplingParams:
self
.
use_beam_search
=
use_beam_search
self
.
stop_token_ids
=
stop_token_ids
self
.
max_num_steps
=
max_num_steps
self
.
max_context_len
=
max_context_len
self
.
num_logprobs
=
num_logprobs
self
.
context_window_size
=
context_window_size
def
__repr__
(
self
)
->
str
:
return
(
f
'SamplingParams(n=
{
self
.
n
}
, '
f
'temperature=
{
self
.
temperature
}
, '
f
'top_p=
{
self
.
top_p
}
, '
f
'use_beam_search=
{
self
.
use_beam_search
}
, '
f
'stop_token_ids=
{
self
.
stop_token_ids
}
, '
f
'max_num_steps=
{
self
.
max_num_steps
}
, '
f
'num_logprobs=
{
self
.
num_logprobs
}
, '
f
'context_window_size=
{
self
.
context_window_size
}
)'
)
cacheflow/sequence.py
View file @
1a7eb7da
import
copy
import
enum
from
typing
import
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
cacheflow.block
import
LogicalTokenBlock
from
cacheflow.sampling_params
import
SamplingParams
class
SequenceStatus
(
enum
.
Enum
):
...
...
@@ -24,9 +26,11 @@ class Sequence:
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
# Initialize the logical token blocks with the given token ids.
self
.
a
ppen
d
(
token_ids
)
self
.
a
d
d
(
token_ids
)
self
.
status
=
SequenceStatus
.
PENDING
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
cumulative_logprobs
=
1.0
def
add_block
(
self
)
->
None
:
block
=
LogicalTokenBlock
(
...
...
@@ -35,7 +39,7 @@ class Sequence:
)
self
.
logical_token_blocks
.
append
(
block
)
def
a
ppen
d
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
a
d
d
(
self
,
token_ids
:
List
[
int
])
->
None
:
while
token_ids
:
if
not
self
.
logical_token_blocks
:
self
.
add_block
()
...
...
@@ -49,6 +53,12 @@ class Sequence:
last_block
.
append
(
token_ids
[:
num_empty_slots
])
token_ids
=
token_ids
[
num_empty_slots
:]
def
append
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
float
])
->
None
:
assert
token_id
in
logprobs
self
.
add
([
token_id
])
self
.
output_logprobs
.
append
(
logprobs
)
self
.
cumulative_logprobs
+=
logprobs
[
token_id
]
def
get_len
(
self
)
->
int
:
return
sum
(
block
.
num_tokens
for
block
in
self
.
logical_token_blocks
)
...
...
@@ -58,6 +68,14 @@ class Sequence:
token_ids
.
extend
(
block
.
get_token_ids
())
return
token_ids
def
get_last_token_id
(
self
)
->
int
:
return
self
.
logical_token_blocks
[
-
1
].
get_last_token_id
()
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
'Sequence'
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
cumulative_logprobs
=
self
.
cumulative_logprobs
def
__repr__
(
self
)
->
str
:
return
(
f
'Sequence(seq_id=
{
self
.
seq_id
}
, '
f
'status=
{
self
.
status
.
name
}
, '
...
...
@@ -74,11 +92,17 @@ class SequenceGroup:
self
.
group_id
=
group_id
self
.
seqs
=
seqs
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
def
get_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
,
)
->
List
[
Sequence
]:
if
status
is
None
:
return
len
(
self
.
seqs
)
return
self
.
seqs
else
:
return
len
([
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
])
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
]
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
return
len
(
self
.
get_seqs
(
status
))
def
find
(
self
,
seq_id
:
int
)
->
Sequence
:
for
seq
in
self
.
seqs
:
...
...
@@ -92,3 +116,45 @@ class SequenceGroup:
def
__repr__
(
self
)
->
str
:
return
(
f
'SequenceGroup(group_id=
{
self
.
group_id
}
, '
f
'num_seqs=
{
len
(
self
.
seqs
)
}
)'
)
class
SequenceGroupInputs
:
def
__init__
(
self
,
group_id
:
int
,
is_prompt
:
bool
,
input_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> token ids.
context_len
:
int
,
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of physical block numbers.
)
->
None
:
self
.
group_id
=
group_id
self
.
is_prompt
=
is_prompt
self
.
input_tokens
=
input_tokens
self
.
context_len
=
context_len
self
.
seq_logprobs
=
seq_logprobs
self
.
sampling_params
=
sampling_params
self
.
block_tables
=
block_tables
class
SequenceOutputs
:
def
__init__
(
self
,
seq_id
:
int
,
parent_seq_id
:
int
,
output_token
:
int
,
logprobs
:
Dict
[
int
,
float
],
# Token id -> logP(x_i+1 | x_0, ..., x_i).
)
->
None
:
self
.
seq_id
=
seq_id
self
.
parent_seq_id
=
parent_seq_id
self
.
output_token
=
output_token
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
return
(
f
'SequenceOutputs(seq_id=
{
self
.
seq_id
}
, '
f
'parent_seq_id=
{
self
.
parent_seq_id
}
, '
f
'output_token=
{
self
.
output_token
}
), '
f
'logprobs=
{
self
.
logprobs
}
'
)
cacheflow/worker/cache_engine.py
View file @
1a7eb7da
...
...
@@ -97,7 +97,7 @@ class CacheEngine:
cpu_cache
.
append
((
key_blocks
,
value_blocks
))
return
cpu_cache
def
_
copy_blocks
(
def
_
swap
(
self
,
src
:
List
[
KVCache
],
dst
:
List
[
KVCache
],
...
...
@@ -108,19 +108,38 @@ class CacheEngine:
src_key_cache
,
src_value_cache
=
src
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
# Copy the key blocks.
cache_ops
.
copy_cache
_blocks
(
cache_ops
.
swap
_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
# Copy the value blocks.
cache_ops
.
copy_cache
_blocks
(
cache_ops
.
swap
_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
event
=
self
.
events
[
i
]
event
.
record
(
stream
=
self
.
cache_stream
)
def
copy
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_copy_blocks
(
self
.
gpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_
copy_blocks
(
self
.
cpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
self
.
_
swap
(
self
.
cpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_copy_blocks
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
self
.
_swap
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
def
_copy
(
self
,
src
:
List
[
KVCache
],
dst
:
List
[
KVCache
],
src_to_dsts
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
with
torch
.
cuda
.
stream
(
self
.
cache_stream
):
for
i
in
range
(
self
.
num_layers
):
src_key_cache
,
src_value_cache
=
src
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
# Copy the key blocks.
cache_ops
.
copy_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dsts
)
# Copy the value blocks.
cache_ops
.
copy_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dsts
)
event
=
self
.
events
[
i
]
event
.
record
(
stream
=
self
.
cache_stream
)
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]])
->
None
:
self
.
_copy
(
self
.
gpu_cache
,
self
.
gpu_cache
,
src_to_dsts
)
cacheflow/worker/controller.py
View file @
1a7eb7da
from
typing
import
Dict
,
List
,
Union
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.worker.worker
import
Worker
...
...
@@ -14,7 +15,8 @@ class Controller:
block_size
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
dtype
:
str
=
'half'
,
dtype
:
str
,
seed
:
int
,
)
->
None
:
self
.
node_id
=
node_id
self
.
num_workers
=
num_workers
...
...
@@ -37,6 +39,7 @@ class Controller:
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
dtype
=
dtype
,
seed
=
seed
,
)
self
.
workers
.
append
(
worker
)
...
...
@@ -49,22 +52,16 @@ class Controller:
def
execute_stage
(
self
,
prompt_tokens
:
Dict
[
int
,
List
[
int
]],
generation_tokens
:
Dict
[
int
,
int
],
context_lens
:
Dict
[
int
,
int
],
block_tables
:
Dict
[
int
,
List
[
int
]],
input_seq_groups
:
List
[
SequenceGroupInputs
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
,
)
->
None
:
# FIXME: Support tensor parallelism.
assert
len
(
self
.
workers
)
==
1
worker
=
self
.
workers
[
0
]
output
=
worker
.
execute_stage
(
prompt_tokens
,
generation_tokens
,
context_lens
,
block_tables
,
input_seq_groups
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
...
...
cacheflow/worker/worker.py
View file @
1a7eb7da
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
cacheflow.models
import
get_model
from
cacheflow.models
import
set_seed
from
cacheflow.models
import
InputMetadata
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.worker.cache_engine
import
CacheEngine
...
...
@@ -18,6 +22,7 @@ class Worker:
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
dtype
:
str
,
seed
:
int
,
)
->
None
:
self
.
worker_id
=
worker_id
self
.
gpu_id
=
gpu_id
...
...
@@ -33,6 +38,11 @@ class Worker:
self
.
head_size
=
self
.
model
.
config
.
hidden_size
//
self
.
num_heads
self
.
dtype
=
self
.
model
.
dtype
# Set the seed.
# We set the seed after initializing the model to ensure that
# the random state is not affected by the model initialization.
set_seed
(
seed
)
self
.
cache_engine
=
CacheEngine
(
worker_id
=
worker_id
,
gpu_id
=
gpu_id
,
...
...
@@ -49,55 +59,81 @@ class Worker:
def
prepare_inputs
(
self
,
prompt_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of input token ids.
generation_tokens
:
Dict
[
int
,
int
],
# Seq id -> Input token id.
context_lens
:
Dict
[
int
,
int
],
# Seq id -> Number of tokens participating in attention.
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of physical block numbers.
input_seq_groups
:
List
[
SequenceGroupInputs
],
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
# TODO(woosuk): Support interactive generation.
# Add the prompt tokens.
prompt_lens
:
List
[
int
]
=
[]
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
sampling_params
:
Dict
[
int
,
SamplingParams
]
=
{}
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
prompt_seq_ids
=
sorted
(
prompt_tokens
.
keys
())
for
seq_id
in
prompt_seq_ids
:
prompt_len
=
len
(
prompt_tokens
[
seq_id
])
# Add prompt tokens.
prompt_lens
:
List
[
int
]
=
[]
for
input_seq_group
in
input_seq_groups
:
if
not
input_seq_group
.
is_prompt
:
continue
seq_ids
=
list
(
input_seq_group
.
input_tokens
.
keys
())
sampling_params
=
input_seq_group
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
input_seq_group
.
seq_logprobs
)
# Use any sequence in the group.
seq_id
=
seq_ids
[
0
]
prompt_tokens
=
input_seq_group
.
input_tokens
[
seq_id
]
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
extend
(
prompt_tokens
[
seq_id
])
input_positions
.
extend
(
range
(
len
(
prompt_tokens
[
seq_id
])))
input_tokens
.
extend
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
extend
(
range
(
len
(
prompt_tokens
)))
block_table
=
block_tables
[
seq_id
]
# Compute the slot mapping.
block_table
=
input_seq_group
.
block_tables
[
seq_id
]
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
# Add
the
generation tokens.
# Add generation tokens.
max_context_len
=
0
max_num_blocks_per_seq
=
0
context_lens
:
List
[
int
]
=
[]
generation_block_tables
:
List
[
List
[
int
]]
=
[]
generation_seq_ids
=
sorted
(
generation_tokens
.
keys
())
for
seq_id
in
generation_seq_ids
:
input_tokens
.
append
(
generation_tokens
[
seq_id
])
position_id
=
context_lens
[
seq_id
]
-
1
input_positions
.
append
(
position_id
)
block_table
=
block_tables
[
seq_id
]
generation_block_tables
.
append
(
block_table
)
max_context_len
=
max
(
max_context_len
,
context_lens
[
seq_id
])
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
,
len
(
block_table
))
block_number
=
block_table
[
position_id
//
self
.
block_size
]
block_offset
=
position_id
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
for
input_seq_group
in
input_seq_groups
:
if
input_seq_group
.
is_prompt
:
continue
seq_ids
=
list
(
input_seq_group
.
input_tokens
.
keys
())
sampling_params
=
input_seq_group
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
input_seq_group
.
seq_logprobs
)
for
seq_id
in
seq_ids
:
assert
len
(
input_seq_group
.
input_tokens
[
seq_id
])
==
1
generation_token
=
input_seq_group
.
input_tokens
[
seq_id
][
0
]
input_tokens
.
append
(
generation_token
)
position
=
input_seq_group
.
context_len
-
1
input_positions
.
append
(
position
)
block_table
=
input_seq_group
.
block_tables
[
seq_id
]
generation_block_tables
.
append
(
block_table
)
max_context_len
=
max
(
max_context_len
,
input_seq_group
.
context_len
)
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
,
len
(
block_table
))
context_lens
.
append
(
input_seq_group
.
context_len
)
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
# Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
...
...
@@ -112,8 +148,7 @@ class Worker:
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
context_lens_tensor
=
torch
.
tensor
(
[
context_lens
[
seq_id
]
for
seq_id
in
generation_seq_ids
],
dtype
=
torch
.
int
,
device
=
self
.
device
)
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
padded_block_tables
=
[
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
)
for
block_table
in
generation_block_tables
]
...
...
@@ -121,7 +156,8 @@ class Worker:
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
input_metadata
=
InputMetadata
(
seq_ids
=
prompt_seq_ids
+
generation_seq_ids
,
seq_groups
=
seq_groups
,
seq_logprobs
=
seq_logprobs
,
prompt_lens
=
prompt_lens
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
...
...
@@ -133,14 +169,11 @@ class Worker:
@
torch
.
inference_mode
()
def
execute_stage
(
self
,
prompt_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of input token ids.
generation_tokens
:
Dict
[
int
,
int
],
# Seq id -> Input token id.
context_lens
:
Dict
[
int
,
int
],
# Seq id -> Number of tokens participating in attention.
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of physical block numbers.
input_seq_groups
:
List
[
SequenceGroupInputs
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
int
],
)
->
Union
[
torch
.
Tensor
,
Dict
[
int
,
Tuple
[
int
,
int
]]
]:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
,
)
->
Dict
[
int
,
SequenceOutputs
]:
# Issue cache operations.
command_issued
=
False
if
blocks_to_swap_in
:
...
...
@@ -160,7 +193,7 @@ class Worker:
# Prepare input tensors.
input_tokens
,
input_positions
,
input_metadata
=
self
.
prepare_inputs
(
prompt_tokens
,
generation_tokens
,
context_lens
,
block_table
s
)
input_seq_group
s
)
# Execute the model.
output
=
self
.
model
(
...
...
csrc/cache.cpp
View file @
1a7eb7da
#include <torch/extension.h>
void
copy_blocks
(
#include <map>
#include <vector>
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
);
void
copy_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>&
block_mapping
);
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
...
...
@@ -14,7 +22,11 @@ void reshape_and_cache(
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"copy_cache_blocks"
,
"swap_blocks"
,
&
swap_blocks
,
"Swap in (out) the cache blocks from src to dst"
);
m
.
def
(
"copy_blocks"
,
&
copy_blocks
,
"Copy the cache blocks from src to dst"
);
m
.
def
(
...
...
csrc/cache_kernels.cu
View file @
1a7eb7da
...
...
@@ -5,8 +5,9 @@
#include <algorithm>
#include <cassert>
#include <map>
#include <vector>
void
copy
_blocks
(
void
swap
_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
)
{
...
...
@@ -43,6 +44,35 @@ void copy_blocks(
}
}
void
copy_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>&
block_mapping
)
{
torch
::
Device
src_device
=
src
.
device
();
torch
::
Device
dst_device
=
dst
.
device
();
assert
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cuda
());
cudaMemcpyKind
memcpy_type
=
cudaMemcpyDeviceToDevice
;
void
*
src_ptr
=
src
.
data_ptr
();
void
*
dst_ptr
=
dst
.
data_ptr
();
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
[
0
].
numel
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
for
(
const
auto
&
pair
:
block_mapping
)
{
int64_t
src_block_number
=
pair
.
first
;
for
(
int64_t
dst_block_number
:
pair
.
second
)
{
int64_t
src_offset
=
src_block_number
*
block_size_in_bytes
;
int64_t
dst_offset
=
dst_block_number
*
block_size_in_bytes
;
cudaMemcpyAsync
(
dst_ptr
+
dst_offset
,
src_ptr
+
src_offset
,
block_size_in_bytes
,
memcpy_type
,
stream
);
}
}
}
template
<
typename
scalar_t
>
__global__
void
reshape_and_cache_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
...
...
server.py
View file @
1a7eb7da
...
...
@@ -15,6 +15,8 @@ parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of
parser
.
add_argument
(
'--num-cpu-blocks'
,
type
=
int
,
default
=
32
,
help
=
'number of CPU blocks (per GPU)'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
args
=
parser
.
parse_args
()
...
...
@@ -30,6 +32,7 @@ def main():
num_gpu_blocks
=
args
.
num_gpu_blocks
,
num_cpu_blocks
=
args
.
num_cpu_blocks
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
)
controllers
.
append
(
controller
)
...
...
@@ -52,18 +55,18 @@ def main():
controllers
[
i
].
set_next
(
controllers
[
i
+
1
])
controllers
[
-
1
].
set_next
(
scheduler
)
# Test the following inputs.
test_inputs
=
[
'Ion Stoica is a'
,
'UC Berkeley is'
,
'The future of cloud computing is'
,
(
'Ion Stoica is a'
,
{
'n'
:
4
,
'use_beam_search'
:
True
,
'temperature'
:
0.0
}),
(
'UC Berkeley is'
,
{
'n'
:
3
,
'temperature'
:
0.8
,
'top_p'
:
0.99
}),
(
'The future of cloud computing is'
,
{}),
# Use default parameters.
]
# FIXME
while
True
:
if
test_inputs
:
frontend
.
query
(
test_inputs
.
pop
())
text
,
sampling_params
=
test_inputs
.
pop
(
0
)
frontend
.
query
(
text
,
**
sampling_params
)
scheduler
.
step
()
if
not
scheduler
.
pending
and
not
scheduler
.
running
:
if
not
(
scheduler
.
pending
or
scheduler
.
running
or
test_inputs
)
:
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment