Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1a2aef3e
Unverified
Commit
1a2aef3e
authored
Sep 23, 2024
by
Alexander Matveev
Committed by
GitHub
Sep 23, 2024
Browse files
Add output streaming support to multi-step + async while ensuring RequestOutput obj reuse (#8335)
parent
5f7bb584
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
142 additions
and
42 deletions
+142
-42
tests/entrypoints/openai/test_accuracy.py
tests/entrypoints/openai/test_accuracy.py
+5
-1
vllm/config.py
vllm/config.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+6
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+28
-9
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+8
-1
vllm/outputs.py
vllm/outputs.py
+74
-22
vllm/sequence.py
vllm/sequence.py
+19
-9
No files found.
tests/entrypoints/openai/test_accuracy.py
View file @
1a2aef3e
...
@@ -19,7 +19,11 @@ FILTER = "exact_match,strict-match"
...
@@ -19,7 +19,11 @@ FILTER = "exact_match,strict-match"
RTOL
=
0.03
RTOL
=
0.03
EXPECTED_VALUE
=
0.58
EXPECTED_VALUE
=
0.58
DEFAULT_ARGS
=
[
"--max-model-len"
,
"4096"
,
"--disable-log-requests"
]
DEFAULT_ARGS
=
[
"--max-model-len"
,
"4096"
,
"--disable-log-requests"
]
MORE_ARGS_LIST
=
[[
"--enable-chunked-prefill"
],
[
"--num-scheduler-steps"
,
"8"
]]
MORE_ARGS_LIST
=
[
[
"--enable-chunked-prefill"
],
# Chunked
[
"--num-scheduler-steps"
,
"8"
],
# MS
[
"--num-scheduler-steps"
,
"8"
,
"--multi-step-stream-outputs"
]
# MS+Stream
]
@
pytest
.
mark
.
parametrize
(
"more_args"
,
MORE_ARGS_LIST
)
@
pytest
.
mark
.
parametrize
(
"more_args"
,
MORE_ARGS_LIST
)
...
...
vllm/config.py
View file @
1a2aef3e
...
@@ -960,6 +960,7 @@ class SchedulerConfig:
...
@@ -960,6 +960,7 @@ class SchedulerConfig:
is_multimodal_model
:
bool
=
False
,
is_multimodal_model
:
bool
=
False
,
preemption_mode
:
Optional
[
str
]
=
None
,
preemption_mode
:
Optional
[
str
]
=
None
,
num_scheduler_steps
:
int
=
1
,
num_scheduler_steps
:
int
=
1
,
multi_step_stream_outputs
:
bool
=
False
,
send_delta_data
:
bool
=
False
)
->
None
:
send_delta_data
:
bool
=
False
)
->
None
:
if
max_num_batched_tokens
is
None
:
if
max_num_batched_tokens
is
None
:
if
enable_chunked_prefill
:
if
enable_chunked_prefill
:
...
@@ -1000,6 +1001,7 @@ class SchedulerConfig:
...
@@ -1000,6 +1001,7 @@ class SchedulerConfig:
self
.
embedding_mode
=
embedding_mode
self
.
embedding_mode
=
embedding_mode
self
.
preemption_mode
=
preemption_mode
self
.
preemption_mode
=
preemption_mode
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
multi_step_stream_outputs
=
multi_step_stream_outputs
self
.
send_delta_data
=
send_delta_data
self
.
send_delta_data
=
send_delta_data
self
.
_verify_args
()
self
.
_verify_args
()
...
...
vllm/engine/arg_utils.py
View file @
1a2aef3e
...
@@ -145,6 +145,7 @@ class EngineArgs:
...
@@ -145,6 +145,7 @@ class EngineArgs:
max_cpu_loras
:
Optional
[
int
]
=
None
max_cpu_loras
:
Optional
[
int
]
=
None
device
:
str
=
'auto'
device
:
str
=
'auto'
num_scheduler_steps
:
int
=
1
num_scheduler_steps
:
int
=
1
multi_step_stream_outputs
:
bool
=
False
ray_workers_use_nsight
:
bool
=
False
ray_workers_use_nsight
:
bool
=
False
num_gpu_blocks_override
:
Optional
[
int
]
=
None
num_gpu_blocks_override
:
Optional
[
int
]
=
None
num_lookahead_slots
:
int
=
0
num_lookahead_slots
:
int
=
0
...
@@ -595,6 +596,10 @@ class EngineArgs:
...
@@ -595,6 +596,10 @@ class EngineArgs:
help
=
(
'Maximum number of forward steps per '
help
=
(
'Maximum number of forward steps per '
'scheduler call.'
))
'scheduler call.'
))
parser
.
add_argument
(
'--multi-step-stream-outputs'
,
action
=
'store_true'
,
help
=
'If True, then multi-step will stream outputs for every step'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--scheduler-delay-factor'
,
'--scheduler-delay-factor'
,
type
=
float
,
type
=
float
,
...
@@ -999,6 +1004,7 @@ class EngineArgs:
...
@@ -999,6 +1004,7 @@ class EngineArgs:
is_multimodal_model
=
model_config
.
is_multimodal_model
,
is_multimodal_model
=
model_config
.
is_multimodal_model
,
preemption_mode
=
self
.
preemption_mode
,
preemption_mode
=
self
.
preemption_mode
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
multi_step_stream_outputs
=
self
.
multi_step_stream_outputs
,
send_delta_data
=
(
envs
.
VLLM_USE_RAY_SPMD_WORKER
send_delta_data
=
(
envs
.
VLLM_USE_RAY_SPMD_WORKER
and
parallel_config
.
use_ray
),
and
parallel_config
.
use_ray
),
)
)
...
...
vllm/engine/llm_engine.py
View file @
1a2aef3e
...
@@ -95,7 +95,7 @@ class OutputData(NamedTuple):
...
@@ -95,7 +95,7 @@ class OutputData(NamedTuple):
class
SchedulerContext
:
class
SchedulerContext
:
def
__init__
(
self
):
def
__init__
(
self
,
multi_step_stream_outputs
:
bool
=
False
):
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
EmbeddingRequestOutput
]]
=
[]
...
@@ -103,6 +103,8 @@ class SchedulerContext:
...
@@ -103,6 +103,8 @@ class SchedulerContext:
List
[
SequenceGroupMetadata
]]
=
None
List
[
SequenceGroupMetadata
]]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
self
.
multi_step_stream_outputs
:
bool
=
multi_step_stream_outputs
def
append_output
(
self
,
outputs
:
List
[
SamplerOutput
],
def
append_output
(
self
,
outputs
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduler_outputs
:
SchedulerOutputs
,
is_async
:
bool
,
scheduler_outputs
:
SchedulerOutputs
,
is_async
:
bool
,
...
@@ -219,6 +221,7 @@ class LLMEngine:
...
@@ -219,6 +221,7 @@ class LLMEngine:
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
)
->
None
:
)
->
None
:
logger
.
info
(
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
"Initializing an LLM engine (v%s) with config: "
...
@@ -234,8 +237,9 @@ class LLMEngine:
...
@@ -234,8 +237,9 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s)"
,
"enable_prefix_caching=%s, use_async_output_proc=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)"
,
VLLM_VERSION
,
VLLM_VERSION
,
model_config
.
model
,
model_config
.
model
,
speculative_config
,
speculative_config
,
...
@@ -266,8 +270,10 @@ class LLMEngine:
...
@@ -266,8 +270,10 @@ class LLMEngine:
model_config
.
served_model_name
,
model_config
.
served_model_name
,
scheduler_config
.
use_v2_block_manager
,
scheduler_config
.
use_v2_block_manager
,
scheduler_config
.
num_scheduler_steps
,
scheduler_config
.
num_scheduler_steps
,
scheduler_config
.
multi_step_stream_outputs
,
cache_config
.
enable_prefix_caching
,
cache_config
.
enable_prefix_caching
,
model_config
.
use_async_output_proc
,
model_config
.
use_async_output_proc
,
use_cached_outputs
,
model_config
.
mm_processor_kwargs
,
model_config
.
mm_processor_kwargs
,
)
)
# TODO(woosuk): Print more configs in debug mode.
# TODO(woosuk): Print more configs in debug mode.
...
@@ -287,6 +293,7 @@ class LLMEngine:
...
@@ -287,6 +293,7 @@ class LLMEngine:
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
)
)
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
if
not
self
.
model_config
.
skip_tokenizer_init
:
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
tokenizer
=
self
.
_init_tokenizer
()
...
@@ -379,7 +386,8 @@ class LLMEngine:
...
@@ -379,7 +386,8 @@ class LLMEngine:
]
]
self
.
scheduler_contexts
=
[
self
.
scheduler_contexts
=
[
SchedulerContext
()
SchedulerContext
(
multi_step_stream_outputs
=
self
.
scheduler_config
.
multi_step_stream_outputs
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
]
...
@@ -998,7 +1006,8 @@ class LLMEngine:
...
@@ -998,7 +1006,8 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
...
@@ -1019,8 +1028,8 @@ class LLMEngine:
...
@@ -1019,8 +1028,8 @@ class LLMEngine:
for
scheduler
in
self
.
scheduler
:
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
scheduler
.
free_finished_seq_groups
()
# For multi-step, do
no
t create outputs each iteration
# For multi-step
without streaming
, do
n'
t create outputs each iteration
if
not
is_last_step
:
if
not
is_last_step
and
not
ctx
.
multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
# Immediately process request outputs here (if callback is given)
if
(
finished_now
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
and
self
.
process_request_outputs_callback
is
not
None
):
...
@@ -1037,17 +1046,27 @@ class LLMEngine:
...
@@ -1037,17 +1046,27 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
# For multi-step with streaming, create outputs each iteration
if
not
is_last_step
and
ctx
.
multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
if
self
.
process_request_outputs_callback
is
not
None
:
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
return
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
params
=
seq_group
.
sampling_params
params
=
seq_group
.
sampling_params
if
params
is
not
None
and
params
.
output_kind
==
(
if
params
is
not
None
and
params
.
output_kind
==
(
RequestOutputKind
.
DELTA
)
and
not
seq_group
.
is_finished
():
RequestOutputKind
.
DELTA
)
and
not
seq_group
.
is_finished
():
continue
continue
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
...
...
vllm/engine/multiprocessing/engine.py
View file @
1a2aef3e
...
@@ -66,7 +66,14 @@ class MQLLMEngine:
...
@@ -66,7 +66,14 @@ class MQLLMEngine:
*
args
,
*
args
,
log_requests
:
bool
=
True
,
log_requests
:
bool
=
True
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
)
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs
=
True
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
,
use_cached_outputs
=
use_cached_outputs
)
self
.
log_requests
=
log_requests
self
.
log_requests
=
log_requests
self
.
use_async_sockets
=
use_async_sockets
self
.
use_async_sockets
=
use_async_sockets
...
...
vllm/outputs.py
View file @
1a2aef3e
...
@@ -114,17 +114,28 @@ class RequestOutput:
...
@@ -114,17 +114,28 @@ class RequestOutput:
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
@
classmethod
@
classmethod
def
from_seq_group
(
cls
,
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
,
se
q_group
:
SequenceGroup
)
->
Optional
[
"RequestOutput"
]:
u
se
_cache
:
bool
)
->
Optional
[
"RequestOutput"
]:
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
is
None
:
if
sampling_params
is
None
:
raise
ValueError
(
raise
ValueError
(
"Sampling parameters are missing for a CompletionRequest."
)
"Sampling parameters are missing for a CompletionRequest."
)
finished
=
seq_group
.
is_finished
()
finished
=
seq_group
.
is_finished
()
if
sampling_params
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
and
(
if
sampling_params
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
and
(
not
finished
):
not
finished
):
return
None
return
None
# Init cache (if needed)
if
use_cache
and
seq_group
.
cached_request_output
is
None
:
seq_group
.
cached_request_output
=
RequestOutput
(
# type: ignore
request_id
=
""
,
prompt
=
None
,
prompt_token_ids
=
[],
prompt_logprobs
=
None
,
outputs
=
[],
finished
=
False
)
seqs
=
seq_group
.
get_seqs
()
seqs
=
seq_group
.
get_seqs
()
if
len
(
seqs
)
==
1
:
if
len
(
seqs
)
==
1
:
top_n_seqs
=
seqs
top_n_seqs
=
seqs
...
@@ -149,29 +160,66 @@ class RequestOutput:
...
@@ -149,29 +160,66 @@ class RequestOutput:
outputs
=
[]
outputs
=
[]
include_prompt
=
True
include_prompt
=
True
for
seq
in
top_n_seqs
:
for
i
,
seq
in
enumerate
(
top_n_seqs
)
:
output_text
=
seq
.
get_output_text_to_return
(
output_text
=
seq
.
get_output_text_to_return
(
text_buffer_length
,
delta
)
text_buffer_length
,
delta
)
output_token_ids
=
seq
.
get_output_token_ids_to_return
(
delta
)
output_token_ids
=
seq
.
get_output_token_ids_to_return
(
delta
)
num_output_tokens
=
1
if
isinstance
(
output_token_ids
,
int
)
else
len
(
output_token_ids
)
output_logprobs
=
seq
.
output_logprobs
if
include_logprobs
else
None
output_logprobs
=
seq
.
output_logprobs
if
include_logprobs
else
None
if
delta
:
if
delta
:
# Slice logprobs delta if applicable
# Slice logprobs delta if applicable
if
output_logprobs
:
if
output_logprobs
:
output_logprobs
=
output_logprobs
[
-
len
(
output_token
_ids
)
:]
output_logprobs
=
output_logprobs
[
-
num_
output_token
s
:]
# Don't include prompt if this is after the first output
# Don't include prompt if this is after the first output
# containing decode token ids
# containing decode token ids
if
include_prompt
and
seq
.
get_output_len
()
>
len
(
if
include_prompt
and
seq
.
get_output_len
()
>
num_output_tokens
:
output_token_ids
):
include_prompt
=
False
include_prompt
=
False
outputs
.
append
(
if
use_cache
:
CompletionOutput
(
# Get cached output object
seqs
.
index
(
seq
),
output_text
,
output_token_ids
,
cached_outputs
=
seq_group
.
cached_request_output
.
outputs
# type: ignore
if
i
>=
len
(
cached_outputs
):
cached_outputs
.
append
(
CompletionOutput
(
index
=
i
,
text
=
""
,
token_ids
=
[],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
None
,
stop_reason
=
None
))
output
=
cached_outputs
[
i
]
# Init cached output object
assert
output
.
index
==
i
output
.
text
=
output_text
if
isinstance
(
output_token_ids
,
int
):
output
.
token_ids
.
clear
()
output
.
token_ids
.
append
(
output_token_ids
)
else
:
output
.
token_ids
=
output_token_ids
output
.
cumulative_logprob
=
seq
.
get_cumulative_logprob
()
\
if
include_logprobs
else
None
output
.
logprobs
=
output_logprobs
output
.
finish_reason
=
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
output
.
stop_reason
=
seq
.
stop_reason
else
:
output
=
CompletionOutput
(
seqs
.
index
(
seq
),
output_text
,
[
output_token_ids
]
if
isinstance
(
output_token_ids
,
int
)
else
output_token_ids
,
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
output_logprobs
,
output_logprobs
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
stop_reason
))
seq
.
stop_reason
)
outputs
.
append
(
output
)
# Every sequence in the sequence group should have the same prompt.
# Every sequence in the sequence group should have the same prompt.
if
include_prompt
:
if
include_prompt
:
...
@@ -188,16 +236,20 @@ class RequestOutput:
...
@@ -188,16 +236,20 @@ class RequestOutput:
prompt_logprobs
=
None
prompt_logprobs
=
None
finished_time
=
time
.
time
()
if
finished
else
None
finished_time
=
time
.
time
()
if
finished
else
None
seq_group
.
set_finished_time
(
finished_time
)
seq_group
.
set_finished_time
(
finished_time
)
return
cls
(
seq_group
.
request_id
,
prompt
,
init_args
=
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
prompt_token_ids
,
prompt_logprobs
,
outputs
,
finished
,
seq_group
.
metrics
,
prompt_logprobs
,
seq_group
.
lora_request
,
encoder_prompt
,
outputs
,
encoder_prompt_token_ids
)
finished
,
seq_group
.
metrics
,
if
use_cache
:
lora_request
=
seq_group
.
lora_request
,
request_output
=
seq_group
.
cached_request_output
encoder_prompt
=
encoder_prompt
,
request_output
.
__init__
(
*
init_args
)
# type: ignore
encoder_prompt_token_ids
=
encoder_prompt_token_ids
)
else
:
request_output
=
cls
(
*
init_args
)
return
request_output
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
...
@@ -261,10 +313,10 @@ class EmbeddingRequestOutput:
...
@@ -261,10 +313,10 @@ class EmbeddingRequestOutput:
class
RequestOutputFactory
:
class
RequestOutputFactory
:
@
staticmethod
@
staticmethod
def
create
(
seq_group
):
def
create
(
seq_group
:
SequenceGroup
,
use_cache
:
bool
=
False
):
# Determine the type based on a condition, for example:
# Determine the type based on a condition, for example:
if
hasattr
(
seq_group
,
if
hasattr
(
seq_group
,
'embeddings'
)
and
seq_group
.
embeddings
is
not
None
:
'embeddings'
)
and
seq_group
.
embeddings
is
not
None
:
return
EmbeddingRequestOutput
.
from_seq_group
(
seq_group
)
return
EmbeddingRequestOutput
.
from_seq_group
(
seq_group
)
else
:
else
:
return
RequestOutput
.
from_seq_group
(
seq_group
)
return
RequestOutput
.
from_seq_group
(
seq_group
,
use_cache
)
vllm/sequence.py
View file @
1a2aef3e
...
@@ -436,7 +436,7 @@ class Sequence:
...
@@ -436,7 +436,7 @@ class Sequence:
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
# These are used to keep track of delta outputs
# These are used to keep track of delta outputs
self
.
_last_token_ids_offset
:
int
=
0
self
.
_last_
output_
token_ids_offset
:
int
=
0
self
.
_last_output_text_offset
:
int
=
0
self
.
_last_output_text_offset
:
int
=
0
# Used for incremental detokenization
# Used for incremental detokenization
...
@@ -499,18 +499,26 @@ class Sequence:
...
@@ -499,18 +499,26 @@ class Sequence:
return
self
.
output_text
[
last_offset
:
length
]
return
self
.
output_text
[
last_offset
:
length
]
return
""
return
""
def
get_output_token_ids_to_return
(
self
,
def
get_output_token_ids_to_return
(
delta
:
bool
)
->
GenericSequence
[
int
]:
self
,
delta
:
bool
)
->
Union
[
GenericSequence
[
int
],
int
]:
"""If delta is True, only new tokens since the last call to
"""If delta is True, only new tokens since the last call to
this method are returned"""
this method are returned"""
if
not
delta
:
if
not
delta
:
return
self
.
get_output_token_ids
()
return
self
.
get_output_token_ids
()
length
=
self
.
get_output_len
()
last_offset
=
self
.
_last_token_ids_offset
output_len
=
self
.
get_output_len
()
if
last_offset
<
length
:
self
.
_last_token_ids_offset
=
length
# Get the number of new tokens
return
self
.
data
.
_output_token_ids
[
last_offset
:]
num_new_tokens
=
output_len
-
self
.
_last_output_token_ids_offset
return
()
self
.
_last_output_token_ids_offset
=
output_len
# Return new tokens
if
num_new_tokens
==
1
:
# Optimization for single decode token case
# (which is what we have most of the time)
return
self
.
data
.
_cached_all_token_ids
[
-
1
]
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:]
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
# TODO This can produce incorrect hash when block size > prompt size
# TODO This can produce incorrect hash when block size > prompt size
...
@@ -671,6 +679,8 @@ class SequenceGroup:
...
@@ -671,6 +679,8 @@ class SequenceGroup:
self
.
encoder_seq
=
encoder_seq
self
.
encoder_seq
=
encoder_seq
self
.
trace_headers
=
trace_headers
self
.
trace_headers
=
trace_headers
self
.
cached_request_output
=
None
@
property
@
property
def
prompt
(
self
)
->
Optional
[
str
]:
def
prompt
(
self
)
->
Optional
[
str
]:
# All sequences in the group should have the same prompt.
# All sequences in the group should have the same prompt.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment