Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
55f8b0a5
Unverified
Commit
55f8b0a5
authored
May 10, 2023
by
Woosuk Kwon
Committed by
GitHub
May 10, 2023
Browse files
Implement presence and frequency penalties (#95)
parent
9f88db35
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
215 additions
and
82 deletions
+215
-82
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+7
-16
cacheflow/frontend/fastapi_frontend.py
cacheflow/frontend/fastapi_frontend.py
+1
-1
cacheflow/frontend/simple_frontend.py
cacheflow/frontend/simple_frontend.py
+3
-2
cacheflow/model_executor/input_metadata.py
cacheflow/model_executor/input_metadata.py
+6
-4
cacheflow/model_executor/layers/sampler.py
cacheflow/model_executor/layers/sampler.py
+106
-8
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+26
-9
cacheflow/sequence.py
cacheflow/sequence.py
+46
-23
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+18
-17
simple_server.py
simple_server.py
+2
-2
No files found.
cacheflow/core/scheduler.py
View file @
55f8b0a5
...
@@ -3,11 +3,12 @@ import time
...
@@ -3,11 +3,12 @@ import time
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
cacheflow.core.block_manager
import
BlockSpaceManager
from
cacheflow.core.block_manager
import
BlockSpaceManager
from
cacheflow.logger
import
init_logger
from
cacheflow.core.policy
import
PolicyFactory
from
cacheflow.core.policy
import
PolicyFactory
from
cacheflow.logger
import
init_logger
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
from
cacheflow.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceOutputs
,
SequenceStatus
)
SequenceGroupMetadata
,
SequenceOutputs
,
SequenceStatus
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -246,27 +247,17 @@ class Scheduler:
...
@@ -246,27 +247,17 @@ class Scheduler:
group_id
=
seq_group
.
group_id
group_id
=
seq_group
.
group_id
is_prompt
=
group_id
in
prompt_group_ids
is_prompt
=
group_id
in
prompt_group_ids
input_tokens
:
Dict
[
int
,
List
[
int
]]
=
{}
seq_data
:
Dict
[
int
,
List
[
SequenceData
]]
=
{}
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
seq_id
=
seq
.
seq_id
seq_id
=
seq
.
seq_id
seq_data
[
seq_id
]
=
seq
.
data
block_tables
[
seq_id
]
=
self
.
block_manager
.
get_block_table
(
seq
)
block_tables
[
seq_id
]
=
self
.
block_manager
.
get_block_table
(
seq
)
if
is_prompt
:
input_tokens
[
seq_id
]
=
seq
.
get_token_ids
()
else
:
input_tokens
[
seq_id
]
=
[
seq
.
get_last_token_id
()]
seq_logprobs
[
seq_id
]
=
seq
.
cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len
=
seq
.
get_len
()
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
group_id
=
group_id
,
group_id
=
group_id
,
is_prompt
=
is_prompt
,
is_prompt
=
is_prompt
,
input_tokens
=
input_tokens
,
seq_data
=
seq_data
,
context_len
=
seq_len
,
seq_logprobs
=
seq_logprobs
,
sampling_params
=
self
.
sampling_params
[
group_id
],
sampling_params
=
self
.
sampling_params
[
group_id
],
block_tables
=
block_tables
,
block_tables
=
block_tables
,
)
)
...
...
cacheflow/frontend/fastapi_frontend.py
View file @
55f8b0a5
...
@@ -96,7 +96,7 @@ class FastAPIServer:
...
@@ -96,7 +96,7 @@ class FastAPIServer:
seqs
:
List
[
Sequence
]
=
[]
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
token_ids
,
block_size
=
self
.
block_size
)
seq
=
Sequence
(
seq_id
,
prompt
,
token_ids
,
block_size
=
self
.
block_size
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
...
...
cacheflow/frontend/simple_frontend.py
View file @
55f8b0a5
...
@@ -35,10 +35,11 @@ class SimpleFrontend:
...
@@ -35,10 +35,11 @@ class SimpleFrontend:
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
)
->
None
:
)
->
None
:
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
self
.
_add_query
(
token_ids
,
sampling_params
)
self
.
_add_query
(
prompt
,
token_ids
,
sampling_params
)
def
_add_query
(
def
_add_query
(
self
,
self
,
prompt
:
str
,
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
...
@@ -48,7 +49,7 @@ class SimpleFrontend:
...
@@ -48,7 +49,7 @@ class SimpleFrontend:
seqs
:
List
[
Sequence
]
=
[]
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
token_ids
,
block_size
=
self
.
block_size
)
seq
=
Sequence
(
seq_id
,
prompt
,
token_ids
,
block_size
=
self
.
block_size
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
group_id
=
next
(
self
.
seq_group_counter
)
group_id
=
next
(
self
.
seq_group_counter
)
...
...
cacheflow/model_executor/input_metadata.py
View file @
55f8b0a5
from
typing
import
List
,
Dic
t
,
Tuple
from
typing
import
Dict
,
Lis
t
,
Tuple
import
torch
import
torch
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceData
class
InputMetadata
:
class
InputMetadata
:
def
__init__
(
def
__init__
(
self
,
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
# List of (seq_ids, sampling_params).
seq_
logprobs
:
Dict
[
int
,
flo
at
],
# Seq
id ->
cumulative logprobs
.
seq_
data
:
Dict
[
int
,
SequenceD
at
a
],
# Seq
_
id ->
SequenceData
.
prompt_lens
:
List
[
int
],
prompt_lens
:
List
[
int
],
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
...
@@ -19,7 +20,7 @@ class InputMetadata:
...
@@ -19,7 +20,7 @@ class InputMetadata:
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_groups
=
seq_groups
self
.
seq_
logprobs
=
seq_logprobs
self
.
seq_
data
=
seq_data
self
.
prompt_lens
=
prompt_lens
self
.
prompt_lens
=
prompt_lens
self
.
slot_mapping
=
slot_mapping
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
self
.
context_lens
=
context_lens
...
@@ -39,6 +40,7 @@ class InputMetadata:
...
@@ -39,6 +40,7 @@ class InputMetadata:
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
# Print only useful metadata.
return
(
f
'InputMetadata('
return
(
f
'InputMetadata('
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
...
...
cacheflow/model_executor/layers/sampler.py
View file @
55f8b0a5
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -31,6 +32,16 @@ class Sampler(nn.Module):
...
@@ -31,6 +32,16 @@ class Sampler(nn.Module):
# Remove paddings in vocab (if any).
# Remove paddings in vocab (if any).
logits
=
logits
[:,
:
self
.
vocab_size
]
logits
=
logits
[:,
:
self
.
vocab_size
]
# Apply presence and frequency penalties.
output_tokens
=
_get_output_tokens
(
input_metadata
)
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
presence_penalties
,
frequency_penalties
=
_get_penalties
(
input_metadata
)
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
logits
=
_apply_penalties
(
logits
,
output_tokens
,
presence_penalties
,
frequency_penalties
,
self
.
vocab_size
)
# Apply temperature scaling.
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input_metadata
)
temperatures
=
_get_temperatures
(
input_metadata
)
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
...
@@ -43,16 +54,14 @@ class Sampler(nn.Module):
...
@@ -43,16 +54,14 @@ class Sampler(nn.Module):
# We use float32 for probabilities and log probabilities.
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities (before applying top-p).
# Compute the log probabilities (before applying top-p
and top-k
).
logprobs
=
torch
.
log
(
probs
)
logprobs
=
torch
.
log
(
probs
)
# Apply top-p and top-k truncation.
# Apply top-p and top-k truncation.
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
probs
.
shape
[
0
]
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
probs
.
shape
[
0
]
if
any
(
p
<
1.0
for
p
in
top_ps
)
or
any
(
k
!=
-
1
for
k
in
top_ks
):
if
any
(
p
<
1.0
for
p
in
top_ps
)
or
any
(
k
!=
-
1
for
k
in
top_ks
):
p
=
torch
.
tensor
(
top_ps
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
probs
=
_apply_top_p_top_k
(
probs
,
top_ps
,
top_ks
)
k
=
torch
.
tensor
(
top_ks
,
dtype
=
torch
.
int
,
device
=
probs
.
device
)
probs
=
_apply_top_p_top_k
(
probs
,
p
,
k
)
# Sample the next tokens.
# Sample the next tokens.
return
_sample
(
probs
,
logprobs
,
input_metadata
)
return
_sample
(
probs
,
logprobs
,
input_metadata
)
...
@@ -72,6 +81,93 @@ def _prune_hidden_states(
...
@@ -72,6 +81,93 @@ def _prune_hidden_states(
return
hidden_states
[
last_token_indicies
]
return
hidden_states
[
last_token_indicies
]
def
_get_penalties
(
input_metadata
:
InputMetadata
,
)
->
Tuple
[
List
[
float
],
List
[
float
]]:
# Collect the presence and frequency penalties.
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
presence_penalties
.
append
(
p
)
frequency_penalties
.
append
(
f
)
else
:
# A generation token.
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
return
presence_penalties
,
frequency_penalties
def
_get_output_tokens
(
input_metadata
:
InputMetadata
,
)
->
List
[
List
[
int
]]:
output_tokens
:
List
[
List
[
int
]]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
_
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
# NOTE: While the prompt input usually has no output tokens,
# it may have output tokens in the case of recomputation.
seq_id
=
seq_ids
[
0
]
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
output_tokens
.
append
(
seq_data
.
output_token_ids
)
else
:
# A generation token.
for
seq_id
in
seq_ids
:
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
output_tokens
.
append
(
seq_data
.
output_token_ids
)
return
output_tokens
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
output_tokens
:
List
[
List
[
int
]],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
vocab_size
:
int
,
)
->
torch
.
Tensor
:
num_seqs
=
logits
.
shape
[
0
]
# Collect the indices of sequences that have non-zero penalties.
indices
=
[]
for
i
in
range
(
num_seqs
):
if
not
output_tokens
[
i
]:
continue
p
=
presence_penalties
[
i
]
f
=
frequency_penalties
[
i
]
if
p
==
0.0
and
f
==
0.0
:
continue
indices
.
append
(
i
)
# Return early if all sequences have zero penalties.
if
not
indices
:
return
logits
bin_counts
=
[]
for
i
in
indices
:
bin_counts
.
append
(
np
.
bincount
(
output_tokens
[
i
],
minlength
=
vocab_size
))
bin_counts
=
np
.
stack
(
bin_counts
,
axis
=
0
)
bin_counts
=
torch
.
from_numpy
(
bin_counts
).
to
(
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
frequency_penalties
=
[
frequency_penalties
[
i
]
for
i
in
indices
]
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
presence_penalties
=
[
presence_penalties
[
i
]
for
i
in
indices
]
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
[
indices
]
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
bin_counts
presence_mask
=
(
bin_counts
>
0.0
).
to
(
dtype
=
logits
.
dtype
)
logits
[
indices
]
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
presence_mask
return
logits
def
_get_temperatures
(
def
_get_temperatures
(
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
List
[
float
]:
)
->
List
[
float
]:
...
@@ -121,10 +217,11 @@ def _get_top_p_top_k(
...
@@ -121,10 +217,11 @@ def _get_top_p_top_k(
def
_apply_top_p_top_k
(
def
_apply_top_p_top_k
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
top_ps
:
List
[
float
]
,
k
:
torch
.
Tensor
,
top_ks
:
List
[
int
]
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO(woosuk): Optimize.
p
=
torch
.
tensor
(
top_ps
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
k
=
torch
.
tensor
(
top_ks
,
dtype
=
torch
.
int
,
device
=
probs
.
device
)
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
# Apply top-p.
# Apply top-p.
...
@@ -286,7 +383,8 @@ def _sample(
...
@@ -286,7 +383,8 @@ def _sample(
# Sample the next tokens.
# Sample the next tokens.
seq_logprobs
=
[
seq_logprobs
=
[
input_metadata
.
seq_logprobs
[
seq_id
]
for
seq_id
in
seq_ids
]
input_metadata
.
seq_data
[
seq_id
].
cumulative_logprobs
for
seq_id
in
seq_ids
]
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
...
...
cacheflow/sampling_params.py
View file @
55f8b0a5
...
@@ -6,6 +6,8 @@ class SamplingParams:
...
@@ -6,6 +6,8 @@ class SamplingParams:
def
__init__
(
def
__init__
(
self
,
self
,
n
:
int
,
n
:
int
,
presence_penalty
:
float
,
frequency_penalty
:
float
,
temperature
:
float
,
temperature
:
float
,
top_p
:
float
,
top_p
:
float
,
top_k
:
int
,
top_k
:
int
,
...
@@ -16,6 +18,12 @@ class SamplingParams:
...
@@ -16,6 +18,12 @@ class SamplingParams:
)
->
None
:
)
->
None
:
if
n
<
1
:
if
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
n
}
."
)
raise
ValueError
(
f
"n must be at least 1, got
{
n
}
."
)
if
not
-
2.0
<=
presence_penalty
<=
2.0
:
raise
ValueError
(
f
"presence_penalty must be in [-2, 2], got
{
presence_penalty
}
."
)
if
not
-
2.0
<=
frequency_penalty
<=
2.0
:
raise
ValueError
(
f
"frequency_penalty must be in [-2, 2], got
{
frequency_penalty
}
."
)
if
temperature
<
0.0
:
if
temperature
<
0.0
:
raise
ValueError
(
raise
ValueError
(
f
"temperature must be non-negative, got
{
temperature
}
."
)
f
"temperature must be non-negative, got
{
temperature
}
."
)
...
@@ -57,6 +65,8 @@ class SamplingParams:
...
@@ -57,6 +65,8 @@ class SamplingParams:
"top_k must be -1 when using greedy sampling."
)
"top_k must be -1 when using greedy sampling."
)
self
.
n
=
n
self
.
n
=
n
self
.
presence_penalty
=
presence_penalty
self
.
frequency_penalty
=
frequency_penalty
self
.
temperature
=
temperature
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
top_k
=
top_k
...
@@ -67,6 +77,8 @@ class SamplingParams:
...
@@ -67,6 +77,8 @@ class SamplingParams:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
,"
f
"top_k=
{
self
.
top_k
}
,"
...
@@ -77,13 +89,18 @@ class SamplingParams:
...
@@ -77,13 +89,18 @@ class SamplingParams:
@
classmethod
@
classmethod
def
from_dict
(
cls
,
d
:
Dict
)
->
"SamplingParams"
:
def
from_dict
(
cls
,
d
:
Dict
)
->
"SamplingParams"
:
return
cls
(
sampling_params
=
cls
(
n
=
d
.
get
(
"n"
,
1
),
n
=
d
.
pop
(
"n"
,
1
),
temperature
=
d
.
get
(
"temperature"
,
1.0
),
presence_penalty
=
d
.
pop
(
"presence_penalty"
,
0.0
),
top_p
=
d
.
get
(
"top_p"
,
1.0
),
frequency_penalty
=
d
.
pop
(
"frequency_penalty"
,
0.0
),
top_k
=
d
.
get
(
"top_k"
,
-
1
),
temperature
=
d
.
pop
(
"temperature"
,
1.0
),
use_beam_search
=
d
.
get
(
"use_beam_search"
,
False
),
top_p
=
d
.
pop
(
"top_p"
,
1.0
),
stop_token_ids
=
set
(
d
.
get
(
"stop_token_ids"
,
set
())),
top_k
=
d
.
pop
(
"top_k"
,
-
1
),
max_num_steps
=
d
.
get
(
"max_num_steps"
,
16
),
use_beam_search
=
d
.
pop
(
"use_beam_search"
,
False
),
num_logprobs
=
d
.
get
(
"num_logprobs"
,
0
),
stop_token_ids
=
set
(
d
.
pop
(
"stop_token_ids"
,
set
())),
max_num_steps
=
d
.
pop
(
"max_num_steps"
,
16
),
num_logprobs
=
d
.
pop
(
"num_logprobs"
,
0
),
)
)
if
d
:
raise
ValueError
(
f
"Unrecognized keys in dict:
{
d
.
keys
()
}
"
)
return
sampling_params
cacheflow/sequence.py
View file @
55f8b0a5
...
@@ -13,26 +13,55 @@ class SequenceStatus(enum.Enum):
...
@@ -13,26 +13,55 @@ class SequenceStatus(enum.Enum):
FINISHED
=
enum
.
auto
()
FINISHED
=
enum
.
auto
()
class
SequenceData
:
def
__init__
(
self
,
prompt_token_ids
:
List
[
int
],
)
->
None
:
self
.
prompt_token_ids
=
prompt_token_ids
self
.
output_token_ids
:
List
[
int
]
=
[]
self
.
cumulative_logprobs
=
0.0
def
get_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
+
len
(
self
.
prompt_token_ids
)
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
prompt_token_ids
+
self
.
output_token_ids
def
get_last_token_id
(
self
)
->
int
:
if
not
self
.
output_token_ids
:
return
self
.
prompt_token_ids
[
-
1
]
return
self
.
output_token_ids
[
-
1
]
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceData("
f
"prompt=
{
self
.
prompt
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"output_token_ids=
{
self
.
output_token_ids
}
)"
)
class
Sequence
:
class
Sequence
:
def
__init__
(
def
__init__
(
self
,
self
,
seq_id
:
int
,
seq_id
:
int
,
prompt
:
str
,
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
block_size
:
int
,
block_size
:
int
,
)
->
None
:
)
->
None
:
self
.
seq_id
=
seq_id
self
.
seq_id
=
seq_id
self
.
prompt
=
prompt
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
prompt_len
=
len
(
prompt_token_ids
)
self
.
data
=
SequenceData
(
prompt_token_ids
)
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
# Initialize the logical token blocks with the prompt token ids.
# Initialize the logical token blocks with the prompt token ids.
self
.
_append_tokens
(
prompt_token_ids
)
self
.
_append_tokens_to_blocks
(
prompt_token_ids
)
self
.
status
=
SequenceStatus
.
WAITING
self
.
status
=
SequenceStatus
.
WAITING
# Used for beam search.
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
cumulative_logprobs
=
0.0
def
_append_logical_block
(
self
)
->
None
:
def
_append_logical_block
(
self
)
->
None
:
block
=
LogicalTokenBlock
(
block
=
LogicalTokenBlock
(
...
@@ -41,7 +70,7 @@ class Sequence:
...
@@ -41,7 +70,7 @@ class Sequence:
)
)
self
.
logical_token_blocks
.
append
(
block
)
self
.
logical_token_blocks
.
append
(
block
)
def
_append_tokens
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
_append_tokens
_to_blocks
(
self
,
token_ids
:
List
[
int
])
->
None
:
while
token_ids
:
while
token_ids
:
if
not
self
.
logical_token_blocks
:
if
not
self
.
logical_token_blocks
:
self
.
_append_logical_block
()
self
.
_append_logical_block
()
...
@@ -57,26 +86,24 @@ class Sequence:
...
@@ -57,26 +86,24 @@ class Sequence:
def
append_token
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
float
])
->
None
:
def
append_token
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
float
])
->
None
:
assert
token_id
in
logprobs
assert
token_id
in
logprobs
self
.
_append_tokens
([
token_id
])
self
.
_append_tokens
_to_blocks
([
token_id
])
self
.
output_logprobs
.
append
(
logprobs
)
self
.
output_logprobs
.
append
(
logprobs
)
self
.
cumulative_logprobs
+=
logprobs
[
token_id
]
self
.
data
.
output_token_ids
.
append
(
token_id
)
self
.
data
.
cumulative_logprobs
+=
logprobs
[
token_id
]
def
get_len
(
self
)
->
int
:
def
get_len
(
self
)
->
int
:
return
s
um
(
block
.
num_tokens
for
block
in
self
.
logical_token_blocks
)
return
s
elf
.
data
.
get_len
(
)
def
get_token_ids
(
self
)
->
List
[
int
]:
def
get_token_ids
(
self
)
->
List
[
int
]:
token_ids
:
List
[
int
]
=
[]
return
self
.
data
.
get_token_ids
()
for
block
in
self
.
logical_token_blocks
:
token_ids
.
extend
(
block
.
get_token_ids
())
return
token_ids
def
get_last_token_id
(
self
)
->
int
:
def
get_last_token_id
(
self
)
->
int
:
return
self
.
logical_token_blocks
[
-
1
]
.
get_last_token_id
()
return
self
.
data
.
get_last_token_id
()
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
'Sequence'
:
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
'Sequence'
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
cumulative_logprobs
=
self
.
cumulative_logprobs
child_seq
.
data
=
copy
.
deepcopy
(
self
.
data
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
'Sequence(seq_id=
{
self
.
seq_id
}
, '
return
(
f
'Sequence(seq_id=
{
self
.
seq_id
}
, '
...
@@ -128,17 +155,13 @@ class SequenceGroupMetadata:
...
@@ -128,17 +155,13 @@ class SequenceGroupMetadata:
self
,
self
,
group_id
:
int
,
group_id
:
int
,
is_prompt
:
bool
,
is_prompt
:
bool
,
input_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> token ids.
seq_data
:
Dict
[
int
,
SequenceData
],
# Seq id -> sequence data.
context_len
:
int
,
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id ->
L
ist of physical block numbers.
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id ->
l
ist of physical block numbers.
)
->
None
:
)
->
None
:
self
.
group_id
=
group_id
self
.
group_id
=
group_id
self
.
is_prompt
=
is_prompt
self
.
is_prompt
=
is_prompt
self
.
input_tokens
=
input_tokens
self
.
seq_data
=
seq_data
self
.
context_len
=
context_len
self
.
seq_logprobs
=
seq_logprobs
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
...
...
cacheflow/worker/worker.py
View file @
55f8b0a5
from
typing
import
Dict
,
List
,
Tuple
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -8,8 +8,8 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
...
@@ -8,8 +8,8 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
initialize_all_reduce_launcher
,
initialize_all_reduce_launcher
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceGroupMetadata
from
cacheflow.sequence
import
(
SequenceData
,
SequenceGroupMetadata
,
from
cacheflow.sequence
import
SequenceOutputs
SequenceOutputs
)
from
cacheflow.worker.cache_engine
import
CacheEngine
from
cacheflow.worker.cache_engine
import
CacheEngine
...
@@ -72,7 +72,6 @@ class Worker:
...
@@ -72,7 +72,6 @@ class Worker:
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
def
init_distributed_environment
(
self
,
def
init_distributed_environment
(
self
,
distributed_init_method
:
str
,
distributed_init_method
:
str
,
rank
:
int
,
rank
:
int
,
...
@@ -96,7 +95,6 @@ class Worker:
...
@@ -96,7 +95,6 @@ class Worker:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
...
@@ -107,15 +105,15 @@ class Worker:
...
@@ -107,15 +105,15 @@ class Worker:
if
not
seq_group_metadata
.
is_prompt
:
if
not
seq_group_metadata
.
is_prompt
:
continue
continue
seq_ids
=
list
(
seq_group_metadata
.
input_tokens
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
seq_group_metadata
.
seq_logprobs
)
# Use any sequence in the group.
# Use any sequence in the group.
seq_id
=
seq_ids
[
0
]
seq_id
=
seq_ids
[
0
]
prompt_tokens
=
seq_group_metadata
.
input_tokens
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt_len
=
len
(
prompt_tokens
)
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
...
@@ -141,27 +139,26 @@ class Worker:
...
@@ -141,27 +139,26 @@ class Worker:
if
seq_group_metadata
.
is_prompt
:
if
seq_group_metadata
.
is_prompt
:
continue
continue
seq_ids
=
list
(
seq_group_metadata
.
input_tokens
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
seq_group_metadata
.
seq_logprobs
)
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
assert
len
(
seq_group_metadata
.
input_tokens
[
seq_id
]
)
==
1
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_
group_metadata
.
input_tokens
[
seq_id
][
0
]
generation_token
=
seq_
data
.
get_last_token_id
()
input_tokens
.
append
(
generation_token
)
input_tokens
.
append
(
generation_token
)
position
=
seq_group_metadata
.
context_len
-
1
context_len
=
seq_data
.
get_len
()
position
=
context_len
-
1
input_positions
.
append
(
position
)
input_positions
.
append
(
position
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
generation_block_tables
.
append
(
block_table
)
generation_block_tables
.
append
(
block_table
)
max_context_len
=
max
(
max_context_len
=
max
(
max_context_len
,
context_len
)
max_context_len
,
seq_group_metadata
.
context_len
)
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
,
len
(
block_table
))
max_num_blocks_per_seq
,
len
(
block_table
))
context_lens
.
append
(
seq_group_metadata
.
context_len
)
context_lens
.
append
(
context_len
)
block_number
=
block_table
[
position
//
self
.
block_size
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
block_offset
=
position
%
self
.
block_size
...
@@ -188,9 +185,13 @@ class Worker:
...
@@ -188,9 +185,13 @@ class Worker:
block_tables_tensor
=
torch
.
tensor
(
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
seq_groups
=
seq_groups
,
seq_groups
=
seq_groups
,
seq_
logprobs
=
seq_logprobs
,
seq_
data
=
seq_data
,
prompt_lens
=
prompt_lens
,
prompt_lens
=
prompt_lens
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
context_lens
=
context_lens_tensor
,
...
...
simple_server.py
View file @
55f8b0a5
...
@@ -11,8 +11,8 @@ def main(args: argparse.Namespace):
...
@@ -11,8 +11,8 @@ def main(args: argparse.Namespace):
# Test the following inputs.
# Test the following inputs.
test_inputs
=
[
test_inputs
=
[
(
"A robot may not injure a human being"
,
{}),
# Use default parameters.
(
"A robot may not injure a human being"
,
{}),
# Use default parameters.
(
"To be or not to be,"
,
{
"temperature"
:
0.8
,
"top_k"
:
5
}),
(
"To be or not to be,"
,
{
"temperature"
:
0.8
,
"top_k"
:
5
,
"presence_penalty"
:
0.2
}),
(
"What is the meaning of life?"
,
{
"n"
:
2
,
"temperature"
:
0.8
,
"top_p"
:
0.95
}),
(
"What is the meaning of life?"
,
{
"n"
:
2
,
"temperature"
:
0.8
,
"top_p"
:
0.95
,
"frequency_penalty"
:
0.1
}),
(
"It is only with the heart that one can see rightly"
,
{
"n"
:
3
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
}),
(
"It is only with the heart that one can see rightly"
,
{
"n"
:
3
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
}),
]
]
while
True
:
while
True
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment