Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1a7eb7da
Unverified
Commit
1a7eb7da
authored
Mar 10, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 10, 2023
Browse files
Support beam search & parallel generation (#7)
parent
04e5acc0
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
662 additions
and
163 deletions
+662
-163
cacheflow/block.py
cacheflow/block.py
+4
-0
cacheflow/master/frontend.py
cacheflow/master/frontend.py
+28
-5
cacheflow/master/scheduler.py
cacheflow/master/scheduler.py
+55
-36
cacheflow/models/__init__.py
cacheflow/models/__init__.py
+3
-1
cacheflow/models/input_metadata.py
cacheflow/models/input_metadata.py
+11
-7
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+10
-0
cacheflow/models/opt.py
cacheflow/models/opt.py
+2
-1
cacheflow/models/sample.py
cacheflow/models/sample.py
+258
-17
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+54
-19
cacheflow/sequence.py
cacheflow/sequence.py
+72
-6
cacheflow/worker/cache_engine.py
cacheflow/worker/cache_engine.py
+27
-8
cacheflow/worker/controller.py
cacheflow/worker/controller.py
+7
-10
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+76
-43
csrc/cache.cpp
csrc/cache.cpp
+14
-2
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+31
-1
server.py
server.py
+10
-7
No files found.
cacheflow/block.py
View file @
1a7eb7da
...
@@ -35,6 +35,10 @@ class LogicalTokenBlock:
...
@@ -35,6 +35,10 @@ class LogicalTokenBlock:
def
get_token_ids
(
self
)
->
List
[
int
]:
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
token_ids
[:
self
.
num_tokens
]
return
self
.
token_ids
[:
self
.
num_tokens
]
def
get_last_token_id
(
self
)
->
int
:
assert
self
.
num_tokens
>
0
return
self
.
token_ids
[
self
.
num_tokens
-
1
]
class
PhysicalTokenBlock
:
class
PhysicalTokenBlock
:
...
...
cacheflow/master/frontend.py
View file @
1a7eb7da
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -25,12 +25,35 @@ class Frontend:
...
@@ -25,12 +25,35 @@ class Frontend:
def
query
(
def
query
(
self
,
self
,
prompt
:
str
,
prompt
:
str
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
n
:
int
=
1
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
use_beam_search
:
bool
=
False
,
stop_token_ids
:
Set
[
int
]
=
set
(),
max_num_steps
:
int
=
16
,
# From OpenAI API.
num_logprobs
:
int
=
0
,
context_window_size
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
if
sampling_params
is
None
:
# Stop when we see an EOS token.
sampling_params
=
SamplingParams
()
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
token_ids
:
List
[
int
]
=
self
.
tokenizer
.
encode
(
prompt
)
sampling_params
=
SamplingParams
(
n
=
n
,
temperature
=
temperature
,
top_p
=
top_p
,
use_beam_search
=
use_beam_search
,
stop_token_ids
=
stop_token_ids
,
max_num_steps
=
max_num_steps
,
num_logprobs
=
num_logprobs
,
context_window_size
=
context_window_size
,
)
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
self
.
_add_query
(
token_ids
,
sampling_params
)
def
_add_query
(
self
,
token_ids
:
List
[
int
],
sampling_params
:
SamplingParams
,
)
->
None
:
seqs
:
List
[
Sequence
]
=
[]
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
seq_id
=
next
(
self
.
seq_counter
)
...
...
cacheflow/master/scheduler.py
View file @
1a7eb7da
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
from
cacheflow.master.block_manager
import
BlockSpaceManager
from
cacheflow.master.block_manager
import
BlockSpaceManager
from
cacheflow.master.frontend
import
Frontend
from
cacheflow.master.frontend
import
Frontend
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
from
cacheflow.sequence
import
Sequence
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.sequence
import
SequenceStatus
from
cacheflow.sequence
import
SequenceStatus
_MAX_NUM_BATCHED_TOKENS
=
2048
_MAX_NUM_BATCHED_TOKENS
=
2048
...
@@ -66,7 +68,7 @@ class Scheduler:
...
@@ -66,7 +68,7 @@ class Scheduler:
def
_append
(
def
_append
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_copy
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
,
)
->
None
:
)
->
None
:
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
...
@@ -74,7 +76,10 @@ class Scheduler:
...
@@ -74,7 +76,10 @@ class Scheduler:
ret
=
self
.
block_manager
.
append
(
seq
)
ret
=
self
.
block_manager
.
append
(
seq
)
if
ret
is
not
None
:
if
ret
is
not
None
:
src_block
,
dst_block
=
ret
src_block
,
dst_block
=
ret
blocks_to_copy
[
src_block
]
=
dst_block
if
src_block
in
blocks_to_copy
:
blocks_to_copy
[
src_block
].
append
(
dst_block
)
else
:
blocks_to_copy
[
src_block
]
=
[
dst_block
]
def
_swap_in
(
def
_swap_in
(
self
,
self
,
...
@@ -83,8 +88,7 @@ class Scheduler:
...
@@ -83,8 +88,7 @@ class Scheduler:
)
->
None
:
)
->
None
:
mapping
=
self
.
block_manager
.
swap_in
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_in
(
seq_group
)
blocks_to_swap_in
.
update
(
mapping
)
blocks_to_swap_in
.
update
(
mapping
)
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
if
seq
.
status
==
SequenceStatus
.
SWAPPED
:
seq
.
status
=
SequenceStatus
.
RUNNING
seq
.
status
=
SequenceStatus
.
RUNNING
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
...
@@ -96,8 +100,7 @@ class Scheduler:
...
@@ -96,8 +100,7 @@ class Scheduler:
assert
self
.
block_manager
.
can_swap_out
(
seq_group
)
assert
self
.
block_manager
.
can_swap_out
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_out
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_out
(
seq_group
)
blocks_to_swap_out
.
update
(
mapping
)
blocks_to_swap_out
.
update
(
mapping
)
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
if
seq
.
status
==
SequenceStatus
.
RUNNING
:
seq
.
status
=
SequenceStatus
.
SWAPPED
seq
.
status
=
SequenceStatus
.
SWAPPED
self
.
swapped
.
append
(
seq_group
)
self
.
swapped
.
append
(
seq_group
)
...
@@ -105,7 +108,7 @@ class Scheduler:
...
@@ -105,7 +108,7 @@ class Scheduler:
# Blocks that need to be swaped or copied before model execution.
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
=
{}
# 1. Reserve new slots for the running sequences.
# 1. Reserve new slots for the running sequences.
# NOTE: Here we implicitly assume FCFS scheduling.
# NOTE: Here we implicitly assume FCFS scheduling.
...
@@ -143,6 +146,10 @@ class Scheduler:
...
@@ -143,6 +146,10 @@ class Scheduler:
# All swapped sequences are swapped in.
# All swapped sequences are swapped in.
self
.
swapped
.
clear
()
self
.
swapped
.
clear
()
# Ensure that swap-in and swap-out never happen at the same timestep.
if
blocks_to_swap_in
:
assert
not
blocks_to_swap_out
num_batched_tokens
=
sum
(
num_batched_tokens
=
sum
(
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
for
seq_group
in
self
.
running
for
seq_group
in
self
.
running
...
@@ -152,7 +159,6 @@ class Scheduler:
...
@@ -152,7 +159,6 @@ class Scheduler:
# NOTE: Here we implicitly assume FCFS scheduling.
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
# TODO(woosuk): Add a batching policy to control the batch size.
if
not
self
.
swapped
:
if
not
self
.
swapped
:
# FIXME(woosuk): Acquire a lock to protect pending.
self
.
_fetch_inputs
()
self
.
_fetch_inputs
()
for
i
,
seq_group
in
enumerate
(
self
.
pending
):
for
i
,
seq_group
in
enumerate
(
self
.
pending
):
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
...
@@ -168,39 +174,45 @@ class Scheduler:
...
@@ -168,39 +174,45 @@ class Scheduler:
else
:
else
:
self
.
pending
.
clear
()
self
.
pending
.
clear
()
# Ensure that swap-in and swap-out never happen at the same timestep.
if
blocks_to_swap_in
:
assert
not
blocks_to_swap_out
# 4. Create input data structures.
# 4. Create input data structures.
prompt_tokens
:
Dict
[
int
,
List
[
int
]]
=
{}
input_seq_groups
:
List
[
SequenceGroupInputs
]
=
[]
generation_tokens
:
Dict
[
int
,
int
]
=
{}
context_lens
:
Dict
[
int
,
int
]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
for
seq_group
in
self
.
running
:
for
seq_group
in
self
.
running
:
group_id
=
seq_group
.
group_id
group_id
=
seq_group
.
group_id
num_steps
=
self
.
num_steps
[
group_id
]
num_steps
=
self
.
num_steps
[
group_id
]
# NOTE(woosuk): We assume that the number of steps is 0
# NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences.
# for the prompt sequences.
is_prompt
=
num_steps
==
0
is_prompt
=
num_steps
==
0
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
!=
SequenceStatus
.
RUNNING
:
continue
input_tokens
:
Dict
[
int
,
List
[
int
]]
=
{}
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
seq_id
=
seq
.
seq_id
seq_id
=
seq
.
seq_id
block_tables
[
seq_id
]
=
self
.
block_manager
.
get_block_table
(
seq
)
block_tables
[
seq_id
]
=
self
.
block_manager
.
get_block_table
(
seq
)
if
is_prompt
:
if
is_prompt
:
promp
t_tokens
[
seq_id
]
=
seq
.
get_token_ids
()
inpu
t_tokens
[
seq_id
]
=
seq
.
get_token_ids
()
else
:
else
:
generation_tokens
[
seq_id
]
=
seq
.
get_token_ids
()[
-
1
]
input_tokens
[
seq_id
]
=
[
seq
.
get_last_token_id
()]
context_lens
[
seq_id
]
=
seq
.
get_len
()
seq_logprobs
[
seq_id
]
=
seq
.
cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len
=
seq
.
get_len
()
input_seq_group
=
SequenceGroupInputs
(
group_id
=
group_id
,
is_prompt
=
is_prompt
,
input_tokens
=
input_tokens
,
context_len
=
seq_len
,
seq_logprobs
=
seq_logprobs
,
sampling_params
=
self
.
sampling_params
[
group_id
],
block_tables
=
block_tables
,
)
input_seq_groups
.
append
(
input_seq_group
)
# 5. Execute the first stage of the pipeline.
# 5. Execute the first stage of the pipeline.
self
.
controllers
[
0
].
execute_stage
(
self
.
controllers
[
0
].
execute_stage
(
prompt_tokens
,
input_seq_groups
,
generation_tokens
,
context_lens
,
block_tables
,
blocks_to_swap_in
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_swap_out
,
blocks_to_copy
,
blocks_to_copy
,
...
@@ -208,7 +220,7 @@ class Scheduler:
...
@@ -208,7 +220,7 @@ class Scheduler:
def
post_step
(
def
post_step
(
self
,
self
,
next_token
s
:
Dict
[
int
,
Tuple
[
int
,
int
]
],
seq_output
s
:
Dict
[
int
,
SequenceOutputs
],
)
->
None
:
)
->
None
:
# Update the running sequences and free blocks.
# Update the running sequences and free blocks.
for
seq_group
in
self
.
running
:
for
seq_group
in
self
.
running
:
...
@@ -216,25 +228,32 @@ class Scheduler:
...
@@ -216,25 +228,32 @@ class Scheduler:
self
.
num_steps
[
group_id
]
+=
1
self
.
num_steps
[
group_id
]
+=
1
stop_token_ids
=
self
.
sampling_params
[
group_id
].
stop_token_ids
stop_token_ids
=
self
.
sampling_params
[
group_id
].
stop_token_ids
# Process beam search results before processing the next tokens.
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
continue
parent_seq_id
,
next_token
=
next_token
s
[
seq
.
seq_id
]
output
=
seq_output
s
[
seq
.
seq_id
]
if
seq
.
seq_id
!=
parent_seq_id
:
if
seq
.
seq_id
!=
output
.
parent_seq_id
:
# The sequence is a fork of the parent sequence (beam search).
# The sequence is a fork of the parent sequence (beam search).
# Free the current sequence.
# Free the current sequence.
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
# Fork the parent sequence.
# Fork the parent sequence.
parent_seq
=
seq_group
.
find
(
parent_seq_id
)
parent_seq
=
seq_group
.
find
(
output
.
parent_seq_id
)
seq
.
logical_token_blocks
=
parent_seq
.
logical_token_blocks
.
copy
(
)
parent_seq
.
fork
(
seq
)
self
.
block_manager
.
fork
(
parent_seq
,
seq
)
self
.
block_manager
.
fork
(
parent_seq
,
seq
)
# Process the next tokens.
for
seq
in
seq_group
.
seqs
:
if
seq
.
status
==
SequenceStatus
.
FINISHED
:
continue
# Append a new token to the sequence.
# Append a new token to the sequence.
seq
.
append
([
next_token
])
output
=
seq_outputs
[
seq
.
seq_id
]
seq
.
append
(
output
.
output_token
,
output
.
logprobs
)
# Check if the sequence has generated a stop token.
# Check if the sequence has generated a stop token.
if
nex
t_token
in
stop_token_ids
:
if
output
.
outpu
t_token
in
stop_token_ids
:
self
.
_free_seq
(
seq
)
self
.
_free_seq
(
seq
)
continue
continue
...
...
cacheflow/models/__init__.py
View file @
1a7eb7da
from
cacheflow.models.input_metadata
import
InputMetadata
from
cacheflow.models.input_metadata
import
InputMetadata
from
cacheflow.models.model_utils
import
get_model
from
cacheflow.models.model_utils
import
get_model
from
cacheflow.models.model_utils
import
set_seed
__all__
=
[
__all__
=
[
'get_model'
,
'InputMetadata'
,
'InputMetadata'
,
'get_model'
,
'set_seed'
]
]
cacheflow/models/input_metadata.py
View file @
1a7eb7da
from
typing
import
List
from
typing
import
List
,
Dict
,
Tuple
import
torch
import
torch
from
cacheflow.sampling_params
import
SamplingParams
class
InputMetadata
:
class
InputMetadata
:
def
__init__
(
def
__init__
(
self
,
self
,
seq_ids
:
List
[
int
],
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
prompt_lens
:
List
[
int
],
prompt_lens
:
List
[
int
],
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
# FIXME: Rename
max_context_len
:
int
,
max_context_len
:
int
,
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
self
.
seq_ids
=
seq_ids
self
.
seq_groups
=
seq_groups
self
.
seq_logprobs
=
seq_logprobs
self
.
prompt_lens
=
prompt_lens
self
.
prompt_lens
=
prompt_lens
self
.
slot_mapping
=
slot_mapping
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
self
.
context_lens
=
context_lens
...
@@ -23,19 +26,20 @@ class InputMetadata:
...
@@ -23,19 +26,20 @@ class InputMetadata:
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompt_tokens
=
sum
(
prompt_lens
)
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
if
block_tables
.
numel
()
>
0
:
if
block_tables
.
numel
()
>
0
:
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
else
:
else
:
self
.
max_num_blocks_per_seq
=
0
self
.
max_num_blocks_per_seq
=
0
assert
self
.
num_generation_tokens
==
block_tables
.
shape
[
0
]
assert
block_tables
.
shape
[
0
]
==
self
.
num_generation_tokens
assert
self
.
num_prompts
+
self
.
num_generation_tokens
==
len
(
seq_ids
)
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
'InputMetadata('
return
(
f
'InputMetadata('
f
'seq_ids=
{
self
.
seq_ids
}
, '
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
...
...
cacheflow/models/model_utils.py
View file @
1a7eb7da
import
random
from
typing
import
Union
from
typing
import
Union
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -30,3 +32,11 @@ def get_model(
...
@@ -30,3 +32,11 @@ def get_model(
model
=
hf_model
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
)
model
=
hf_model
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
)
return
model
.
eval
()
return
model
.
eval
()
raise
ValueError
(
f
'Invalid model name:
{
model_name
}
'
)
raise
ValueError
(
f
'Invalid model name:
{
model_name
}
'
)
def
set_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
seed
)
cacheflow/models/opt.py
View file @
1a7eb7da
...
@@ -9,6 +9,7 @@ from transformers import PreTrainedModel
...
@@ -9,6 +9,7 @@ from transformers import PreTrainedModel
from
cacheflow.models
import
InputMetadata
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
OPTCacheFlowAttention
from
cacheflow.models.attention
import
OPTCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.sample
import
Sampler
from
cacheflow.sequence
import
SequenceOutputs
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -261,7 +262,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
...
@@ -261,7 +262,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Tuple
[
int
,
int
]
]:
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
next_tokens
=
self
.
sampler
(
...
...
cacheflow/models/sample.py
View file @
1a7eb7da
...
@@ -4,6 +4,8 @@ import torch
...
@@ -4,6 +4,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
cacheflow.models
import
InputMetadata
from
cacheflow.models
import
InputMetadata
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceOutputs
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
...
@@ -16,8 +18,42 @@ class Sampler(nn.Module):
...
@@ -16,8 +18,42 @@ class Sampler(nn.Module):
embedding
:
torch
.
Tensor
,
embedding
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
Dict
[
int
,
Tuple
[
int
,
int
]]:
)
->
Dict
[
int
,
SequenceOutputs
]:
# Get the hidden states of the last tokens.
# Get the hidden states that we use for sampling.
hidden_states
=
_prune_hidden_states
(
hidden_states
,
input_metadata
)
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input_metadata
)
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
if
any
(
t
!=
1.0
for
t
in
temperatures
):
t
=
torch
.
tensor
(
temperatures
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities (before applying top-p).
logprobs
=
torch
.
log
(
probs
)
# Apply top-p truncation.
top_ps
=
_get_top_ps
(
input_metadata
)
assert
len
(
top_ps
)
==
probs
.
shape
[
0
]
if
any
(
p
<
1.0
for
p
in
top_ps
):
p
=
torch
.
tensor
(
top_ps
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
probs
=
_apply_top_p
(
probs
,
p
)
# Sample the next tokens.
return
_sample
(
probs
,
logprobs
,
input_metadata
)
def
_prune_hidden_states
(
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
start_idx
=
0
start_idx
=
0
last_token_indicies
:
List
[
int
]
=
[]
last_token_indicies
:
List
[
int
]
=
[]
for
prompt_len
in
input_metadata
.
prompt_lens
:
for
prompt_len
in
input_metadata
.
prompt_lens
:
...
@@ -25,18 +61,223 @@ class Sampler(nn.Module):
...
@@ -25,18 +61,223 @@ class Sampler(nn.Module):
start_idx
+=
prompt_len
start_idx
+=
prompt_len
last_token_indicies
.
extend
(
last_token_indicies
.
extend
(
range
(
start_idx
,
start_idx
+
input_metadata
.
num_generation_tokens
))
range
(
start_idx
,
start_idx
+
input_metadata
.
num_generation_tokens
))
hidden_states
=
hidden_states
[
last_token_indicies
]
return
hidden_states
[
last_token_indicies
]
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
# Sample the next tokens.
def
_get_temperatures
(
# TODO(woosuk): Implement other sampling methods.
input_metadata
:
InputMetadata
,
next_token_ids
=
torch
.
argmax
(
logits
,
dim
=-
1
)
)
->
List
[
float
]:
# Collect the temperatures for the logits.
temperatures
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
if
temperature
==
0.0
:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature
=
1.0
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
temperatures
.
append
(
temperature
)
else
:
# A generation token.
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
return
temperatures
def
_get_top_ps
(
input_metadata
:
InputMetadata
,
)
->
List
[
float
]:
top_ps
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
top_ps
.
append
(
sampling_params
.
top_p
)
else
:
# A generation token.
top_ps
+=
[
sampling_params
.
top_p
]
*
len
(
seq_ids
)
return
top_ps
def
_apply_top_p
(
probs
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# TODO(woosuk): Optimize.
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
mask
=
(
probs_sum
-
probs_sort
)
>
p
.
unsqueeze
(
dim
=
1
)
probs_sort
[
mask
]
=
0.0
probs_sort
.
div_
(
probs_sort
.
sum
(
dim
=-
1
,
keepdim
=
True
))
probs
=
torch
.
gather
(
probs_sort
,
dim
=-
1
,
index
=
torch
.
argsort
(
probs_idx
,
dim
=-
1
))
return
probs
def
_get_topk_logprobs
(
logprobs
:
torch
.
Tensor
,
num_logprobs
:
int
,
)
->
Dict
[
int
,
float
]:
if
num_logprobs
==
0
:
return
{}
topk_logprobs
,
topk_ids
=
torch
.
topk
(
logprobs
,
num_logprobs
)
if
num_logprobs
==
1
:
topk_logprobs
=
[
topk_logprobs
.
item
()]
topk_ids
=
[
topk_ids
.
item
()]
else
:
topk_logprobs
=
topk_logprobs
.
tolist
()
topk_ids
=
topk_ids
.
tolist
()
token_to_logprob
:
Dict
[
int
,
float
]
=
{}
for
token_id
,
logprob
in
zip
(
topk_ids
,
topk_logprobs
):
token_to_logprob
[
token_id
]
=
logprob
return
token_to_logprob
def
_sample_from_prompt
(
prob
:
torch
.
Tensor
,
sampling_params
:
SamplingParams
,
)
->
List
[
int
]:
if
sampling_params
.
use_beam_search
:
# Beam search.
beam_width
=
sampling_params
.
n
_
,
next_token_ids
=
torch
.
topk
(
prob
,
beam_width
)
next_token_ids
=
next_token_ids
.
tolist
()
elif
sampling_params
.
temperature
==
0.0
:
# Greedy sampling.
assert
sampling_params
.
n
==
1
next_token_id
=
torch
.
argmax
(
prob
)
next_token_ids
=
[
next_token_id
.
item
()]
else
:
# Neucleus sampling.
# Sample n tokens for the prompt.
n
=
sampling_params
.
n
next_token_ids
=
torch
.
multinomial
(
prob
,
num_samples
=
n
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
return
next_token_ids
def
_sample_from_generation_tokens
(
seq_ids
:
List
[
int
],
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
seq_logprobs
:
List
[
float
],
sampling_params
:
SamplingParams
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
# NOTE(woosuk): sampling_params.n can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if
sampling_params
.
use_beam_search
:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs
=
torch
.
tensor
(
seq_logprobs
,
dtype
=
torch
.
float
,
device
=
logprobs
.
device
)
logprobs
=
logprobs
+
seq_logprobs
.
unsqueeze
(
dim
=
1
)
vocab_size
=
logprobs
.
size
(
-
1
)
beam_width
=
len
(
seq_ids
)
_
,
topk_ids
=
torch
.
topk
(
logprobs
.
flatten
(),
beam_width
)
seq_idx
=
torch
.
div
(
topk_ids
,
vocab_size
,
rounding_mode
=
'floor'
).
tolist
()
beam_seq_ids
=
[
seq_ids
[
i
]
for
i
in
seq_idx
]
token_ids
=
(
topk_ids
%
vocab_size
).
tolist
()
beam_outputs
:
Dict
[
int
,
Tuple
[
int
,
int
]]
=
{}
outstanding_beams
:
List
[
Tuple
[
int
,
int
]]
=
[]
# If a beam survives, continue with it.
for
seq_id
,
token_id
in
zip
(
beam_seq_ids
,
token_ids
):
if
seq_id
not
in
beam_outputs
:
beam_outputs
[
seq_id
]
=
(
seq_id
,
token_id
)
else
:
outstanding_beams
.
append
((
seq_id
,
token_id
))
# If a beam is discarded, fork another beam.
for
seq_id
in
seq_ids
:
if
seq_id
not
in
beam_outputs
:
beam_outputs
[
seq_id
]
=
outstanding_beams
.
pop
()
assert
not
outstanding_beams
parent_seq_ids
=
[
beam_outputs
[
seq_id
][
0
]
for
seq_id
in
seq_ids
]
next_token_ids
=
[
beam_outputs
[
seq_id
][
1
]
for
seq_id
in
seq_ids
]
elif
sampling_params
.
temperature
==
0.0
:
# Greedy sampling.
assert
len
(
seq_ids
)
==
1
next_token_id
=
torch
.
argmax
(
probs
,
dim
=-
1
)
next_token_ids
=
[
next_token_id
.
item
()]
parent_seq_ids
=
seq_ids
else
:
# Neucleus sampling.
# Sample 1 token for each sequence in the group.
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
squeeze
(
dim
=-
1
).
tolist
()
parent_seq_ids
=
seq_ids
return
parent_seq_ids
,
next_token_ids
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
Dict
[
int
,
SequenceOutputs
]:
seq_outputs
:
Dict
[
int
,
SequenceOutputs
]
=
{}
# TODO(woosuk): Optimize.
idx
=
0
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
# Generate the next tokens for a prompt input.
assert
len
(
seq_ids
)
==
sampling_params
.
n
prob
=
probs
[
idx
]
logprob
=
logprobs
[
idx
]
idx
+=
1
# Sample the next tokens.
next_token_ids
=
_sample_from_prompt
(
prob
,
sampling_params
)
# Get top-k log probabilities for the next tokens.
next_logprobs
=
_get_topk_logprobs
(
logprob
,
sampling_params
.
num_logprobs
)
# Build the output.
for
seq_id
,
next_token_id
in
zip
(
seq_ids
,
next_token_ids
):
output_logprobs
=
next_logprobs
.
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
next_token_id
].
item
()
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
seq_id
,
next_token_id
,
output_logprobs
)
else
:
# Generate the next tokens for generation tokens.
prob
=
probs
[
idx
:
idx
+
len
(
seq_ids
)]
logprob
=
logprobs
[
idx
:
idx
+
len
(
seq_ids
)]
idx
+=
len
(
seq_ids
)
# Sample the next tokens.
seq_logprobs
=
[
input_metadata
.
seq_logprobs
[
seq_id
]
for
seq_id
in
seq_ids
]
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
# Get top-k log probabilities for the next tokens.
next_logprobs
:
Dict
[
int
,
Dict
[
int
,
float
]]
=
{}
for
i
,
seq_id
in
enumerate
(
seq_ids
):
next_logprobs
[
seq_id
]
=
_get_topk_logprobs
(
logprob
[
i
],
sampling_params
.
num_logprobs
)
# Build the output.
for
seq_id
,
parent_seq_id
,
next_token_id
in
zip
(
seq_ids
,
parent_seq_ids
,
next_token_ids
):
i
=
seq_ids
.
index
(
parent_seq_id
)
output_logprobs
=
next_logprobs
[
parent_seq_id
].
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
i
,
next_token_id
].
item
()
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
parent_seq_id
,
next_token_id
,
output_logprobs
,
)
# Return the next tokens.
return
seq_outputs
next_tokens
:
Dict
[
int
,
Tuple
[
int
,
int
]]
=
{}
for
seq_id
,
token_id
in
zip
(
input_metadata
.
seq_ids
,
next_token_ids
):
next_tokens
[
seq_id
]
=
(
seq_id
,
token_id
)
return
next_tokens
cacheflow/sampling_params.py
View file @
1a7eb7da
...
@@ -5,27 +5,51 @@ class SamplingParams:
...
@@ -5,27 +5,51 @@ class SamplingParams:
def
__init__
(
def
__init__
(
self
,
self
,
n
:
int
=
1
,
n
:
int
,
temperature
:
float
=
1.0
,
temperature
:
float
,
top_p
:
float
=
1.0
,
top_p
:
float
,
use_beam_search
:
bool
=
False
,
use_beam_search
:
bool
,
stop_token_ids
:
Set
[
int
]
=
[],
stop_token_ids
:
Set
[
int
],
max_num_steps
:
int
=
16
,
# From OpenAI API.
max_num_steps
:
int
,
max_context_len
:
Optional
[
int
]
=
None
,
num_logprobs
:
int
,
context_window_size
:
Optional
[
int
],
)
->
None
:
)
->
None
:
assert
n
>=
1
if
n
<
1
:
assert
temperature
>=
0.0
raise
ValueError
(
f
'n must be at least 1, got
{
n
}
.'
)
assert
0.0
<
top_p
<=
1.0
if
temperature
<
0.0
:
raise
ValueError
(
f
'temperature must be non-negative, got
{
temperature
}
.'
)
if
not
0.0
<
top_p
<=
1.0
:
raise
ValueError
(
f
'top_p must be in (0, 1], got
{
top_p
}
.'
)
if
max_num_steps
<
1
:
raise
ValueError
(
f
'max_num_steps must be at least 1, got
{
max_num_steps
}
.'
)
if
num_logprobs
<
0
:
raise
ValueError
(
f
'num_logprobs must be non-negative, got
{
num_logprobs
}
.'
)
if
context_window_size
is
not
None
and
context_window_size
<
0
:
raise
ValueError
(
'context_window_size must be non-negative, '
f
'got
{
context_window_size
}
.'
)
if
use_beam_search
:
if
use_beam_search
:
assert
n
>
1
if
n
==
1
:
assert
temperature
>
0.0
raise
ValueError
(
assert
top_p
==
1.0
'n must be greater than 1 when using beam search.'
)
if
temperature
>
0.0
:
raise
ValueError
(
'temperature must be 0 when using beam search.'
)
if
top_p
<
1.0
:
raise
ValueError
(
'top_p must be 1 when using beam search.'
)
elif
temperature
==
0.0
:
elif
temperature
==
0.0
:
# Zero temperature means greedy decoding.
# Zero temperature means greedy sampling.
assert
n
==
1
if
n
>
1
:
assert
top_p
==
1.0
raise
ValueError
(
assert
max_num_steps
>=
1
'n must be 1 when using greedy sampling.'
)
assert
max_context_len
is
None
or
max_context_len
>=
0
if
top_p
<
1.0
:
raise
ValueError
(
'top_p must be 1 when using greedy sampling.'
)
self
.
n
=
n
self
.
n
=
n
self
.
temperature
=
temperature
self
.
temperature
=
temperature
...
@@ -33,4 +57,15 @@ class SamplingParams:
...
@@ -33,4 +57,15 @@ class SamplingParams:
self
.
use_beam_search
=
use_beam_search
self
.
use_beam_search
=
use_beam_search
self
.
stop_token_ids
=
stop_token_ids
self
.
stop_token_ids
=
stop_token_ids
self
.
max_num_steps
=
max_num_steps
self
.
max_num_steps
=
max_num_steps
self
.
max_context_len
=
max_context_len
self
.
num_logprobs
=
num_logprobs
self
.
context_window_size
=
context_window_size
def
__repr__
(
self
)
->
str
:
return
(
f
'SamplingParams(n=
{
self
.
n
}
, '
f
'temperature=
{
self
.
temperature
}
, '
f
'top_p=
{
self
.
top_p
}
, '
f
'use_beam_search=
{
self
.
use_beam_search
}
, '
f
'stop_token_ids=
{
self
.
stop_token_ids
}
, '
f
'max_num_steps=
{
self
.
max_num_steps
}
, '
f
'num_logprobs=
{
self
.
num_logprobs
}
, '
f
'context_window_size=
{
self
.
context_window_size
}
)'
)
cacheflow/sequence.py
View file @
1a7eb7da
import
copy
import
enum
import
enum
from
typing
import
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
cacheflow.block
import
LogicalTokenBlock
from
cacheflow.block
import
LogicalTokenBlock
from
cacheflow.sampling_params
import
SamplingParams
class
SequenceStatus
(
enum
.
Enum
):
class
SequenceStatus
(
enum
.
Enum
):
...
@@ -24,9 +26,11 @@ class Sequence:
...
@@ -24,9 +26,11 @@ class Sequence:
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
# Initialize the logical token blocks with the given token ids.
# Initialize the logical token blocks with the given token ids.
self
.
a
ppen
d
(
token_ids
)
self
.
a
d
d
(
token_ids
)
self
.
status
=
SequenceStatus
.
PENDING
self
.
status
=
SequenceStatus
.
PENDING
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
cumulative_logprobs
=
1.0
def
add_block
(
self
)
->
None
:
def
add_block
(
self
)
->
None
:
block
=
LogicalTokenBlock
(
block
=
LogicalTokenBlock
(
...
@@ -35,7 +39,7 @@ class Sequence:
...
@@ -35,7 +39,7 @@ class Sequence:
)
)
self
.
logical_token_blocks
.
append
(
block
)
self
.
logical_token_blocks
.
append
(
block
)
def
a
ppen
d
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
a
d
d
(
self
,
token_ids
:
List
[
int
])
->
None
:
while
token_ids
:
while
token_ids
:
if
not
self
.
logical_token_blocks
:
if
not
self
.
logical_token_blocks
:
self
.
add_block
()
self
.
add_block
()
...
@@ -49,6 +53,12 @@ class Sequence:
...
@@ -49,6 +53,12 @@ class Sequence:
last_block
.
append
(
token_ids
[:
num_empty_slots
])
last_block
.
append
(
token_ids
[:
num_empty_slots
])
token_ids
=
token_ids
[
num_empty_slots
:]
token_ids
=
token_ids
[
num_empty_slots
:]
def
append
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
float
])
->
None
:
assert
token_id
in
logprobs
self
.
add
([
token_id
])
self
.
output_logprobs
.
append
(
logprobs
)
self
.
cumulative_logprobs
+=
logprobs
[
token_id
]
def
get_len
(
self
)
->
int
:
def
get_len
(
self
)
->
int
:
return
sum
(
block
.
num_tokens
for
block
in
self
.
logical_token_blocks
)
return
sum
(
block
.
num_tokens
for
block
in
self
.
logical_token_blocks
)
...
@@ -58,6 +68,14 @@ class Sequence:
...
@@ -58,6 +68,14 @@ class Sequence:
token_ids
.
extend
(
block
.
get_token_ids
())
token_ids
.
extend
(
block
.
get_token_ids
())
return
token_ids
return
token_ids
def
get_last_token_id
(
self
)
->
int
:
return
self
.
logical_token_blocks
[
-
1
].
get_last_token_id
()
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
'Sequence'
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
cumulative_logprobs
=
self
.
cumulative_logprobs
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
'Sequence(seq_id=
{
self
.
seq_id
}
, '
return
(
f
'Sequence(seq_id=
{
self
.
seq_id
}
, '
f
'status=
{
self
.
status
.
name
}
, '
f
'status=
{
self
.
status
.
name
}
, '
...
@@ -74,11 +92,17 @@ class SequenceGroup:
...
@@ -74,11 +92,17 @@ class SequenceGroup:
self
.
group_id
=
group_id
self
.
group_id
=
group_id
self
.
seqs
=
seqs
self
.
seqs
=
seqs
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
def
get_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
,
)
->
List
[
Sequence
]:
if
status
is
None
:
if
status
is
None
:
return
len
(
self
.
seqs
)
return
self
.
seqs
else
:
else
:
return
len
([
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
])
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
]
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
return
len
(
self
.
get_seqs
(
status
))
def
find
(
self
,
seq_id
:
int
)
->
Sequence
:
def
find
(
self
,
seq_id
:
int
)
->
Sequence
:
for
seq
in
self
.
seqs
:
for
seq
in
self
.
seqs
:
...
@@ -92,3 +116,45 @@ class SequenceGroup:
...
@@ -92,3 +116,45 @@ class SequenceGroup:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
'SequenceGroup(group_id=
{
self
.
group_id
}
, '
return
(
f
'SequenceGroup(group_id=
{
self
.
group_id
}
, '
f
'num_seqs=
{
len
(
self
.
seqs
)
}
)'
)
f
'num_seqs=
{
len
(
self
.
seqs
)
}
)'
)
class
SequenceGroupInputs
:
def
__init__
(
self
,
group_id
:
int
,
is_prompt
:
bool
,
input_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> token ids.
context_len
:
int
,
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of physical block numbers.
)
->
None
:
self
.
group_id
=
group_id
self
.
is_prompt
=
is_prompt
self
.
input_tokens
=
input_tokens
self
.
context_len
=
context_len
self
.
seq_logprobs
=
seq_logprobs
self
.
sampling_params
=
sampling_params
self
.
block_tables
=
block_tables
class
SequenceOutputs
:
def
__init__
(
self
,
seq_id
:
int
,
parent_seq_id
:
int
,
output_token
:
int
,
logprobs
:
Dict
[
int
,
float
],
# Token id -> logP(x_i+1 | x_0, ..., x_i).
)
->
None
:
self
.
seq_id
=
seq_id
self
.
parent_seq_id
=
parent_seq_id
self
.
output_token
=
output_token
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
return
(
f
'SequenceOutputs(seq_id=
{
self
.
seq_id
}
, '
f
'parent_seq_id=
{
self
.
parent_seq_id
}
, '
f
'output_token=
{
self
.
output_token
}
), '
f
'logprobs=
{
self
.
logprobs
}
'
)
cacheflow/worker/cache_engine.py
View file @
1a7eb7da
...
@@ -97,7 +97,7 @@ class CacheEngine:
...
@@ -97,7 +97,7 @@ class CacheEngine:
cpu_cache
.
append
((
key_blocks
,
value_blocks
))
cpu_cache
.
append
((
key_blocks
,
value_blocks
))
return
cpu_cache
return
cpu_cache
def
_
copy_blocks
(
def
_
swap
(
self
,
self
,
src
:
List
[
KVCache
],
src
:
List
[
KVCache
],
dst
:
List
[
KVCache
],
dst
:
List
[
KVCache
],
...
@@ -108,19 +108,38 @@ class CacheEngine:
...
@@ -108,19 +108,38 @@ class CacheEngine:
src_key_cache
,
src_value_cache
=
src
[
i
]
src_key_cache
,
src_value_cache
=
src
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
# Copy the key blocks.
# Copy the key blocks.
cache_ops
.
copy_cache
_blocks
(
cache_ops
.
swap
_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_key_cache
,
dst_key_cache
,
src_to_dst
)
# Copy the value blocks.
# Copy the value blocks.
cache_ops
.
copy_cache
_blocks
(
cache_ops
.
swap
_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
src_value_cache
,
dst_value_cache
,
src_to_dst
)
event
=
self
.
events
[
i
]
event
=
self
.
events
[
i
]
event
.
record
(
stream
=
self
.
cache_stream
)
event
.
record
(
stream
=
self
.
cache_stream
)
def
copy
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_copy_blocks
(
self
.
gpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_
copy_blocks
(
self
.
cpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
self
.
_
swap
(
self
.
cpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_copy_blocks
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
self
.
_swap
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
def
_copy
(
self
,
src
:
List
[
KVCache
],
dst
:
List
[
KVCache
],
src_to_dsts
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
with
torch
.
cuda
.
stream
(
self
.
cache_stream
):
for
i
in
range
(
self
.
num_layers
):
src_key_cache
,
src_value_cache
=
src
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
# Copy the key blocks.
cache_ops
.
copy_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dsts
)
# Copy the value blocks.
cache_ops
.
copy_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dsts
)
event
=
self
.
events
[
i
]
event
.
record
(
stream
=
self
.
cache_stream
)
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]])
->
None
:
self
.
_copy
(
self
.
gpu_cache
,
self
.
gpu_cache
,
src_to_dsts
)
cacheflow/worker/controller.py
View file @
1a7eb7da
from
typing
import
Dict
,
List
,
Union
from
typing
import
Dict
,
List
,
Union
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.worker.worker
import
Worker
from
cacheflow.worker.worker
import
Worker
...
@@ -14,7 +15,8 @@ class Controller:
...
@@ -14,7 +15,8 @@ class Controller:
block_size
:
int
,
block_size
:
int
,
num_gpu_blocks
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
num_cpu_blocks
:
int
,
dtype
:
str
=
'half'
,
dtype
:
str
,
seed
:
int
,
)
->
None
:
)
->
None
:
self
.
node_id
=
node_id
self
.
node_id
=
node_id
self
.
num_workers
=
num_workers
self
.
num_workers
=
num_workers
...
@@ -37,6 +39,7 @@ class Controller:
...
@@ -37,6 +39,7 @@ class Controller:
num_gpu_blocks
=
num_gpu_blocks
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
dtype
=
dtype
,
dtype
=
dtype
,
seed
=
seed
,
)
)
self
.
workers
.
append
(
worker
)
self
.
workers
.
append
(
worker
)
...
@@ -49,22 +52,16 @@ class Controller:
...
@@ -49,22 +52,16 @@ class Controller:
def
execute_stage
(
def
execute_stage
(
self
,
self
,
prompt_tokens
:
Dict
[
int
,
List
[
int
]],
input_seq_groups
:
List
[
SequenceGroupInputs
],
generation_tokens
:
Dict
[
int
,
int
],
context_lens
:
Dict
[
int
,
int
],
block_tables
:
Dict
[
int
,
List
[
int
]],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
,
)
->
None
:
)
->
None
:
# FIXME: Support tensor parallelism.
# FIXME: Support tensor parallelism.
assert
len
(
self
.
workers
)
==
1
assert
len
(
self
.
workers
)
==
1
worker
=
self
.
workers
[
0
]
worker
=
self
.
workers
[
0
]
output
=
worker
.
execute_stage
(
output
=
worker
.
execute_stage
(
prompt_tokens
,
input_seq_groups
,
generation_tokens
,
context_lens
,
block_tables
,
blocks_to_swap_in
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_swap_out
,
blocks_to_copy
,
blocks_to_copy
,
...
...
cacheflow/worker/worker.py
View file @
1a7eb7da
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch
from
cacheflow.models
import
get_model
from
cacheflow.models
import
get_model
from
cacheflow.models
import
set_seed
from
cacheflow.models
import
InputMetadata
from
cacheflow.models
import
InputMetadata
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.worker.cache_engine
import
CacheEngine
from
cacheflow.worker.cache_engine
import
CacheEngine
...
@@ -18,6 +22,7 @@ class Worker:
...
@@ -18,6 +22,7 @@ class Worker:
num_gpu_blocks
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
num_cpu_blocks
:
int
,
dtype
:
str
,
dtype
:
str
,
seed
:
int
,
)
->
None
:
)
->
None
:
self
.
worker_id
=
worker_id
self
.
worker_id
=
worker_id
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
...
@@ -33,6 +38,11 @@ class Worker:
...
@@ -33,6 +38,11 @@ class Worker:
self
.
head_size
=
self
.
model
.
config
.
hidden_size
//
self
.
num_heads
self
.
head_size
=
self
.
model
.
config
.
hidden_size
//
self
.
num_heads
self
.
dtype
=
self
.
model
.
dtype
self
.
dtype
=
self
.
model
.
dtype
# Set the seed.
# We set the seed after initializing the model to ensure that
# the random state is not affected by the model initialization.
set_seed
(
seed
)
self
.
cache_engine
=
CacheEngine
(
self
.
cache_engine
=
CacheEngine
(
worker_id
=
worker_id
,
worker_id
=
worker_id
,
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
...
@@ -49,53 +59,79 @@ class Worker:
...
@@ -49,53 +59,79 @@ class Worker:
def
prepare_inputs
(
def
prepare_inputs
(
self
,
self
,
prompt_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of input token ids.
input_seq_groups
:
List
[
SequenceGroupInputs
],
generation_tokens
:
Dict
[
int
,
int
],
# Seq id -> Input token id.
context_lens
:
Dict
[
int
,
int
],
# Seq id -> Number of tokens participating in attention.
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of physical block numbers.
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
# TODO(woosuk): Support interactive generation.
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
# Add the prompt tokens.
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
prompt_lens
:
List
[
int
]
=
[]
sampling_params
:
Dict
[
int
,
SamplingParams
]
=
{}
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
prompt_seq_ids
=
sorted
(
prompt_tokens
.
keys
())
# Add prompt tokens.
for
seq_id
in
prompt_seq_ids
:
prompt_lens
:
List
[
int
]
=
[]
prompt_len
=
len
(
prompt_tokens
[
seq_id
])
for
input_seq_group
in
input_seq_groups
:
if
not
input_seq_group
.
is_prompt
:
continue
seq_ids
=
list
(
input_seq_group
.
input_tokens
.
keys
())
sampling_params
=
input_seq_group
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
input_seq_group
.
seq_logprobs
)
# Use any sequence in the group.
seq_id
=
seq_ids
[
0
]
prompt_tokens
=
input_seq_group
.
input_tokens
[
seq_id
]
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
extend
(
prompt_tokens
[
seq_id
])
input_tokens
.
extend
(
prompt_tokens
)
input_positions
.
extend
(
range
(
len
(
prompt_tokens
[
seq_id
])))
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
extend
(
range
(
len
(
prompt_tokens
)))
block_table
=
block_tables
[
seq_id
]
# Compute the slot mapping.
block_table
=
input_seq_group
.
block_tables
[
seq_id
]
for
i
in
range
(
prompt_len
):
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
# Add
the
generation tokens.
# Add generation tokens.
max_context_len
=
0
max_context_len
=
0
max_num_blocks_per_seq
=
0
max_num_blocks_per_seq
=
0
context_lens
:
List
[
int
]
=
[]
generation_block_tables
:
List
[
List
[
int
]]
=
[]
generation_block_tables
:
List
[
List
[
int
]]
=
[]
for
input_seq_group
in
input_seq_groups
:
if
input_seq_group
.
is_prompt
:
continue
seq_ids
=
list
(
input_seq_group
.
input_tokens
.
keys
())
sampling_params
=
input_seq_group
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
input_seq_group
.
seq_logprobs
)
for
seq_id
in
seq_ids
:
assert
len
(
input_seq_group
.
input_tokens
[
seq_id
])
==
1
generation_token
=
input_seq_group
.
input_tokens
[
seq_id
][
0
]
input_tokens
.
append
(
generation_token
)
generation_seq_ids
=
sorted
(
generation_tokens
.
keys
())
position
=
input_seq_group
.
context_len
-
1
for
seq_id
in
generation_seq_ids
:
input_positions
.
append
(
position
)
input_tokens
.
append
(
generation_tokens
[
seq_id
])
position_id
=
context_lens
[
seq_id
]
-
1
input_positions
.
append
(
position_id
)
block_table
=
block_tables
[
seq_id
]
block_table
=
input_seq_group
.
block_tables
[
seq_id
]
generation_block_tables
.
append
(
block_table
)
generation_block_tables
.
append
(
block_table
)
max_context_len
=
max
(
max_context_len
,
context_lens
[
seq_id
])
max_context_len
=
max
(
max_context_len
,
input_seq_group
.
context_len
)
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
,
len
(
block_table
))
max_num_blocks_per_seq
,
len
(
block_table
))
context_lens
.
append
(
input_seq_group
.
context_len
)
block_number
=
block_table
[
position
_id
//
self
.
block_size
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
_id
%
self
.
block_size
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
...
@@ -112,8 +148,7 @@ class Worker:
...
@@ -112,8 +148,7 @@ class Worker:
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
slot_mapping
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
context_lens_tensor
=
torch
.
tensor
(
context_lens_tensor
=
torch
.
tensor
(
[
context_lens
[
seq_id
]
for
seq_id
in
generation_seq_ids
],
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
dtype
=
torch
.
int
,
device
=
self
.
device
)
padded_block_tables
=
[
padded_block_tables
=
[
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
)
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
)
for
block_table
in
generation_block_tables
]
for
block_table
in
generation_block_tables
]
...
@@ -121,7 +156,8 @@ class Worker:
...
@@ -121,7 +156,8 @@ class Worker:
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
seq_ids
=
prompt_seq_ids
+
generation_seq_ids
,
seq_groups
=
seq_groups
,
seq_logprobs
=
seq_logprobs
,
prompt_lens
=
prompt_lens
,
prompt_lens
=
prompt_lens
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
context_lens
=
context_lens_tensor
,
...
@@ -133,14 +169,11 @@ class Worker:
...
@@ -133,14 +169,11 @@ class Worker:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_stage
(
def
execute_stage
(
self
,
self
,
prompt_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of input token ids.
input_seq_groups
:
List
[
SequenceGroupInputs
],
generation_tokens
:
Dict
[
int
,
int
],
# Seq id -> Input token id.
context_lens
:
Dict
[
int
,
int
],
# Seq id -> Number of tokens participating in attention.
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of physical block numbers.
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]
]
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
int
,
Tuple
[
int
,
int
]]
]:
)
->
Dict
[
int
,
SequenceOutputs
]:
# Issue cache operations.
# Issue cache operations.
command_issued
=
False
command_issued
=
False
if
blocks_to_swap_in
:
if
blocks_to_swap_in
:
...
@@ -160,7 +193,7 @@ class Worker:
...
@@ -160,7 +193,7 @@ class Worker:
# Prepare input tensors.
# Prepare input tensors.
input_tokens
,
input_positions
,
input_metadata
=
self
.
prepare_inputs
(
input_tokens
,
input_positions
,
input_metadata
=
self
.
prepare_inputs
(
prompt_tokens
,
generation_tokens
,
context_lens
,
block_table
s
)
input_seq_group
s
)
# Execute the model.
# Execute the model.
output
=
self
.
model
(
output
=
self
.
model
(
...
...
csrc/cache.cpp
View file @
1a7eb7da
#include <torch/extension.h>
#include <torch/extension.h>
void
copy_blocks
(
#include <map>
#include <vector>
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
);
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
);
void
copy_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>&
block_mapping
);
void
reshape_and_cache
(
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
value
,
...
@@ -14,7 +22,11 @@ void reshape_and_cache(
...
@@ -14,7 +22,11 @@ void reshape_and_cache(
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
m
.
def
(
"copy_cache_blocks"
,
"swap_blocks"
,
&
swap_blocks
,
"Swap in (out) the cache blocks from src to dst"
);
m
.
def
(
"copy_blocks"
,
&
copy_blocks
,
&
copy_blocks
,
"Copy the cache blocks from src to dst"
);
"Copy the cache blocks from src to dst"
);
m
.
def
(
m
.
def
(
...
...
csrc/cache_kernels.cu
View file @
1a7eb7da
...
@@ -5,8 +5,9 @@
...
@@ -5,8 +5,9 @@
#include <algorithm>
#include <algorithm>
#include <cassert>
#include <cassert>
#include <map>
#include <map>
#include <vector>
void
copy
_blocks
(
void
swap
_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
)
{
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
)
{
...
@@ -43,6 +44,35 @@ void copy_blocks(
...
@@ -43,6 +44,35 @@ void copy_blocks(
}
}
}
}
void
copy_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>&
block_mapping
)
{
torch
::
Device
src_device
=
src
.
device
();
torch
::
Device
dst_device
=
dst
.
device
();
assert
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cuda
());
cudaMemcpyKind
memcpy_type
=
cudaMemcpyDeviceToDevice
;
void
*
src_ptr
=
src
.
data_ptr
();
void
*
dst_ptr
=
dst
.
data_ptr
();
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
[
0
].
numel
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
for
(
const
auto
&
pair
:
block_mapping
)
{
int64_t
src_block_number
=
pair
.
first
;
for
(
int64_t
dst_block_number
:
pair
.
second
)
{
int64_t
src_offset
=
src_block_number
*
block_size_in_bytes
;
int64_t
dst_offset
=
dst_block_number
*
block_size_in_bytes
;
cudaMemcpyAsync
(
dst_ptr
+
dst_offset
,
src_ptr
+
src_offset
,
block_size_in_bytes
,
memcpy_type
,
stream
);
}
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
reshape_and_cache_kernel
(
__global__
void
reshape_and_cache_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
...
...
server.py
View file @
1a7eb7da
...
@@ -15,6 +15,8 @@ parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of
...
@@ -15,6 +15,8 @@ parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of
parser
.
add_argument
(
'--num-cpu-blocks'
,
type
=
int
,
default
=
32
,
help
=
'number of CPU blocks (per GPU)'
)
parser
.
add_argument
(
'--num-cpu-blocks'
,
type
=
int
,
default
=
32
,
help
=
'number of CPU blocks (per GPU)'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -30,6 +32,7 @@ def main():
...
@@ -30,6 +32,7 @@ def main():
num_gpu_blocks
=
args
.
num_gpu_blocks
,
num_gpu_blocks
=
args
.
num_gpu_blocks
,
num_cpu_blocks
=
args
.
num_cpu_blocks
,
num_cpu_blocks
=
args
.
num_cpu_blocks
,
dtype
=
args
.
dtype
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
)
)
controllers
.
append
(
controller
)
controllers
.
append
(
controller
)
...
@@ -52,18 +55,18 @@ def main():
...
@@ -52,18 +55,18 @@ def main():
controllers
[
i
].
set_next
(
controllers
[
i
+
1
])
controllers
[
i
].
set_next
(
controllers
[
i
+
1
])
controllers
[
-
1
].
set_next
(
scheduler
)
controllers
[
-
1
].
set_next
(
scheduler
)
# Test the following inputs.
test_inputs
=
[
test_inputs
=
[
'Ion Stoica is a'
,
(
'Ion Stoica is a'
,
{
'n'
:
4
,
'use_beam_search'
:
True
,
'temperature'
:
0.0
}),
'UC Berkeley is'
,
(
'UC Berkeley is'
,
{
'n'
:
3
,
'temperature'
:
0.8
,
'top_p'
:
0.99
}),
'The future of cloud computing is'
,
(
'The future of cloud computing is'
,
{}),
# Use default parameters.
]
]
# FIXME
while
True
:
while
True
:
if
test_inputs
:
if
test_inputs
:
frontend
.
query
(
test_inputs
.
pop
())
text
,
sampling_params
=
test_inputs
.
pop
(
0
)
frontend
.
query
(
text
,
**
sampling_params
)
scheduler
.
step
()
scheduler
.
step
()
if
not
scheduler
.
pending
and
not
scheduler
.
running
:
if
not
(
scheduler
.
pending
or
scheduler
.
running
or
test_inputs
)
:
break
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment