Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
55f8b0a5
Unverified
Commit
55f8b0a5
authored
May 10, 2023
by
Woosuk Kwon
Committed by
GitHub
May 10, 2023
Browse files
Implement presence and frequency penalties (#95)
parent
9f88db35
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
215 additions
and
82 deletions
+215
-82
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+7
-16
cacheflow/frontend/fastapi_frontend.py
cacheflow/frontend/fastapi_frontend.py
+1
-1
cacheflow/frontend/simple_frontend.py
cacheflow/frontend/simple_frontend.py
+3
-2
cacheflow/model_executor/input_metadata.py
cacheflow/model_executor/input_metadata.py
+6
-4
cacheflow/model_executor/layers/sampler.py
cacheflow/model_executor/layers/sampler.py
+106
-8
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+26
-9
cacheflow/sequence.py
cacheflow/sequence.py
+46
-23
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+18
-17
simple_server.py
simple_server.py
+2
-2
No files found.
cacheflow/core/scheduler.py
View file @
55f8b0a5
...
...
@@ -3,11 +3,12 @@ import time
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
cacheflow.core.block_manager
import
BlockSpaceManager
from
cacheflow.logger
import
init_logger
from
cacheflow.core.policy
import
PolicyFactory
from
cacheflow.logger
import
init_logger
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceOutputs
,
SequenceStatus
)
from
cacheflow.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceOutputs
,
SequenceStatus
)
logger
=
init_logger
(
__name__
)
...
...
@@ -246,27 +247,17 @@ class Scheduler:
group_id
=
seq_group
.
group_id
is_prompt
=
group_id
in
prompt_group_ids
input_tokens
:
Dict
[
int
,
List
[
int
]]
=
{}
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
seq_data
:
Dict
[
int
,
List
[
SequenceData
]]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
seq_id
=
seq
.
seq_id
seq_data
[
seq_id
]
=
seq
.
data
block_tables
[
seq_id
]
=
self
.
block_manager
.
get_block_table
(
seq
)
if
is_prompt
:
input_tokens
[
seq_id
]
=
seq
.
get_token_ids
()
else
:
input_tokens
[
seq_id
]
=
[
seq
.
get_last_token_id
()]
seq_logprobs
[
seq_id
]
=
seq
.
cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len
=
seq
.
get_len
()
seq_group_metadata
=
SequenceGroupMetadata
(
group_id
=
group_id
,
is_prompt
=
is_prompt
,
input_tokens
=
input_tokens
,
context_len
=
seq_len
,
seq_logprobs
=
seq_logprobs
,
seq_data
=
seq_data
,
sampling_params
=
self
.
sampling_params
[
group_id
],
block_tables
=
block_tables
,
)
...
...
cacheflow/frontend/fastapi_frontend.py
View file @
55f8b0a5
...
...
@@ -96,7 +96,7 @@ class FastAPIServer:
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
token_ids
,
block_size
=
self
.
block_size
)
seq
=
Sequence
(
seq_id
,
prompt
,
token_ids
,
block_size
=
self
.
block_size
)
seqs
.
append
(
seq
)
arrival_time
=
time
.
time
()
...
...
cacheflow/frontend/simple_frontend.py
View file @
55f8b0a5
...
...
@@ -35,10 +35,11 @@ class SimpleFrontend:
sampling_params
:
SamplingParams
,
)
->
None
:
token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
self
.
_add_query
(
token_ids
,
sampling_params
)
self
.
_add_query
(
prompt
,
token_ids
,
sampling_params
)
def
_add_query
(
self
,
prompt
:
str
,
token_ids
:
List
[
int
],
sampling_params
:
SamplingParams
,
arrival_time
:
Optional
[
float
]
=
None
,
...
...
@@ -48,7 +49,7 @@ class SimpleFrontend:
seqs
:
List
[
Sequence
]
=
[]
for
_
in
range
(
sampling_params
.
n
):
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
token_ids
,
block_size
=
self
.
block_size
)
seq
=
Sequence
(
seq_id
,
prompt
,
token_ids
,
block_size
=
self
.
block_size
)
seqs
.
append
(
seq
)
group_id
=
next
(
self
.
seq_group_counter
)
...
...
cacheflow/model_executor/input_metadata.py
View file @
55f8b0a5
from
typing
import
List
,
Dic
t
,
Tuple
from
typing
import
Dict
,
Lis
t
,
Tuple
import
torch
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceData
class
InputMetadata
:
def
__init__
(
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_
logprobs
:
Dict
[
int
,
flo
at
],
# Seq
id ->
cumulative logprobs
.
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
# List of (seq_ids, sampling_params).
seq_
data
:
Dict
[
int
,
SequenceD
at
a
],
# Seq
_
id ->
SequenceData
.
prompt_lens
:
List
[
int
],
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
...
...
@@ -19,7 +20,7 @@ class InputMetadata:
block_tables
:
torch
.
Tensor
,
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_
logprobs
=
seq_logprobs
self
.
seq_
data
=
seq_data
self
.
prompt_lens
=
prompt_lens
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
...
...
@@ -39,6 +40,7 @@ class InputMetadata:
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
def
__repr__
(
self
)
->
str
:
# Print only useful metadata.
return
(
f
'InputMetadata('
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
...
...
cacheflow/model_executor/layers/sampler.py
View file @
55f8b0a5
from
typing
import
Dict
,
List
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn
as
nn
...
...
@@ -31,6 +32,16 @@ class Sampler(nn.Module):
# Remove paddings in vocab (if any).
logits
=
logits
[:,
:
self
.
vocab_size
]
# Apply presence and frequency penalties.
output_tokens
=
_get_output_tokens
(
input_metadata
)
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
presence_penalties
,
frequency_penalties
=
_get_penalties
(
input_metadata
)
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
logits
=
_apply_penalties
(
logits
,
output_tokens
,
presence_penalties
,
frequency_penalties
,
self
.
vocab_size
)
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input_metadata
)
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
...
...
@@ -43,16 +54,14 @@ class Sampler(nn.Module):
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities (before applying top-p).
# Compute the log probabilities (before applying top-p
and top-k
).
logprobs
=
torch
.
log
(
probs
)
# Apply top-p and top-k truncation.
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
probs
.
shape
[
0
]
if
any
(
p
<
1.0
for
p
in
top_ps
)
or
any
(
k
!=
-
1
for
k
in
top_ks
):
p
=
torch
.
tensor
(
top_ps
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
k
=
torch
.
tensor
(
top_ks
,
dtype
=
torch
.
int
,
device
=
probs
.
device
)
probs
=
_apply_top_p_top_k
(
probs
,
p
,
k
)
probs
=
_apply_top_p_top_k
(
probs
,
top_ps
,
top_ks
)
# Sample the next tokens.
return
_sample
(
probs
,
logprobs
,
input_metadata
)
...
...
@@ -72,6 +81,93 @@ def _prune_hidden_states(
return
hidden_states
[
last_token_indicies
]
def
_get_penalties
(
input_metadata
:
InputMetadata
,
)
->
Tuple
[
List
[
float
],
List
[
float
]]:
# Collect the presence and frequency penalties.
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
presence_penalties
.
append
(
p
)
frequency_penalties
.
append
(
f
)
else
:
# A generation token.
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
return
presence_penalties
,
frequency_penalties
def
_get_output_tokens
(
input_metadata
:
InputMetadata
,
)
->
List
[
List
[
int
]]:
output_tokens
:
List
[
List
[
int
]]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
_
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
# NOTE: While the prompt input usually has no output tokens,
# it may have output tokens in the case of recomputation.
seq_id
=
seq_ids
[
0
]
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
output_tokens
.
append
(
seq_data
.
output_token_ids
)
else
:
# A generation token.
for
seq_id
in
seq_ids
:
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
output_tokens
.
append
(
seq_data
.
output_token_ids
)
return
output_tokens
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
output_tokens
:
List
[
List
[
int
]],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
vocab_size
:
int
,
)
->
torch
.
Tensor
:
num_seqs
=
logits
.
shape
[
0
]
# Collect the indices of sequences that have non-zero penalties.
indices
=
[]
for
i
in
range
(
num_seqs
):
if
not
output_tokens
[
i
]:
continue
p
=
presence_penalties
[
i
]
f
=
frequency_penalties
[
i
]
if
p
==
0.0
and
f
==
0.0
:
continue
indices
.
append
(
i
)
# Return early if all sequences have zero penalties.
if
not
indices
:
return
logits
bin_counts
=
[]
for
i
in
indices
:
bin_counts
.
append
(
np
.
bincount
(
output_tokens
[
i
],
minlength
=
vocab_size
))
bin_counts
=
np
.
stack
(
bin_counts
,
axis
=
0
)
bin_counts
=
torch
.
from_numpy
(
bin_counts
).
to
(
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
frequency_penalties
=
[
frequency_penalties
[
i
]
for
i
in
indices
]
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
presence_penalties
=
[
presence_penalties
[
i
]
for
i
in
indices
]
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
[
indices
]
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
bin_counts
presence_mask
=
(
bin_counts
>
0.0
).
to
(
dtype
=
logits
.
dtype
)
logits
[
indices
]
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
presence_mask
return
logits
def
_get_temperatures
(
input_metadata
:
InputMetadata
,
)
->
List
[
float
]:
...
...
@@ -121,10 +217,11 @@ def _get_top_p_top_k(
def
_apply_top_p_top_k
(
probs
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
top_ps
:
List
[
float
]
,
top_ks
:
List
[
int
]
,
)
->
torch
.
Tensor
:
# TODO(woosuk): Optimize.
p
=
torch
.
tensor
(
top_ps
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
k
=
torch
.
tensor
(
top_ks
,
dtype
=
torch
.
int
,
device
=
probs
.
device
)
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
# Apply top-p.
...
...
@@ -286,7 +383,8 @@ def _sample(
# Sample the next tokens.
seq_logprobs
=
[
input_metadata
.
seq_logprobs
[
seq_id
]
for
seq_id
in
seq_ids
]
input_metadata
.
seq_data
[
seq_id
].
cumulative_logprobs
for
seq_id
in
seq_ids
]
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
...
...
cacheflow/sampling_params.py
View file @
55f8b0a5
...
...
@@ -6,6 +6,8 @@ class SamplingParams:
def
__init__
(
self
,
n
:
int
,
presence_penalty
:
float
,
frequency_penalty
:
float
,
temperature
:
float
,
top_p
:
float
,
top_k
:
int
,
...
...
@@ -16,6 +18,12 @@ class SamplingParams:
)
->
None
:
if
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
n
}
."
)
if
not
-
2.0
<=
presence_penalty
<=
2.0
:
raise
ValueError
(
f
"presence_penalty must be in [-2, 2], got
{
presence_penalty
}
."
)
if
not
-
2.0
<=
frequency_penalty
<=
2.0
:
raise
ValueError
(
f
"frequency_penalty must be in [-2, 2], got
{
frequency_penalty
}
."
)
if
temperature
<
0.0
:
raise
ValueError
(
f
"temperature must be non-negative, got
{
temperature
}
."
)
...
...
@@ -57,6 +65,8 @@ class SamplingParams:
"top_k must be -1 when using greedy sampling."
)
self
.
n
=
n
self
.
presence_penalty
=
presence_penalty
self
.
frequency_penalty
=
frequency_penalty
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_k
=
top_k
...
...
@@ -67,6 +77,8 @@ class SamplingParams:
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
,"
...
...
@@ -77,13 +89,18 @@ class SamplingParams:
@
classmethod
def
from_dict
(
cls
,
d
:
Dict
)
->
"SamplingParams"
:
return
cls
(
n
=
d
.
get
(
"n"
,
1
),
temperature
=
d
.
get
(
"temperature"
,
1.0
),
top_p
=
d
.
get
(
"top_p"
,
1.0
),
top_k
=
d
.
get
(
"top_k"
,
-
1
),
use_beam_search
=
d
.
get
(
"use_beam_search"
,
False
),
stop_token_ids
=
set
(
d
.
get
(
"stop_token_ids"
,
set
())),
max_num_steps
=
d
.
get
(
"max_num_steps"
,
16
),
num_logprobs
=
d
.
get
(
"num_logprobs"
,
0
),
sampling_params
=
cls
(
n
=
d
.
pop
(
"n"
,
1
),
presence_penalty
=
d
.
pop
(
"presence_penalty"
,
0.0
),
frequency_penalty
=
d
.
pop
(
"frequency_penalty"
,
0.0
),
temperature
=
d
.
pop
(
"temperature"
,
1.0
),
top_p
=
d
.
pop
(
"top_p"
,
1.0
),
top_k
=
d
.
pop
(
"top_k"
,
-
1
),
use_beam_search
=
d
.
pop
(
"use_beam_search"
,
False
),
stop_token_ids
=
set
(
d
.
pop
(
"stop_token_ids"
,
set
())),
max_num_steps
=
d
.
pop
(
"max_num_steps"
,
16
),
num_logprobs
=
d
.
pop
(
"num_logprobs"
,
0
),
)
if
d
:
raise
ValueError
(
f
"Unrecognized keys in dict:
{
d
.
keys
()
}
"
)
return
sampling_params
cacheflow/sequence.py
View file @
55f8b0a5
...
...
@@ -13,26 +13,55 @@ class SequenceStatus(enum.Enum):
FINISHED
=
enum
.
auto
()
class
SequenceData
:
def
__init__
(
self
,
prompt_token_ids
:
List
[
int
],
)
->
None
:
self
.
prompt_token_ids
=
prompt_token_ids
self
.
output_token_ids
:
List
[
int
]
=
[]
self
.
cumulative_logprobs
=
0.0
def
get_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
+
len
(
self
.
prompt_token_ids
)
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
prompt_token_ids
+
self
.
output_token_ids
def
get_last_token_id
(
self
)
->
int
:
if
not
self
.
output_token_ids
:
return
self
.
prompt_token_ids
[
-
1
]
return
self
.
output_token_ids
[
-
1
]
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceData("
f
"prompt=
{
self
.
prompt
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"output_token_ids=
{
self
.
output_token_ids
}
)"
)
class
Sequence
:
def
__init__
(
self
,
seq_id
:
int
,
prompt
:
str
,
prompt_token_ids
:
List
[
int
],
block_size
:
int
,
)
->
None
:
self
.
seq_id
=
seq_id
self
.
prompt
=
prompt
self
.
block_size
=
block_size
self
.
prompt_len
=
len
(
prompt_token_ids
)
self
.
data
=
SequenceData
(
prompt_token_ids
)
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
# Initialize the logical token blocks with the prompt token ids.
self
.
_append_tokens
(
prompt_token_ids
)
self
.
_append_tokens_to_blocks
(
prompt_token_ids
)
self
.
status
=
SequenceStatus
.
WAITING
# Used for beam search.
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
cumulative_logprobs
=
0.0
def
_append_logical_block
(
self
)
->
None
:
block
=
LogicalTokenBlock
(
...
...
@@ -41,7 +70,7 @@ class Sequence:
)
self
.
logical_token_blocks
.
append
(
block
)
def
_append_tokens
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
_append_tokens
_to_blocks
(
self
,
token_ids
:
List
[
int
])
->
None
:
while
token_ids
:
if
not
self
.
logical_token_blocks
:
self
.
_append_logical_block
()
...
...
@@ -57,26 +86,24 @@ class Sequence:
def
append_token
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
float
])
->
None
:
assert
token_id
in
logprobs
self
.
_append_tokens
([
token_id
])
self
.
_append_tokens
_to_blocks
([
token_id
])
self
.
output_logprobs
.
append
(
logprobs
)
self
.
cumulative_logprobs
+=
logprobs
[
token_id
]
self
.
data
.
output_token_ids
.
append
(
token_id
)
self
.
data
.
cumulative_logprobs
+=
logprobs
[
token_id
]
def
get_len
(
self
)
->
int
:
return
s
um
(
block
.
num_tokens
for
block
in
self
.
logical_token_blocks
)
return
s
elf
.
data
.
get_len
(
)
def
get_token_ids
(
self
)
->
List
[
int
]:
token_ids
:
List
[
int
]
=
[]
for
block
in
self
.
logical_token_blocks
:
token_ids
.
extend
(
block
.
get_token_ids
())
return
token_ids
return
self
.
data
.
get_token_ids
()
def
get_last_token_id
(
self
)
->
int
:
return
self
.
logical_token_blocks
[
-
1
]
.
get_last_token_id
()
return
self
.
data
.
get_last_token_id
()
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
'Sequence'
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
cumulative_logprobs
=
self
.
cumulative_logprobs
child_seq
.
data
=
copy
.
deepcopy
(
self
.
data
)
def
__repr__
(
self
)
->
str
:
return
(
f
'Sequence(seq_id=
{
self
.
seq_id
}
, '
...
...
@@ -128,17 +155,13 @@ class SequenceGroupMetadata:
self
,
group_id
:
int
,
is_prompt
:
bool
,
input_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> token ids.
context_len
:
int
,
seq_logprobs
:
Dict
[
int
,
float
],
# Seq id -> cumulative logprobs.
seq_data
:
Dict
[
int
,
SequenceData
],
# Seq id -> sequence data.
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id ->
L
ist of physical block numbers.
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id ->
l
ist of physical block numbers.
)
->
None
:
self
.
group_id
=
group_id
self
.
is_prompt
=
is_prompt
self
.
input_tokens
=
input_tokens
self
.
context_len
=
context_len
self
.
seq_logprobs
=
seq_logprobs
self
.
seq_data
=
seq_data
self
.
sampling_params
=
sampling_params
self
.
block_tables
=
block_tables
...
...
cacheflow/worker/worker.py
View file @
55f8b0a5
from
typing
import
Dict
,
List
,
Tuple
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -8,8 +8,8 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
initialize_all_reduce_launcher
,
get_tensor_model_parallel_world_size
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceGroupMetadata
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.sequence
import
(
SequenceData
,
SequenceGroupMetadata
,
SequenceOutputs
)
from
cacheflow.worker.cache_engine
import
CacheEngine
...
...
@@ -72,7 +72,6 @@ class Worker:
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
def
init_distributed_environment
(
self
,
distributed_init_method
:
str
,
rank
:
int
,
...
...
@@ -96,7 +95,6 @@ class Worker:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
...
...
@@ -107,15 +105,15 @@ class Worker:
if
not
seq_group_metadata
.
is_prompt
:
continue
seq_ids
=
list
(
seq_group_metadata
.
input_tokens
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
seq_group_metadata
.
seq_logprobs
)
# Use any sequence in the group.
seq_id
=
seq_ids
[
0
]
prompt_tokens
=
seq_group_metadata
.
input_tokens
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
...
...
@@ -141,27 +139,26 @@ class Worker:
if
seq_group_metadata
.
is_prompt
:
continue
seq_ids
=
list
(
seq_group_metadata
.
input_tokens
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
seq_group_metadata
.
seq_logprobs
)
for
seq_id
in
seq_ids
:
assert
len
(
seq_group_metadata
.
input_tokens
[
seq_id
]
)
==
1
generation_token
=
seq_
group_metadata
.
input_tokens
[
seq_id
][
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_
data
.
get_last_token_id
()
input_tokens
.
append
(
generation_token
)
position
=
seq_group_metadata
.
context_len
-
1
context_len
=
seq_data
.
get_len
()
position
=
context_len
-
1
input_positions
.
append
(
position
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
generation_block_tables
.
append
(
block_table
)
max_context_len
=
max
(
max_context_len
,
seq_group_metadata
.
context_len
)
max_context_len
=
max
(
max_context_len
,
context_len
)
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
,
len
(
block_table
))
context_lens
.
append
(
seq_group_metadata
.
context_len
)
context_lens
.
append
(
context_len
)
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
...
...
@@ -188,9 +185,13 @@ class Worker:
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
input_metadata
=
InputMetadata
(
seq_groups
=
seq_groups
,
seq_
logprobs
=
seq_logprobs
,
seq_
data
=
seq_data
,
prompt_lens
=
prompt_lens
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
...
...
simple_server.py
View file @
55f8b0a5
...
...
@@ -11,8 +11,8 @@ def main(args: argparse.Namespace):
# Test the following inputs.
test_inputs
=
[
(
"A robot may not injure a human being"
,
{}),
# Use default parameters.
(
"To be or not to be,"
,
{
"temperature"
:
0.8
,
"top_k"
:
5
}),
(
"What is the meaning of life?"
,
{
"n"
:
2
,
"temperature"
:
0.8
,
"top_p"
:
0.95
}),
(
"To be or not to be,"
,
{
"temperature"
:
0.8
,
"top_k"
:
5
,
"presence_penalty"
:
0.2
}),
(
"What is the meaning of life?"
,
{
"n"
:
2
,
"temperature"
:
0.8
,
"top_p"
:
0.95
,
"frequency_penalty"
:
0.1
}),
(
"It is only with the heart that one can see rightly"
,
{
"n"
:
3
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
}),
]
while
True
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment