Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
55f8b0a5
"...composable_kernel_onnxruntime.git" did not exist on "ecf337bab5c23708d80a4c537c6b49dbda6e23b2"
Unverified
Commit
55f8b0a5
authored
May 10, 2023
by
Woosuk Kwon
Committed by
GitHub
May 10, 2023
Browse files
Implement presence and frequency penalties (#95)
parent
9f88db35
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
215 additions
and
82 deletions
+215
-82
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+7
-16
cacheflow/frontend/fastapi_frontend.py
cacheflow/frontend/fastapi_frontend.py
+1
-1
cacheflow/frontend/simple_frontend.py
cacheflow/frontend/simple_frontend.py
+3
-2
cacheflow/model_executor/input_metadata.py
cacheflow/model_executor/input_metadata.py
+6
-4
cacheflow/model_executor/layers/sampler.py
cacheflow/model_executor/layers/sampler.py
+106
-8
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+26
-9
cacheflow/sequence.py
cacheflow/sequence.py
+46
-23
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+18
-17
simple_server.py
simple_server.py
+2
-2
No files found.
cacheflow/core/scheduler.py
View file @
55f8b0a5
...
@@ -3,11 +3,12 @@ import time
...
@@ -3,11 +3,12 @@ import time
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
cacheflow.core.block_manager
import
BlockSpaceManager
from
cacheflow.core.block_manager
import
BlockSpaceManager
from
cacheflow.logger
import
init_logger
from
cacheflow.core.policy
import
PolicyFactory
from
cacheflow.core.policy
import
PolicyFactory
from
cacheflow.logger
import
init_logger
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
from
cacheflow.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceOutputs
,
SequenceStatus
)
SequenceGroupMetadata
,
SequenceOutputs
,
SequenceStatus
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -246,27 +247,17 @@ class Scheduler:
...
@@ -246,27 +247,17 @@ class Scheduler:
group_id
=
seq_group
.
group_id
group_id
=
seq_group
.
group_id
is_prompt
=
group_id
in
prompt_group_ids
is_prompt
=
group_id
in
prompt_group_ids
input_tokens
:
Dict
[
int
,
List
[
int
]]
=
{}
seq_data
:
Dict
[
int
,
List
[
SequenceData
]]
=
{}
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
seq_id
=
seq
.
seq_id
seq_id
=
seq
.
seq_id
seq_data
[
seq_id
]
=
seq
.
data
block_tables
[
seq_id
]
=
self
.
block_manager
.
get_block_table
(
seq
)
block_tables
[
seq_id
]
=
self
.
block_manager
.
get_block_table
(
seq
)
if
is_prompt
:
input_tokens
[
seq_id
]
=
seq
.
get_token_ids
()
else
:
input_tokens
[
seq_id
]
=
[
seq
.
get_last_token_id
()]
seq_logprobs
[
seq_id
]
=
seq
.
cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len
=
seq
.
get_len
()
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
group_id
=
group_id
,
group_id
=
group_id
,
is_prompt
=
is_prompt
,
is_prompt
=
is_prompt
,
input_tokens
=
input_tokens
,
seq_data
=
seq_data
,
context_len
=
seq_len
,
seq_logprobs
=
seq_logprobs
,
sampling_params
=
self
.
sampling_params
[
group_id
],
sampling_params
=
self
.
sampling_params
[
group_id
],
block_tables
=
block_tables
,
block_tables
=
block_tables
,
)
)
...
...
cacheflow/frontend/fastapi_frontend.py
View file @
55f8b0a5
...
@@ -96,7 +96,7 @@ class FastAPIServer:
...
@@ -96,7 +96,7 @@ class FastAPIServer:
seqs
:
List
[
Sequence
]
=
[]
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
token_ids
,
block_size
=
self
.
block_size
)
seq
=
Sequence
(
seq_id
,
prompt
,
token_ids
,
block_size
=
self
.
block_size
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
...
...
cacheflow/frontend/simple_frontend.py
View file @
55f8b0a5
...
@@ -35,10 +35,11 @@ class SimpleFrontend:
...
@@ -35,10 +35,11 @@ class SimpleFrontend:
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
)
->
None
:
)
->
None
:
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
self
.
_add_query
(
token_ids
,
sampling_params
)
self
.
_add_query
(
prompt
,
token_ids
,
sampling_params
)
def
_add_query
(
def
_add_query
(
self
,
self
,
prompt
:
str
,
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
...
@@ -48,7 +49,7 @@ class SimpleFrontend:
...
@@ -48,7 +49,7 @@ class SimpleFrontend:
seqs
:
List
[
Sequence
]
=
[]
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
token_ids
,
block_size
=
self
.
block_size
)
seq
=
Sequence
(
seq_id
,
prompt
,
token_ids
,
block_size
=
self
.
block_size
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
group_id
=
next
(
self
.
seq_group_counter
)
group_id
=
next
(
self
.
seq_group_counter
)
...
...
cacheflow/model_executor/input_metadata.py
View file @
55f8b0a5
from
typing
import
List
,
Dic
t
,
Tuple
from
typing
import
Dict
,
Lis
t
,
Tuple
import
torch
import
torch
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceData
class
InputMetadata
:
class
InputMetadata
:
def
__init__
(
def
__init__
(
self
,
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
# List of (seq_ids, sampling_params).
seq_
logprobs
:
Dict
[
int
,
flo
at
],
# Seq
id ->
cumulative logprobs
.
seq_
data
:
Dict
[
int
,
SequenceD
at
a
],
# Seq
_
id ->
SequenceData
.
prompt_lens
:
List
[
int
],
prompt_lens
:
List
[
int
],
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
...
@@ -19,7 +20,7 @@ class InputMetadata:
...
@@ -19,7 +20,7 @@ class InputMetadata:
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_groups
=
seq_groups
self
.
seq_
logprobs
=
seq_logprobs
self
.
seq_
data
=
seq_data
self
.
prompt_lens
=
prompt_lens
self
.
prompt_lens
=
prompt_lens
self
.
slot_mapping
=
slot_mapping
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
self
.
context_lens
=
context_lens
...
@@ -39,6 +40,7 @@ class InputMetadata:
...
@@ -39,6 +40,7 @@ class InputMetadata:
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
# Print only useful metadata.
return
(
f
'InputMetadata('
return
(
f
'InputMetadata('
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
...
...
cacheflow/model_executor/layers/sampler.py
View file @
55f8b0a5
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -31,6 +32,16 @@ class Sampler(nn.Module):
...
@@ -31,6 +32,16 @@ class Sampler(nn.Module):
# Remove paddings in vocab (if any).
# Remove paddings in vocab (if any).
logits
=
logits
[:,
:
self
.
vocab_size
]
logits
=
logits
[:,
:
self
.
vocab_size
]
# Apply presence and frequency penalties.
output_tokens
=
_get_output_tokens
(
input_metadata
)
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
presence_penalties
,
frequency_penalties
=
_get_penalties
(
input_metadata
)
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
logits
=
_apply_penalties
(
logits
,
output_tokens
,
presence_penalties
,
frequency_penalties
,
self
.
vocab_size
)
# Apply temperature scaling.
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input_metadata
)
temperatures
=
_get_temperatures
(
input_metadata
)
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
...
@@ -43,16 +54,14 @@ class Sampler(nn.Module):
...
@@ -43,16 +54,14 @@ class Sampler(nn.Module):
# We use float32 for probabilities and log probabilities.
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities (before applying top-p).
# Compute the log probabilities (before applying top-p
and top-k
).
logprobs
=
torch
.
log
(
probs
)
logprobs
=
torch
.
log
(
probs
)
# Apply top-p and top-k truncation.
# Apply top-p and top-k truncation.
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
probs
.
shape
[
0
]
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
probs
.
shape
[
0
]
if
any
(
p
<
1.0
for
p
in
top_ps
)
or
any
(
k
!=
-
1
for
k
in
top_ks
):
if
any
(
p
<
1.0
for
p
in
top_ps
)
or
any
(
k
!=
-
1
for
k
in
top_ks
):
p
=
torch
.
tensor
(
top_ps
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
probs
=
_apply_top_p_top_k
(
probs
,
top_ps
,
top_ks
)
k
=
torch
.
tensor
(
top_ks
,
dtype
=
torch
.
int
,
device
=
probs
.
device
)
probs
=
_apply_top_p_top_k
(
probs
,
p
,
k
)
# Sample the next tokens.
# Sample the next tokens.
return
_sample
(
probs
,
logprobs
,
input_metadata
)
return
_sample
(
probs
,
logprobs
,
input_metadata
)
...
@@ -72,6 +81,93 @@ def _prune_hidden_states(
...
@@ -72,6 +81,93 @@ def _prune_hidden_states(
return
hidden_states
[
last_token_indicies
]
return
hidden_states
[
last_token_indicies
]
def
_get_penalties
(
input_metadata
:
InputMetadata
,
)
->
Tuple
[
List
[
float
],
List
[
float
]]:
# Collect the presence and frequency penalties.
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
presence_penalties
.
append
(
p
)
frequency_penalties
.
append
(
f
)
else
:
# A generation token.
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
return
presence_penalties
,
frequency_penalties
def
_get_output_tokens
(
input_metadata
:
InputMetadata
,
)
->
List
[
List
[
int
]]:
output_tokens
:
List
[
List
[
int
]]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
_
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
# NOTE: While the prompt input usually has no output tokens,
# it may have output tokens in the case of recomputation.
seq_id
=
seq_ids
[
0
]
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
output_tokens
.
append
(
seq_data
.
output_token_ids
)
else
:
# A generation token.
for
seq_id
in
seq_ids
:
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
output_tokens
.
append
(
seq_data
.
output_token_ids
)
return
output_tokens
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
output_tokens
:
List
[
List
[
int
]],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
vocab_size
:
int
,
)
->
torch
.
Tensor
:
num_seqs
=
logits
.
shape
[
0
]
# Collect the indices of sequences that have non-zero penalties.
indices
=
[]
for
i
in
range
(
num_seqs
):
if
not
output_tokens
[
i
]:
continue
p
=
presence_penalties
[
i
]
f
=
frequency_penalties
[
i
]
if
p
==
0.0
and
f
==
0.0
:
continue
indices
.
append
(
i
)
# Return early if all sequences have zero penalties.
if
not
indices
:
return
logits
bin_counts
=
[]
for
i
in
indices
:
bin_counts
.
append
(
np
.
bincount
(
output_tokens
[
i
],
minlength
=
vocab_size
))
bin_counts
=
np
.
stack
(
bin_counts
,
axis
=
0
)
bin_counts
=
torch
.
from_numpy
(
bin_counts
).
to
(
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
frequency_penalties
=
[
frequency_penalties
[
i
]
for
i
in
indices
]
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
presence_penalties
=
[
presence_penalties
[
i
]
for
i
in
indices
]
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
[
indices
]
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
bin_counts
presence_mask
=
(
bin_counts
>
0.0
).
to
(
dtype
=
logits
.
dtype
)
logits
[
indices
]
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
presence_mask
return
logits
def
_get_temperatures
(
def
_get_temperatures
(
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
List
[
float
]:
)
->
List
[
float
]:
...
@@ -121,10 +217,11 @@ def _get_top_p_top_k(
...
@@ -121,10 +217,11 @@ def _get_top_p_top_k(
def
_apply_top_p_top_k
(
def
_apply_top_p_top_k
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
top_ps
:
List
[
float
]
,
k
:
torch
.
Tensor
,
top_ks
:
List
[
int
]
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO(woosuk): Optimize.
p
=
torch
.
tensor
(
top_ps
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
k
=
torch
.
tensor
(
top_ks
,
dtype
=
torch
.
int
,
device
=
probs
.
device
)
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
# Apply top-p.
# Apply top-p.
...
@@ -286,7 +383,8 @@ def _sample(
...
@@ -286,7 +383,8 @@ def _sample(
# Sample the next tokens.
# Sample the next tokens.
seq_logprobs
=
[
seq_logprobs
=
[
input_metadata
.
seq_logprobs
[
seq_id
]
for
seq_id
in
seq_ids
]
input_metadata
.
seq_data
[
seq_id
].
cumulative_logprobs
for
seq_id
in
seq_ids
]
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
...
...
cacheflow/sampling_params.py
View file @
55f8b0a5
...
@@ -6,6 +6,8 @@ class SamplingParams:
...
@@ -6,6 +6,8 @@ class SamplingParams:
def
__init__
(
def
__init__
(
self
,
self
,
n
:
int
,
n
:
int
,
presence_penalty
:
float
,
frequency_penalty
:
float
,
temperature
:
float
,
temperature
:
float
,
top_p
:
float
,
top_p
:
float
,
top_k
:
int
,
top_k
:
int
,
...
@@ -16,6 +18,12 @@ class SamplingParams:
...
@@ -16,6 +18,12 @@ class SamplingParams:
)
->
None
:
)
->
None
:
if
n
<
1
:
if
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
n
}
."
)
raise
ValueError
(
f
"n must be at least 1, got
{
n
}
."
)
if
not
-
2.0
<=
presence_penalty
<=
2.0
:
raise
ValueError
(
f
"presence_penalty must be in [-2, 2], got
{
presence_penalty
}
."
)
if
not
-
2.0
<=
frequency_penalty
<=
2.0
:
raise
ValueError
(
f
"frequency_penalty must be in [-2, 2], got
{
frequency_penalty
}
."
)
if
temperature
<
0.0
:
if
temperature
<
0.0
:
raise
ValueError
(
raise
ValueError
(
f
"temperature must be non-negative, got
{
temperature
}
."
)
f
"temperature must be non-negative, got
{
temperature
}
."
)
...
@@ -57,6 +65,8 @@ class SamplingParams:
...
@@ -57,6 +65,8 @@ class SamplingParams:
"top_k must be -1 when using greedy sampling."
)
"top_k must be -1 when using greedy sampling."
)
self
.
n
=
n
self
.
n
=
n
self
.
presence_penalty
=
presence_penalty
self
.
frequency_penalty
=
frequency_penalty
self
.
temperature
=
temperature
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
top_k
=
top_k
...
@@ -67,6 +77,8 @@ class SamplingParams:
...
@@ -67,6 +77,8 @@ class SamplingParams:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
,"
f
"top_k=
{
self
.
top_k
}
,"
...
@@ -77,13 +89,18 @@ class SamplingParams:
...
@@ -77,13 +89,18 @@ class SamplingParams:
@
classmethod
@
classmethod
def
from_dict
(
cls
,
d
:
Dict
)
->
"SamplingParams"
:
def
from_dict
(
cls
,
d
:
Dict
)
->
"SamplingParams"
:
return
cls
(
sampling_params
=
cls
(
n
=
d
.
get
(
"n"
,
1
),
n
=
d
.
pop
(
"n"
,
1
),
temperature
=
d
.
get
(
"temperature"
,
1.0
),
presence_penalty
=
d
.
pop
(
"presence_penalty"
,
0.0
),
top_p
=
d
.
get
(
"top_p"
,
1.0
),
frequency_penalty
=
d
.
pop
(
"frequency_penalty"
,
0.0
),
top_k
=
d
.
get
(
"top_k"
,
-
1
),
temperature
=
d
.
pop
(
"temperature"
,
1.0
),
use_beam_search
=
d
.
get
(
"use_beam_search"
,
False
),
top_p
=
d
.
pop
(
"top_p"
,
1.0
),
stop_token_ids
=
set
(
d
.
get
(
"stop_token_ids"
,
set
())),
top_k
=
d
.
pop
(
"top_k"
,
-
1
),
max_num_steps
=
d
.
get
(
"max_num_steps"
,
16
),
use_beam_search
=
d
.
pop
(
"use_beam_search"
,
False
),
num_logprobs
=
d
.
get
(
"num_logprobs"
,
0
),
stop_token_ids
=
set
(
d
.
pop
(
"stop_token_ids"
,
set
())),
max_num_steps
=
d
.
pop
(
"max_num_steps"
,
16
),
num_logprobs
=
d
.
pop
(
"num_logprobs"
,
0
),
)
)
if
d
:
raise
ValueError
(
f
"Unrecognized keys in dict:
{
d
.
keys
()
}
"
)
return
sampling_params
cacheflow/sequence.py
View file @
55f8b0a5
...
@@ -13,26 +13,55 @@ class SequenceStatus(enum.Enum):
...
@@ -13,26 +13,55 @@ class SequenceStatus(enum.Enum):
FINISHED
=
enum
.
auto
()
FINISHED
=
enum
.
auto
()
class
SequenceData
:
def
__init__
(
self
,
prompt_token_ids
:
List
[
int
],
)
->
None
:
self
.
prompt_token_ids
=
prompt_token_ids
self
.
output_token_ids
:
List
[
int
]
=
[]
self
.
cumulative_logprobs
=
0.0
def
get_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
+
len
(
self
.
prompt_token_ids
)
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
prompt_token_ids
+
self
.
output_token_ids
def
get_last_token_id
(
self
)
->
int
:
if
not
self
.
output_token_ids
:
return
self
.
prompt_token_ids
[
-
1
]
return
self
.
output_token_ids
[
-
1
]
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceData("
f
"prompt=
{
self
.
prompt
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"output_token_ids=
{
self
.
output_token_ids
}
)"
)
class
Sequence
:
class
Sequence
:
def
__init__
(
def
__init__
(
self
,
self
,
seq_id
:
int
,
seq_id
:
int
,
prompt
:
str
,
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
block_size
:
int
,
block_size
:
int
,
)
->
None
:
)
->
None
:
self
.
seq_id
=
seq_id
self
.
seq_id
=
seq_id
self
.
prompt
=
prompt
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
prompt_len
=
len
(
prompt_token_ids
)
self
.
data
=
SequenceData
(
prompt_token_ids
)
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
# Initialize the logical token blocks with the prompt token ids.
# Initialize the logical token blocks with the prompt token ids.
self
.
_append_tokens
(
prompt_token_ids
)
self
.
_append_tokens_to_blocks
(
prompt_token_ids
)
self
.
status
=
SequenceStatus
.
WAITING
self
.
status
=
SequenceStatus
.
WAITING
# Used for beam search.
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
cumulative_logprobs
=
0.0
def
_append_logical_block
(
self
)
->
None
:
def
_append_logical_block
(
self
)
->
None
:
block
=
LogicalTokenBlock
(
block
=
LogicalTokenBlock
(
...
@@ -41,7 +70,7 @@ class Sequence:
...
@@ -41,7 +70,7 @@ class Sequence:
)
)
self
.
logical_token_blocks
.
append
(
block
)
self
.
logical_token_blocks
.
append
(
block
)
def
_append_tokens
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
_append_tokens
_to_blocks
(
self
,
token_ids
:
List
[
int
])
->
None
:
while
token_ids
:
while
token_ids
:
if
not
self
.
logical_token_blocks
:
if
not
self
.
logical_token_blocks
:
self
.
_append_logical_block
()
self
.
_append_logical_block
()
...
@@ -57,26 +86,24 @@ class Sequence:
...
@@ -57,26 +86,24 @@ class Sequence:
def
append_token
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
float
])
->
None
:
def
append_token
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
float
])
->
None
:
assert
token_id
in
logprobs
assert
token_id
in
logprobs
self
.
_append_tokens
([
token_id
])
self
.
_append_tokens
_to_blocks
([
token_id
])
self
.
output_logprobs
.
append
(
logprobs
)
self
.
output_logprobs
.
append
(
logprobs
)
self
.
cumulative_logprobs
+=
logprobs
[
token_id
]
self
.
data
.
output_token_ids
.
append
(
token_id
)
self
.
data
.
cumulative_logprobs
+=
logprobs
[
token_id
]
def
get_len
(
self
)
->
int
:
def
get_len
(
self
)
->
int
:
return
s
um
(
block
.
num_tokens
for
block
in
self
.
logical_token_blocks
)
return
s
elf
.
data
.
get_len
(
)
def
get_token_ids
(
self
)
->
List
[
int
]:
def
get_token_ids
(
self
)
->
List
[
int
]:
token_ids
:
List
[
int
]
=
[]
return
self
.
data
.
get_token_ids
()
for
block
in
self
.
logical_token_blocks
:
token_ids
.
extend
(
block
.
get_token_ids
())
return
token_ids
def
get_last_token_id
(
self
)
->
int
:
def
get_last_token_id
(
self
)
->
int
:
return
self
.
logical_token_blocks
[
-
1
]
.
get_last_token_id
()
return
self
.
data
.
get_last_token_id
()
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
'Sequence'
:
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
'Sequence'
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
cumulative_logprobs
=
self
.
cumulative_logprobs
child_seq
.
data
=
copy
.
deepcopy
(
self
.
data
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
'Sequence(seq_id=
{
self
.
seq_id
}
, '
return
(
f
'Sequence(seq_id=
{
self
.
seq_id
}
, '
...
@@ -128,17 +155,13 @@ class SequenceGroupMetadata:
...
@@ -128,17 +155,13 @@ class SequenceGroupMetadata:
self
,
self
,
group_id
:
int
,
group_id
:
int
,
is_prompt
:
bool
,
is_prompt
:
bool
,
input_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> token ids.
seq_data
:
Dict
[
int
,
SequenceData
],
# Seq id -> sequence data.
context_len
:
int
,
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id ->
L
ist of physical block numbers.
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id ->
l
ist of physical block numbers.
)
->
None
:
)
->
None
:
self
.
group_id
=
group_id
self
.
group_id
=
group_id
self
.
is_prompt
=
is_prompt
self
.
is_prompt
=
is_prompt
self
.
input_tokens
=
input_tokens
self
.
seq_data
=
seq_data
self
.
context_len
=
context_len
self
.
seq_logprobs
=
seq_logprobs
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
...
...
cacheflow/worker/worker.py
View file @
55f8b0a5
from
typing
import
Dict
,
List
,
Tuple
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -8,8 +8,8 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
...
@@ -8,8 +8,8 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
initialize_all_reduce_launcher
,
initialize_all_reduce_launcher
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceGroupMetadata
from
cacheflow.sequence
import
(
SequenceData
,
SequenceGroupMetadata
,
from
cacheflow.sequence
import
SequenceOutputs
SequenceOutputs
)
from
cacheflow.worker.cache_engine
import
CacheEngine
from
cacheflow.worker.cache_engine
import
CacheEngine
...
@@ -72,7 +72,6 @@ class Worker:
...
@@ -72,7 +72,6 @@ class Worker:
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
def
init_distributed_environment
(
self
,
def
init_distributed_environment
(
self
,
distributed_init_method
:
str
,
distributed_init_method
:
str
,
rank
:
int
,
rank
:
int
,
...
@@ -96,7 +95,6 @@ class Worker:
...
@@ -96,7 +95,6 @@ class Worker:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
...
@@ -107,15 +105,15 @@ class Worker:
...
@@ -107,15 +105,15 @@ class Worker:
if
not
seq_group_metadata
.
is_prompt
:
if
not
seq_group_metadata
.
is_prompt
:
continue
continue
seq_ids
=
list
(
seq_group_metadata
.
input_tokens
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
seq_group_metadata
.
seq_logprobs
)
# Use any sequence in the group.
# Use any sequence in the group.
seq_id
=
seq_ids
[
0
]
seq_id
=
seq_ids
[
0
]
prompt_tokens
=
seq_group_metadata
.
input_tokens
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt_len
=
len
(
prompt_tokens
)
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
...
@@ -141,27 +139,26 @@ class Worker:
...
@@ -141,27 +139,26 @@ class Worker:
if
seq_group_metadata
.
is_prompt
:
if
seq_group_metadata
.
is_prompt
:
continue
continue
seq_ids
=
list
(
seq_group_metadata
.
input_tokens
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
seq_group_metadata
.
seq_logprobs
)
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
assert
len
(
seq_group_metadata
.
input_tokens
[
seq_id
]
)
==
1
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_
group_metadata
.
input_tokens
[
seq_id
][
0
]
generation_token
=
seq_
data
.
get_last_token_id
()
input_tokens
.
append
(
generation_token
)
input_tokens
.
append
(
generation_token
)
position
=
seq_group_metadata
.
context_len
-
1
context_len
=
seq_data
.
get_len
()
position
=
context_len
-
1
input_positions
.
append
(
position
)
input_positions
.
append
(
position
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
generation_block_tables
.
append
(
block_table
)
generation_block_tables
.
append
(
block_table
)
max_context_len
=
max
(
max_context_len
=
max
(
max_context_len
,
context_len
)
max_context_len
,
seq_group_metadata
.
context_len
)
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
,
len
(
block_table
))
max_num_blocks_per_seq
,
len
(
block_table
))
context_lens
.
append
(
seq_group_metadata
.
context_len
)
context_lens
.
append
(
context_len
)
block_number
=
block_table
[
position
//
self
.
block_size
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
block_offset
=
position
%
self
.
block_size
...
@@ -188,9 +185,13 @@ class Worker:
...
@@ -188,9 +185,13 @@ class Worker:
block_tables_tensor
=
torch
.
tensor
(
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
seq_groups
=
seq_groups
,
seq_groups
=
seq_groups
,
seq_
logprobs
=
seq_logprobs
,
seq_
data
=
seq_data
,
prompt_lens
=
prompt_lens
,
prompt_lens
=
prompt_lens
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
context_lens
=
context_lens_tensor
,
...
...
simple_server.py
View file @
55f8b0a5
...
@@ -11,8 +11,8 @@ def main(args: argparse.Namespace):
...
@@ -11,8 +11,8 @@ def main(args: argparse.Namespace):
# Test the following inputs.
# Test the following inputs.
test_inputs
=
[
test_inputs
=
[
(
"A robot may not injure a human being"
,
{}),
# Use default parameters.
(
"A robot may not injure a human being"
,
{}),
# Use default parameters.
(
"To be or not to be,"
,
{
"temperature"
:
0.8
,
"top_k"
:
5
}),
(
"To be or not to be,"
,
{
"temperature"
:
0.8
,
"top_k"
:
5
,
"presence_penalty"
:
0.2
}),
(
"What is the meaning of life?"
,
{
"n"
:
2
,
"temperature"
:
0.8
,
"top_p"
:
0.95
}),
(
"What is the meaning of life?"
,
{
"n"
:
2
,
"temperature"
:
0.8
,
"top_p"
:
0.95
,
"frequency_penalty"
:
0.1
}),
(
"It is only with the heart that one can see rightly"
,
{
"n"
:
3
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
}),
(
"It is only with the heart that one can see rightly"
,
{
"n"
:
3
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
}),
]
]
while
True
:
while
True
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment