Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1a7eb7da
"router/vscode:/vscode.git/clone" did not exist on "9ecfa16b12c13ac6ed136929258829208ed8afc5"
Unverified
Commit
1a7eb7da
authored
Mar 10, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 10, 2023
Browse files
Support beam search & parallel generation (#7)
parent
04e5acc0
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
662 additions
and
163 deletions
+662
-163
cacheflow/block.py
cacheflow/block.py
+4
-0
cacheflow/master/frontend.py
cacheflow/master/frontend.py
+28
-5
cacheflow/master/scheduler.py
cacheflow/master/scheduler.py
+55
-36
cacheflow/models/__init__.py
cacheflow/models/__init__.py
+3
-1
cacheflow/models/input_metadata.py
cacheflow/models/input_metadata.py
+11
-7
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+10
-0
cacheflow/models/opt.py
cacheflow/models/opt.py
+2
-1
cacheflow/models/sample.py
cacheflow/models/sample.py
+258
-17
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+54
-19
cacheflow/sequence.py
cacheflow/sequence.py
+72
-6
cacheflow/worker/cache_engine.py
cacheflow/worker/cache_engine.py
+27
-8
cacheflow/worker/controller.py
cacheflow/worker/controller.py
+7
-10
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+76
-43
csrc/cache.cpp
csrc/cache.cpp
+14
-2
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+31
-1
server.py
server.py
+10
-7
No files found.
cacheflow/block.py
View file @
1a7eb7da
...
...
@@ -35,6 +35,10 @@ class LogicalTokenBlock:
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
token_ids
[:
self
.
num_tokens
]
def
get_last_token_id
(
self
)
->
int
:
assert
self
.
num_tokens
>
0
return
self
.
token_ids
[
self
.
num_tokens
-
1
]
class
PhysicalTokenBlock
:
...
...
cacheflow/master/frontend.py
View file @
1a7eb7da
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
transformers
import
AutoTokenizer
...
...
@@ -25,12 +25,35 @@ class Frontend:
def
query
(
self
,
prompt
:
str
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
n
:
int
=
1
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
use_beam_search
:
bool
=
False
,
stop_token_ids
:
Set
[
int
]
=
set
(),
max_num_steps
:
int
=
16
,
# From OpenAI API.
num_logprobs
:
int
=
0
,
context_window_size
:
Optional
[
int
]
=
None
,
)
->
None
:
if
sampling_params
is
None
:
sampling_params
=
SamplingParams
()
token_ids
:
List
[
int
]
=
self
.
tokenizer
.
encode
(
prompt
)
# Stop when we see an EOS token.
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
sampling_params
=
SamplingParams
(
n
=
n
,
temperature
=
temperature
,
top_p
=
top_p
,
use_beam_search
=
use_beam_search
,
stop_token_ids
=
stop_token_ids
,
max_num_steps
=
max_num_steps
,
num_logprobs
=
num_logprobs
,
context_window_size
=
context_window_size
,
)
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
self
.
_add_query
(
token_ids
,
sampling_params
)
def
_add_query
(
self
,
token_ids
:
List
[
int
],
sampling_params
:
SamplingParams
,
)
->
None
:
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
...
...
cacheflow/master/scheduler.py
View file @
1a7eb7da
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
from
cacheflow.master.block_manager
import
BlockSpaceManager
from
cacheflow.master.frontend
import
Frontend
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.sequence
import
SequenceStatus
_MAX_NUM_BATCHED_TOKENS
=
2048
...
...
@@ -66,7 +68,7 @@ class Scheduler:
def
_append
(
self
,
seq_group
:
SequenceGroup
,
blocks_to_copy
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
,
)
->
None
:
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
...
...
@@ -74,7 +76,10 @@ class Scheduler:
ret
=
self
.
block_manager
.
append
(
seq
)
if
ret
is
not
None
:
src_block
,
dst_block
=
ret
blocks_to_copy
[
src_block
]
=
dst_block
if
src_block
in
blocks_to_copy
:
blocks_to_copy
[
src_block
].
append
(
dst_block
)
else
:
blocks_to_copy
[
src_block
]
=
[
dst_block
]
def
_swap_in
(
self
,
...
...
@@ -83,8 +88,7 @@ class Scheduler:
)
->
None
:
mapping
=
self
.
block_manager
.
swap_in
(
seq_group
)
blocks_to_swap_in
.
update
(
mapping
)
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
SWAPPED
:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
seq
.
status
=
SequenceStatus
.
RUNNING
self
.
running
.
append
(
seq_group
)
...
...
@@ -96,8 +100,7 @@ class Scheduler:
assert
self
.
block_manager
.
can_swap_out
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_out
(
seq_group
)
blocks_to_swap_out
.
update
(
mapping
)
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
RUNNING
:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
seq
.
status
=
SequenceStatus
.
SWAPPED
self
.
swapped
.
append
(
seq_group
)
...
...
@@ -105,7 +108,7 @@ class Scheduler:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
=
{}
# 1. Reserve new slots for the running sequences.
# NOTE: Here we implicitly assume FCFS scheduling.
...
...
@@ -143,6 +146,10 @@ class Scheduler:
# All swapped sequences are swapped in.
self
.
swapped
.
clear
()
# Ensure that swap-in and swap-out never happen at the same timestep.
if
blocks_to_swap_in
:
assert
not
blocks_to_swap_out
num_batched_tokens
=
sum
(
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
for
seq_group
in
self
.
running
...
...
@@ -152,7 +159,6 @@ class Scheduler:
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
if
not
self
.
swapped
:
# FIXME(woosuk): Acquire a lock to protect pending.
self
.
_fetch_inputs
()
for
i
,
seq_group
in
enumerate
(
self
.
pending
):
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
...
...
@@ -168,39 +174,45 @@ class Scheduler:
else
:
self
.
pending
.
clear
()
# Ensure that swap-in and swap-out never happen at the same timestep.
if
blocks_to_swap_in
:
assert
not
blocks_to_swap_out
# 4. Create input data structures.
prompt_tokens
:
Dict
[
int
,
List
[
int
]]
=
{}
generation_tokens
:
Dict
[
int
,
int
]
=
{}
context_lens
:
Dict
[
int
,
int
]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
input_seq_groups
:
List
[
SequenceGroupInputs
]
=
[]
for
seq_group
in
self
.
running
:
group_id
=
seq_group
.
group_id
num_steps
=
self
.
num_steps
[
group_id
]
# NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences.
is_prompt
=
num_steps
==
0
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
!=
SequenceStatus
.
RUNNING
:
continue
input_tokens
:
Dict
[
int
,
List
[
int
]]
=
{}
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
seq_id
=
seq
.
seq_id
block_tables
[
seq_id
]
=
self
.
block_manager
.
get_block_table
(
seq
)
if
is_prompt
:
promp
t_tokens
[
seq_id
]
=
seq
.
get_token_ids
()
inpu
t_tokens
[
seq_id
]
=
seq
.
get_token_ids
()
else
:
generation_tokens
[
seq_id
]
=
seq
.
get_token_ids
()[
-
1
]
context_lens
[
seq_id
]
=
seq
.
get_len
()
input_tokens
[
seq_id
]
=
[
seq
.
get_last_token_id
()]
seq_logprobs
[
seq_id
]
=
seq
.
cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len
=
seq
.
get_len
()
input_seq_group
=
SequenceGroupInputs
(
group_id
=
group_id
,
is_prompt
=
is_prompt
,
input_tokens
=
input_tokens
,
context_len
=
seq_len
,
seq_logprobs
=
seq_logprobs
,
sampling_params
=
self
.
sampling_params
[
group_id
],
block_tables
=
block_tables
,
)
input_seq_groups
.
append
(
input_seq_group
)
# 5. Execute the first stage of the pipeline.
self
.
controllers
[
0
].
execute_stage
(
prompt_tokens
,
generation_tokens
,
context_lens
,
block_tables
,
input_seq_groups
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
...
...
@@ -208,7 +220,7 @@ class Scheduler:
def
post_step
(
self
,
next_token
s
:
Dict
[
int
,
Tuple
[
int
,
int
]
],
seq_output
s
:
Dict
[
int
,
SequenceOutputs
],
)
->
None
:
# Update the running sequences and free blocks.
for
seq_group
in
self
.
running
:
...
...
@@ -216,25 +228,32 @@ class Scheduler:
self
.
num_steps
[
group_id
]
+=
1
stop_token_ids
=
self
.
sampling_params
[
group_id
].
stop_token_ids
# Process beam search results before processing the next tokens.
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
parent_seq_id
,
next_token
=
next_token
s
[
seq
.
seq_id
]
if
seq
.
seq_id
!=
parent_seq_id
:
output
=
seq_output
s
[
seq
.
seq_id
]
if
seq
.
seq_id
!=
output
.
parent_seq_id
:
# The sequence is a fork of the parent sequence (beam search).
# Free the current sequence.
self
.
block_manager
.
free
(
seq
)
# Fork the parent sequence.
parent_seq
=
seq_group
.
find
(
parent_seq_id
)
seq
.
logical_token_blocks
=
parent_seq
.
logical_token_blocks
.
copy
(
)
parent_seq
=
seq_group
.
find
(
output
.
parent_seq_id
)
parent_seq
.
fork
(
seq
)
self
.
block_manager
.
fork
(
parent_seq
,
seq
)
# Process the next tokens.
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
# Append a new token to the sequence.
seq
.
append
([
next_token
])
output
=
seq_outputs
[
seq
.
seq_id
]
seq
.
append
(
output
.
output_token
,
output
.
logprobs
)
# Check if the sequence has generated a stop token.
if
nex
t_token
in
stop_token_ids
:
if
output
.
outpu
t_token
in
stop_token_ids
:
self
.
_free_seq
(
seq
)
continue
...
...
cacheflow/models/__init__.py
View file @
1a7eb7da
from
cacheflow.models.input_metadata
import
InputMetadata
from
cacheflow.models.model_utils
import
get_model
from
cacheflow.models.model_utils
import
set_seed
__all__
=
[
'get_model'
,
'InputMetadata'
,
'get_model'
,
'set_seed'
]
cacheflow/models/input_metadata.py
View file @
1a7eb7da
from
typing
import
List
from
typing
import
List
,
Dict
,
Tuple
import
torch
from
cacheflow.sampling_params
import
SamplingParams
class
InputMetadata
:
def
__init__
(
self
,
seq_ids
:
List
[
int
],
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
prompt_lens
:
List
[
int
],
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
# FIXME: Rename
max_context_len
:
int
,
block_tables
:
torch
.
Tensor
,
)
->
None
:
self
.
seq_ids
=
seq_ids
self
.
seq_groups
=
seq_groups
self
.
seq_logprobs
=
seq_logprobs
self
.
prompt_lens
=
prompt_lens
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
...
...
@@ -23,19 +26,20 @@ class InputMetadata:
self
.
block_tables
=
block_tables
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompt_tokens
=
sum
(
prompt_lens
)
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
if
block_tables
.
numel
()
>
0
:
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
else
:
self
.
max_num_blocks_per_seq
=
0
assert
self
.
num_generation_tokens
==
block_tables
.
shape
[
0
]
assert
self
.
num_prompts
+
self
.
num_generation_tokens
==
len
(
seq_ids
)
assert
block_tables
.
shape
[
0
]
==
self
.
num_generation_tokens
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
def
__repr__
(
self
)
->
str
:
return
(
f
'InputMetadata('
f
'seq_ids=
{
self
.
seq_ids
}
, '
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
...
...
cacheflow/models/model_utils.py
View file @
1a7eb7da
import
random
from
typing
import
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
...
...
@@ -30,3 +32,11 @@ def get_model(
model
=
hf_model
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
)
return
model
.
eval
()
raise
ValueError
(
f
'Invalid model name:
{
model_name
}
'
)
def
set_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
seed
)
cacheflow/models/opt.py
View file @
1a7eb7da
...
...
@@ -9,6 +9,7 @@ from transformers import PreTrainedModel
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
OPTCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.sequence
import
SequenceOutputs
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -261,7 +262,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Tuple
[
int
,
int
]
]:
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
...
...
cacheflow/models/sample.py
View file @
1a7eb7da
...
...
@@ -4,6 +4,8 @@ import torch
import
torch.nn
as
nn
from
cacheflow.models
import
InputMetadata
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceOutputs
class
Sampler
(
nn
.
Module
):
...
...
@@ -16,8 +18,42 @@ class Sampler(nn.Module):
embedding
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
Dict
[
int
,
Tuple
[
int
,
int
]]:
# Get the hidden states of the last tokens.
)
->
Dict
[
int
,
SequenceOutputs
]:
# Get the hidden states that we use for sampling.
hidden_states
=
_prune_hidden_states
(
hidden_states
,
input_metadata
)
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input_metadata
)
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
if
any
(
t
!=
1.0
for
t
in
temperatures
):
t
=
torch
.
tensor
(
temperatures
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities (before applying top-p).
logprobs
=
torch
.
log
(
probs
)
# Apply top-p truncation.
top_ps
=
_get_top_ps
(
input_metadata
)
assert
len
(
top_ps
)
==
probs
.
shape
[
0
]
if
any
(
p
<
1.0
for
p
in
top_ps
):
p
=
torch
.
tensor
(
top_ps
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
probs
=
_apply_top_p
(
probs
,
p
)
# Sample the next tokens.
return
_sample
(
probs
,
logprobs
,
input_metadata
)
def
_prune_hidden_states
(
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
start_idx
=
0
last_token_indicies
:
List
[
int
]
=
[]
for
prompt_len
in
input_metadata
.
prompt_lens
:
...
...
@@ -25,18 +61,223 @@ class Sampler(nn.Module):
start_idx
+=
prompt_len
last_token_indicies
.
extend
(
range
(
start_idx
,
start_idx
+
input_metadata
.
num_generation_tokens
))
hidden_states
=
hidden_states
[
last_token_indicies
]
return
hidden_states
[
last_token_indicies
]
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
# Sample the next tokens.
# TODO(woosuk): Implement other sampling methods.
next_token_ids
=
torch
.
argmax
(
logits
,
dim
=-
1
)
def
_get_temperatures
(
input_metadata
:
InputMetadata
,
)
->
List
[
float
]:
# Collect the temperatures for the logits.
temperatures
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
if
temperature
==
0.0
:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature
=
1.0
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
temperatures
.
append
(
temperature
)
else
:
# A generation token.
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
return
temperatures
def
_get_top_ps
(
input_metadata
:
InputMetadata
,
)
->
List
[
float
]:
top_ps
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
top_ps
.
append
(
sampling_params
.
top_p
)
else
:
# A generation token.
top_ps
+=
[
sampling_params
.
top_p
]
*
len
(
seq_ids
)
return
top_ps
def
_apply_top_p
(
probs
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# TODO(woosuk): Optimize.
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
mask
=
(
probs_sum
-
probs_sort
)
>
p
.
unsqueeze
(
dim
=
1
)
probs_sort
[
mask
]
=
0.0
probs_sort
.
div_
(
probs_sort
.
sum
(
dim
=-
1
,
keepdim
=
True
))
probs
=
torch
.
gather
(
probs_sort
,
dim
=-
1
,
index
=
torch
.
argsort
(
probs_idx
,
dim
=-
1
))
return
probs
def
_get_topk_logprobs
(
logprobs
:
torch
.
Tensor
,
num_logprobs
:
int
,
)
->
Dict
[
int
,
float
]:
if
num_logprobs
==
0
:
return
{}
topk_logprobs
,
topk_ids
=
torch
.
topk
(
logprobs
,
num_logprobs
)
if
num_logprobs
==
1
:
topk_logprobs
=
[
topk_logprobs
.
item
()]
topk_ids
=
[
topk_ids
.
item
()]
else
:
topk_logprobs
=
topk_logprobs
.
tolist
()
topk_ids
=
topk_ids
.
tolist
()
token_to_logprob
:
Dict
[
int
,
float
]
=
{}
for
token_id
,
logprob
in
zip
(
topk_ids
,
topk_logprobs
):
token_to_logprob
[
token_id
]
=
logprob
return
token_to_logprob
def
_sample_from_prompt
(
prob
:
torch
.
Tensor
,
sampling_params
:
SamplingParams
,
)
->
List
[
int
]:
if
sampling_params
.
use_beam_search
:
# Beam search.
beam_width
=
sampling_params
.
n
_
,
next_token_ids
=
torch
.
topk
(
prob
,
beam_width
)
next_token_ids
=
next_token_ids
.
tolist
()
elif
sampling_params
.
temperature
==
0.0
:
# Greedy sampling.
assert
sampling_params
.
n
==
1
next_token_id
=
torch
.
argmax
(
prob
)
next_token_ids
=
[
next_token_id
.
item
()]
else
:
# Neucleus sampling.
# Sample n tokens for the prompt.
n
=
sampling_params
.
n
next_token_ids
=
torch
.
multinomial
(
prob
,
num_samples
=
n
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
tolist
()
return
next_token_ids
def
_sample_from_generation_tokens
(
seq_ids
:
List
[
int
],
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
seq_logprobs
:
List
[
float
],
sampling_params
:
SamplingParams
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
# NOTE(woosuk): sampling_params.n can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if
sampling_params
.
use_beam_search
:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs
=
torch
.
tensor
(
seq_logprobs
,
dtype
=
torch
.
float
,
device
=
logprobs
.
device
)
logprobs
=
logprobs
+
seq_logprobs
.
unsqueeze
(
dim
=
1
)
vocab_size
=
logprobs
.
size
(
-
1
)
beam_width
=
len
(
seq_ids
)
_
,
topk_ids
=
torch
.
topk
(
logprobs
.
flatten
(),
beam_width
)
seq_idx
=
torch
.
div
(
topk_ids
,
vocab_size
,
rounding_mode
=
'floor'
).
tolist
()
beam_seq_ids
=
[
seq_ids
[
i
]
for
i
in
seq_idx
]
token_ids
=
(
topk_ids
%
vocab_size
).
tolist
()
beam_outputs
:
Dict
[
int
,
Tuple
[
int
,
int
]]
=
{}
outstanding_beams
:
List
[
Tuple
[
int
,
int
]]
=
[]
# If a beam survives, continue with it.
for
seq_id
,
token_id
in
zip
(
beam_seq_ids
,
token_ids
):
if
seq_id
not
in
beam_outputs
:
beam_outputs
[
seq_id
]
=
(
seq_id
,
token_id
)
else
:
outstanding_beams
.
append
((
seq_id
,
token_id
))
# If a beam is discarded, fork another beam.
for
seq_id
in
seq_ids
:
if
seq_id
not
in
beam_outputs
:
beam_outputs
[
seq_id
]
=
outstanding_beams
.
pop
()
assert
not
outstanding_beams
parent_seq_ids
=
[
beam_outputs
[
seq_id
][
0
]
for
seq_id
in
seq_ids
]
next_token_ids
=
[
beam_outputs
[
seq_id
][
1
]
for
seq_id
in
seq_ids
]
elif
sampling_params
.
temperature
==
0.0
:
# Greedy sampling.
assert
len
(
seq_ids
)
==
1
next_token_id
=
torch
.
argmax
(
probs
,
dim
=-
1
)
next_token_ids
=
[
next_token_id
.
item
()]
parent_seq_ids
=
seq_ids
else
:
# Neucleus sampling.
# Sample 1 token for each sequence in the group.
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
squeeze
(
dim
=-
1
).
tolist
()
parent_seq_ids
=
seq_ids
return
parent_seq_ids
,
next_token_ids
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
Dict
[
int
,
SequenceOutputs
]:
seq_outputs
:
Dict
[
int
,
SequenceOutputs
]
=
{}
# TODO(woosuk): Optimize.
idx
=
0
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
# Generate the next tokens for a prompt input.
assert
len
(
seq_ids
)
==
sampling_params
.
n
prob
=
probs
[
idx
]
logprob
=
logprobs
[
idx
]
idx
+=
1
# Sample the next tokens.
next_token_ids
=
_sample_from_prompt
(
prob
,
sampling_params
)
# Get top-k log probabilities for the next tokens.
next_logprobs
=
_get_topk_logprobs
(
logprob
,
sampling_params
.
num_logprobs
)
# Build the output.
for
seq_id
,
next_token_id
in
zip
(
seq_ids
,
next_token_ids
):
output_logprobs
=
next_logprobs
.
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
next_token_id
].
item
()
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
seq_id
,
next_token_id
,
output_logprobs
)
else
:
# Generate the next tokens for generation tokens.
prob
=
probs
[
idx
:
idx
+
len
(
seq_ids
)]
logprob
=
logprobs
[
idx
:
idx
+
len
(
seq_ids
)]
idx
+=
len
(
seq_ids
)
# Sample the next tokens.
seq_logprobs
=
[
input_metadata
.
seq_logprobs
[
seq_id
]
for
seq_id
in
seq_ids
]
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
# Get top-k log probabilities for the next tokens.
next_logprobs
:
Dict
[
int
,
Dict
[
int
,
float
]]
=
{}
for
i
,
seq_id
in
enumerate
(
seq_ids
):
next_logprobs
[
seq_id
]
=
_get_topk_logprobs
(
logprob
[
i
],
sampling_params
.
num_logprobs
)
# Build the output.
for
seq_id
,
parent_seq_id
,
next_token_id
in
zip
(
seq_ids
,
parent_seq_ids
,
next_token_ids
):
i
=
seq_ids
.
index
(
parent_seq_id
)
output_logprobs
=
next_logprobs
[
parent_seq_id
].
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
i
,
next_token_id
].
item
()
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
parent_seq_id
,
next_token_id
,
output_logprobs
,
)
# Return the next tokens.
next_tokens
:
Dict
[
int
,
Tuple
[
int
,
int
]]
=
{}
for
seq_id
,
token_id
in
zip
(
input_metadata
.
seq_ids
,
next_token_ids
):
next_tokens
[
seq_id
]
=
(
seq_id
,
token_id
)
return
next_tokens
return
seq_outputs
cacheflow/sampling_params.py
View file @
1a7eb7da
...
...
@@ -5,27 +5,51 @@ class SamplingParams:
def
__init__
(
self
,
n
:
int
=
1
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
use_beam_search
:
bool
=
False
,
stop_token_ids
:
Set
[
int
]
=
[],
max_num_steps
:
int
=
16
,
# From OpenAI API.
max_context_len
:
Optional
[
int
]
=
None
,
n
:
int
,
temperature
:
float
,
top_p
:
float
,
use_beam_search
:
bool
,
stop_token_ids
:
Set
[
int
],
max_num_steps
:
int
,
num_logprobs
:
int
,
context_window_size
:
Optional
[
int
],
)
->
None
:
assert
n
>=
1
assert
temperature
>=
0.0
assert
0.0
<
top_p
<=
1.0
if
n
<
1
:
raise
ValueError
(
f
'n must be at least 1, got
{
n
}
.'
)
if
temperature
<
0.0
:
raise
ValueError
(
f
'temperature must be non-negative, got
{
temperature
}
.'
)
if
not
0.0
<
top_p
<=
1.0
:
raise
ValueError
(
f
'top_p must be in (0, 1], got
{
top_p
}
.'
)
if
max_num_steps
<
1
:
raise
ValueError
(
f
'max_num_steps must be at least 1, got
{
max_num_steps
}
.'
)
if
num_logprobs
<
0
:
raise
ValueError
(
f
'num_logprobs must be non-negative, got
{
num_logprobs
}
.'
)
if
context_window_size
is
not
None
and
context_window_size
<
0
:
raise
ValueError
(
'context_window_size must be non-negative, '
f
'got
{
context_window_size
}
.'
)
if
use_beam_search
:
assert
n
>
1
assert
temperature
>
0.0
assert
top_p
==
1.0
if
n
==
1
:
raise
ValueError
(
'n must be greater than 1 when using beam search.'
)
if
temperature
>
0.0
:
raise
ValueError
(
'temperature must be 0 when using beam search.'
)
if
top_p
<
1.0
:
raise
ValueError
(
'top_p must be 1 when using beam search.'
)
elif
temperature
==
0.0
:
# Zero temperature means greedy decoding.
assert
n
==
1
assert
top_p
==
1.0
assert
max_num_steps
>=
1
assert
max_context_len
is
None
or
max_context_len
>=
0
# Zero temperature means greedy sampling.
if
n
>
1
:
raise
ValueError
(
'n must be 1 when using greedy sampling.'
)
if
top_p
<
1.0
:
raise
ValueError
(
'top_p must be 1 when using greedy sampling.'
)
self
.
n
=
n
self
.
temperature
=
temperature
...
...
@@ -33,4 +57,15 @@ class SamplingParams:
self
.
use_beam_search
=
use_beam_search
self
.
stop_token_ids
=
stop_token_ids
self
.
max_num_steps
=
max_num_steps
self
.
max_context_len
=
max_context_len
self
.
num_logprobs
=
num_logprobs
self
.
context_window_size
=
context_window_size
def
__repr__
(
self
)
->
str
:
return
(
f
'SamplingParams(n=
{
self
.
n
}
, '
f
'temperature=
{
self
.
temperature
}
, '
f
'top_p=
{
self
.
top_p
}
, '
f
'use_beam_search=
{
self
.
use_beam_search
}
, '
f
'stop_token_ids=
{
self
.
stop_token_ids
}
, '
f
'max_num_steps=
{
self
.
max_num_steps
}
, '
f
'num_logprobs=
{
self
.
num_logprobs
}
, '
f
'context_window_size=
{
self
.
context_window_size
}
)'
)
cacheflow/sequence.py
View file @
1a7eb7da
import
copy
import
enum
from
typing
import
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
cacheflow.block
import
LogicalTokenBlock
from
cacheflow.sampling_params
import
SamplingParams
class
SequenceStatus
(
enum
.
Enum
):
...
...
@@ -24,9 +26,11 @@ class Sequence:
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
# Initialize the logical token blocks with the given token ids.
self
.
a
ppen
d
(
token_ids
)
self
.
a
d
d
(
token_ids
)
self
.
status
=
SequenceStatus
.
PENDING
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
cumulative_logprobs
=
1.0
def
add_block
(
self
)
->
None
:
block
=
LogicalTokenBlock
(
...
...
@@ -35,7 +39,7 @@ class Sequence:
)
self
.
logical_token_blocks
.
append
(
block
)
def
a
ppen
d
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
a
d
d
(
self
,
token_ids
:
List
[
int
])
->
None
:
while
token_ids
:
if
not
self
.
logical_token_blocks
:
self
.
add_block
()
...
...
@@ -49,6 +53,12 @@ class Sequence:
last_block
.
append
(
token_ids
[:
num_empty_slots
])
token_ids
=
token_ids
[
num_empty_slots
:]
def
append
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
float
])
->
None
:
assert
token_id
in
logprobs
self
.
add
([
token_id
])
self
.
output_logprobs
.
append
(
logprobs
)
self
.
cumulative_logprobs
+=
logprobs
[
token_id
]
def
get_len
(
self
)
->
int
:
return
sum
(
block
.
num_tokens
for
block
in
self
.
logical_token_blocks
)
...
...
@@ -58,6 +68,14 @@ class Sequence:
token_ids
.
extend
(
block
.
get_token_ids
())
return
token_ids
def
get_last_token_id
(
self
)
->
int
:
return
self
.
logical_token_blocks
[
-
1
].
get_last_token_id
()
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
'Sequence'
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
cumulative_logprobs
=
self
.
cumulative_logprobs
def
__repr__
(
self
)
->
str
:
return
(
f
'Sequence(seq_id=
{
self
.
seq_id
}
, '
f
'status=
{
self
.
status
.
name
}
, '
...
...
@@ -74,11 +92,17 @@ class SequenceGroup:
self
.
group_id
=
group_id
self
.
seqs
=
seqs
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
def
get_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
,
)
->
List
[
Sequence
]:
if
status
is
None
:
return
len
(
self
.
seqs
)
return
self
.
seqs
else
:
return
len
([
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
])
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
]
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
return
len
(
self
.
get_seqs
(
status
))
def
find
(
self
,
seq_id
:
int
)
->
Sequence
:
for
seq
in
self
.
seqs
:
...
...
@@ -92,3 +116,45 @@ class SequenceGroup:
def
__repr__
(
self
)
->
str
:
return
(
f
'SequenceGroup(group_id=
{
self
.
group_id
}
, '
f
'num_seqs=
{
len
(
self
.
seqs
)
}
)'
)
class
SequenceGroupInputs
:
def
__init__
(
self
,
group_id
:
int
,
is_prompt
:
bool
,
input_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> token ids.
context_len
:
int
,
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of physical block numbers.
)
->
None
:
self
.
group_id
=
group_id
self
.
is_prompt
=
is_prompt
self
.
input_tokens
=
input_tokens
self
.
context_len
=
context_len
self
.
seq_logprobs
=
seq_logprobs
self
.
sampling_params
=
sampling_params
self
.
block_tables
=
block_tables
class
SequenceOutputs
:
def
__init__
(
self
,
seq_id
:
int
,
parent_seq_id
:
int
,
output_token
:
int
,
logprobs
:
Dict
[
int
,
float
],
# Token id -> logP(x_i+1 | x_0, ..., x_i).
)
->
None
:
self
.
seq_id
=
seq_id
self
.
parent_seq_id
=
parent_seq_id
self
.
output_token
=
output_token
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
return
(
f
'SequenceOutputs(seq_id=
{
self
.
seq_id
}
, '
f
'parent_seq_id=
{
self
.
parent_seq_id
}
, '
f
'output_token=
{
self
.
output_token
}
), '
f
'logprobs=
{
self
.
logprobs
}
'
)
cacheflow/worker/cache_engine.py
View file @
1a7eb7da
...
...
@@ -97,7 +97,7 @@ class CacheEngine:
cpu_cache
.
append
((
key_blocks
,
value_blocks
))
return
cpu_cache
def
_
copy_blocks
(
def
_
swap
(
self
,
src
:
List
[
KVCache
],
dst
:
List
[
KVCache
],
...
...
@@ -108,19 +108,38 @@ class CacheEngine:
src_key_cache
,
src_value_cache
=
src
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
# Copy the key blocks.
cache_ops
.
copy_cache
_blocks
(
cache_ops
.
swap
_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
# Copy the value blocks.
cache_ops
.
copy_cache
_blocks
(
cache_ops
.
swap
_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
event
=
self
.
events
[
i
]
event
.
record
(
stream
=
self
.
cache_stream
)
def
copy
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_copy_blocks
(
self
.
gpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_
copy_blocks
(
self
.
cpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
self
.
_
swap
(
self
.
cpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_copy_blocks
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
self
.
_swap
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
def
_copy
(
self
,
src
:
List
[
KVCache
],
dst
:
List
[
KVCache
],
src_to_dsts
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
with
torch
.
cuda
.
stream
(
self
.
cache_stream
):
for
i
in
range
(
self
.
num_layers
):
src_key_cache
,
src_value_cache
=
src
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
# Copy the key blocks.
cache_ops
.
copy_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dsts
)
# Copy the value blocks.
cache_ops
.
copy_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dsts
)
event
=
self
.
events
[
i
]
event
.
record
(
stream
=
self
.
cache_stream
)
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]])
->
None
:
self
.
_copy
(
self
.
gpu_cache
,
self
.
gpu_cache
,
src_to_dsts
)
cacheflow/worker/controller.py
View file @
1a7eb7da
from
typing
import
Dict
,
List
,
Union
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.worker.worker
import
Worker
...
...
@@ -14,7 +15,8 @@ class Controller:
block_size
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
dtype
:
str
=
'half'
,
dtype
:
str
,
seed
:
int
,
)
->
None
:
self
.
node_id
=
node_id
self
.
num_workers
=
num_workers
...
...
@@ -37,6 +39,7 @@ class Controller:
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
dtype
=
dtype
,
seed
=
seed
,
)
self
.
workers
.
append
(
worker
)
...
...
@@ -49,22 +52,16 @@ class Controller:
def
execute_stage
(
self
,
prompt_tokens
:
Dict
[
int
,
List
[
int
]],
generation_tokens
:
Dict
[
int
,
int
],
context_lens
:
Dict
[
int
,
int
],
block_tables
:
Dict
[
int
,
List
[
int
]],
input_seq_groups
:
List
[
SequenceGroupInputs
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
,
)
->
None
:
# FIXME: Support tensor parallelism.
assert
len
(
self
.
workers
)
==
1
worker
=
self
.
workers
[
0
]
output
=
worker
.
execute_stage
(
prompt_tokens
,
generation_tokens
,
context_lens
,
block_tables
,
input_seq_groups
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
...
...
cacheflow/worker/worker.py
View file @
1a7eb7da
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
cacheflow.models
import
get_model
from
cacheflow.models
import
set_seed
from
cacheflow.models
import
InputMetadata
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.worker.cache_engine
import
CacheEngine
...
...
@@ -18,6 +22,7 @@ class Worker:
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
dtype
:
str
,
seed
:
int
,
)
->
None
:
self
.
worker_id
=
worker_id
self
.
gpu_id
=
gpu_id
...
...
@@ -33,6 +38,11 @@ class Worker:
self
.
head_size
=
self
.
model
.
config
.
hidden_size
//
self
.
num_heads
self
.
dtype
=
self
.
model
.
dtype
# Set the seed.
# We set the seed after initializing the model to ensure that
# the random state is not affected by the model initialization.
set_seed
(
seed
)
self
.
cache_engine
=
CacheEngine
(
worker_id
=
worker_id
,
gpu_id
=
gpu_id
,
...
...
@@ -49,53 +59,79 @@ class Worker:
def
prepare_inputs
(
self
,
prompt_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of input token ids.
generation_tokens
:
Dict
[
int
,
int
],
# Seq id -> Input token id.
context_lens
:
Dict
[
int
,
int
],
# Seq id -> Number of tokens participating in attention.
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of physical block numbers.
input_seq_groups
:
List
[
SequenceGroupInputs
],
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
# TODO(woosuk): Support interactive generation.
# Add the prompt tokens.
prompt_lens
:
List
[
int
]
=
[]
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
sampling_params
:
Dict
[
int
,
SamplingParams
]
=
{}
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
prompt_seq_ids
=
sorted
(
prompt_tokens
.
keys
())
for
seq_id
in
prompt_seq_ids
:
prompt_len
=
len
(
prompt_tokens
[
seq_id
])
# Add prompt tokens.
prompt_lens
:
List
[
int
]
=
[]
for
input_seq_group
in
input_seq_groups
:
if
not
input_seq_group
.
is_prompt
:
continue
seq_ids
=
list
(
input_seq_group
.
input_tokens
.
keys
())
sampling_params
=
input_seq_group
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
input_seq_group
.
seq_logprobs
)
# Use any sequence in the group.
seq_id
=
seq_ids
[
0
]
prompt_tokens
=
input_seq_group
.
input_tokens
[
seq_id
]
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
extend
(
prompt_tokens
[
seq_id
])
input_positions
.
extend
(
range
(
len
(
prompt_tokens
[
seq_id
])))
input_tokens
.
extend
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
extend
(
range
(
len
(
prompt_tokens
)))
block_table
=
block_tables
[
seq_id
]
# Compute the slot mapping.
block_table
=
input_seq_group
.
block_tables
[
seq_id
]
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
# Add
the
generation tokens.
# Add generation tokens.
max_context_len
=
0
max_num_blocks_per_seq
=
0
context_lens
:
List
[
int
]
=
[]
generation_block_tables
:
List
[
List
[
int
]]
=
[]
for
input_seq_group
in
input_seq_groups
:
if
input_seq_group
.
is_prompt
:
continue
seq_ids
=
list
(
input_seq_group
.
input_tokens
.
keys
())
sampling_params
=
input_seq_group
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
input_seq_group
.
seq_logprobs
)
for
seq_id
in
seq_ids
:
assert
len
(
input_seq_group
.
input_tokens
[
seq_id
])
==
1
generation_token
=
input_seq_group
.
input_tokens
[
seq_id
][
0
]
input_tokens
.
append
(
generation_token
)
generation_seq_ids
=
sorted
(
generation_tokens
.
keys
())
for
seq_id
in
generation_seq_ids
:
input_tokens
.
append
(
generation_tokens
[
seq_id
])
position_id
=
context_lens
[
seq_id
]
-
1
input_positions
.
append
(
position_id
)
position
=
input_seq_group
.
context_len
-
1
input_positions
.
append
(
position
)
block_table
=
block_tables
[
seq_id
]
block_table
=
input_seq_group
.
block_tables
[
seq_id
]
generation_block_tables
.
append
(
block_table
)
max_context_len
=
max
(
max_context_len
,
context_lens
[
seq_id
])
max_context_len
=
max
(
max_context_len
,
input_seq_group
.
context_len
)
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
,
len
(
block_table
))
context_lens
.
append
(
input_seq_group
.
context_len
)
block_number
=
block_table
[
position
_id
//
self
.
block_size
]
block_offset
=
position
_id
%
self
.
block_size
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
...
...
@@ -112,8 +148,7 @@ class Worker:
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
context_lens_tensor
=
torch
.
tensor
(
[
context_lens
[
seq_id
]
for
seq_id
in
generation_seq_ids
],
dtype
=
torch
.
int
,
device
=
self
.
device
)
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
padded_block_tables
=
[
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
)
for
block_table
in
generation_block_tables
]
...
...
@@ -121,7 +156,8 @@ class Worker:
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
input_metadata
=
InputMetadata
(
seq_ids
=
prompt_seq_ids
+
generation_seq_ids
,
seq_groups
=
seq_groups
,
seq_logprobs
=
seq_logprobs
,
prompt_lens
=
prompt_lens
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
...
...
@@ -133,14 +169,11 @@ class Worker:
@
torch
.
inference_mode
()
def
execute_stage
(
self
,
prompt_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of input token ids.
generation_tokens
:
Dict
[
int
,
int
],
# Seq id -> Input token id.
context_lens
:
Dict
[
int
,
int
],
# Seq id -> Number of tokens participating in attention.
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of physical block numbers.
input_seq_groups
:
List
[
SequenceGroupInputs
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
int
],
)
->
Union
[
torch
.
Tensor
,
Dict
[
int
,
Tuple
[
int
,
int
]]
]:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
,
)
->
Dict
[
int
,
SequenceOutputs
]:
# Issue cache operations.
command_issued
=
False
if
blocks_to_swap_in
:
...
...
@@ -160,7 +193,7 @@ class Worker:
# Prepare input tensors.
input_tokens
,
input_positions
,
input_metadata
=
self
.
prepare_inputs
(
prompt_tokens
,
generation_tokens
,
context_lens
,
block_table
s
)
input_seq_group
s
)
# Execute the model.
output
=
self
.
model
(
...
...
csrc/cache.cpp
View file @
1a7eb7da
#include <torch/extension.h>
void
copy_blocks
(
#include <map>
#include <vector>
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
);
void
copy_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>&
block_mapping
);
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
...
...
@@ -14,7 +22,11 @@ void reshape_and_cache(
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"copy_cache_blocks"
,
"swap_blocks"
,
&
swap_blocks
,
"Swap in (out) the cache blocks from src to dst"
);
m
.
def
(
"copy_blocks"
,
&
copy_blocks
,
"Copy the cache blocks from src to dst"
);
m
.
def
(
...
...
csrc/cache_kernels.cu
View file @
1a7eb7da
...
...
@@ -5,8 +5,9 @@
#include <algorithm>
#include <cassert>
#include <map>
#include <vector>
void
copy
_blocks
(
void
swap
_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
)
{
...
...
@@ -43,6 +44,35 @@ void copy_blocks(
}
}
void
copy_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>&
block_mapping
)
{
torch
::
Device
src_device
=
src
.
device
();
torch
::
Device
dst_device
=
dst
.
device
();
assert
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cuda
());
cudaMemcpyKind
memcpy_type
=
cudaMemcpyDeviceToDevice
;
void
*
src_ptr
=
src
.
data_ptr
();
void
*
dst_ptr
=
dst
.
data_ptr
();
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
[
0
].
numel
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
for
(
const
auto
&
pair
:
block_mapping
)
{
int64_t
src_block_number
=
pair
.
first
;
for
(
int64_t
dst_block_number
:
pair
.
second
)
{
int64_t
src_offset
=
src_block_number
*
block_size_in_bytes
;
int64_t
dst_offset
=
dst_block_number
*
block_size_in_bytes
;
cudaMemcpyAsync
(
dst_ptr
+
dst_offset
,
src_ptr
+
src_offset
,
block_size_in_bytes
,
memcpy_type
,
stream
);
}
}
}
template
<
typename
scalar_t
>
__global__
void
reshape_and_cache_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
...
...
server.py
View file @
1a7eb7da
...
...
@@ -15,6 +15,8 @@ parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of
parser
.
add_argument
(
'--num-cpu-blocks'
,
type
=
int
,
default
=
32
,
help
=
'number of CPU blocks (per GPU)'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
args
=
parser
.
parse_args
()
...
...
@@ -30,6 +32,7 @@ def main():
num_gpu_blocks
=
args
.
num_gpu_blocks
,
num_cpu_blocks
=
args
.
num_cpu_blocks
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
)
controllers
.
append
(
controller
)
...
...
@@ -52,18 +55,18 @@ def main():
controllers
[
i
].
set_next
(
controllers
[
i
+
1
])
controllers
[
-
1
].
set_next
(
scheduler
)
# Test the following inputs.
test_inputs
=
[
'Ion Stoica is a'
,
'UC Berkeley is'
,
'The future of cloud computing is'
,
(
'Ion Stoica is a'
,
{
'n'
:
4
,
'use_beam_search'
:
True
,
'temperature'
:
0.0
}),
(
'UC Berkeley is'
,
{
'n'
:
3
,
'temperature'
:
0.8
,
'top_p'
:
0.99
}),
(
'The future of cloud computing is'
,
{}),
# Use default parameters.
]
# FIXME
while
True
:
if
test_inputs
:
frontend
.
query
(
test_inputs
.
pop
())
text
,
sampling_params
=
test_inputs
.
pop
(
0
)
frontend
.
query
(
text
,
**
sampling_params
)
scheduler
.
step
()
if
not
scheduler
.
pending
and
not
scheduler
.
running
:
if
not
(
scheduler
.
pending
or
scheduler
.
running
or
test_inputs
)
:
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment