Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
8d66a7b6
Unverified
Commit
8d66a7b6
authored
May 10, 2023
by
Woosuk Kwon
Committed by
GitHub
May 10, 2023
Browse files
Rename variables and methods (#91)
parent
ce26e57f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
64 additions
and
83 deletions
+64
-83
cacheflow/block.py
cacheflow/block.py
+1
-1
cacheflow/core/block_manager.py
cacheflow/core/block_manager.py
+4
-4
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+12
-12
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+2
-10
cacheflow/sequence.py
cacheflow/sequence.py
+13
-12
cacheflow/worker/controller.py
cacheflow/worker/controller.py
+3
-15
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+29
-29
No files found.
cacheflow/block.py
View file @
8d66a7b6
...
@@ -27,7 +27,7 @@ class LogicalTokenBlock:
...
@@ -27,7 +27,7 @@ class LogicalTokenBlock:
def
is_full
(
self
)
->
bool
:
def
is_full
(
self
)
->
bool
:
return
self
.
num_tokens
==
self
.
block_size
return
self
.
num_tokens
==
self
.
block_size
def
append
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
append
_tokens
(
self
,
token_ids
:
List
[
int
])
->
None
:
assert
len
(
token_ids
)
<=
self
.
get_num_empty_slots
()
assert
len
(
token_ids
)
<=
self
.
get_num_empty_slots
()
self
.
token_ids
[
self
.
num_tokens
:
self
.
num_tokens
+
len
(
token_ids
)]
=
token_ids
self
.
token_ids
[
self
.
num_tokens
:
self
.
num_tokens
+
len
(
token_ids
)]
=
token_ids
self
.
num_tokens
+=
len
(
token_ids
)
self
.
num_tokens
+=
len
(
token_ids
)
...
...
cacheflow/core/block_manager.py
View file @
8d66a7b6
...
@@ -97,15 +97,15 @@ class BlockSpaceManager:
...
@@ -97,15 +97,15 @@ class BlockSpaceManager:
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
seqs
:
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
def
can_append
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_append
_slot
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
# Simple heuristic: If there is at least one free block
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
# for each sequence, we can append.
num_free_gpu_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
num_free_gpu_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
num_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
num_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
return
num_seqs
<=
num_free_gpu_blocks
return
num_seqs
<=
num_free_gpu_blocks
def
append
(
self
,
seq
:
Sequence
)
->
Optional
[
Tuple
[
int
,
int
]]:
def
append
_slot
(
self
,
seq
:
Sequence
)
->
Optional
[
Tuple
[
int
,
int
]]:
"""Allocate a physical slot for
the
new token."""
"""Allocate a physical slot for
a
new token."""
logical_blocks
=
seq
.
logical_token_blocks
logical_blocks
=
seq
.
logical_token_blocks
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
@@ -156,7 +156,7 @@ class BlockSpaceManager:
...
@@ -156,7 +156,7 @@ class BlockSpaceManager:
num_free_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
num_free_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
# NOTE: Conservatively, we assume that every sequence will allocate
# NOTE: Conservatively, we assume that every sequence will allocate
# at least one free block right after the swap-in.
# at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append().
# NOTE: This should match the logic in can_append
_slot
().
num_required_blocks
=
len
(
blocks
)
+
num_swapped_seqs
num_required_blocks
=
len
(
blocks
)
+
num_swapped_seqs
return
num_free_blocks
-
num_required_blocks
>=
self
.
watermark_blocks
return
num_free_blocks
-
num_required_blocks
>=
self
.
watermark_blocks
...
...
cacheflow/core/scheduler.py
View file @
8d66a7b6
...
@@ -9,7 +9,7 @@ from cacheflow.core.policy import PolicyFactory
...
@@ -9,7 +9,7 @@ from cacheflow.core.policy import PolicyFactory
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
from
cacheflow.sequence
import
Sequence
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sequence
import
SequenceGroup
Inputs
from
cacheflow.sequence
import
SequenceGroup
Metadata
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.sequence
import
SequenceStatus
from
cacheflow.sequence
import
SequenceStatus
...
@@ -105,7 +105,7 @@ class Scheduler:
...
@@ -105,7 +105,7 @@ class Scheduler:
preempted
:
List
[
SequenceGroup
]
=
[]
preempted
:
List
[
SequenceGroup
]
=
[]
while
self
.
running
:
while
self
.
running
:
seq_group
=
self
.
running
.
pop
(
0
)
seq_group
=
self
.
running
.
pop
(
0
)
while
not
self
.
block_manager
.
can_append
(
seq_group
):
while
not
self
.
block_manager
.
can_append
_slot
(
seq_group
):
if
self
.
running
:
if
self
.
running
:
# Preempt the lowest-priority sequence groups.
# Preempt the lowest-priority sequence groups.
victim_seq_group
=
self
.
running
.
pop
(
-
1
)
victim_seq_group
=
self
.
running
.
pop
(
-
1
)
...
@@ -119,7 +119,7 @@ class Scheduler:
...
@@ -119,7 +119,7 @@ class Scheduler:
break
break
else
:
else
:
# Append new slots to the sequence group.
# Append new slots to the sequence group.
self
.
_append
(
seq_group
,
blocks_to_copy
)
self
.
_append
_slot
(
seq_group
,
blocks_to_copy
)
running
.
append
(
seq_group
)
running
.
append
(
seq_group
)
self
.
running
=
running
self
.
running
=
running
...
@@ -143,7 +143,7 @@ class Scheduler:
...
@@ -143,7 +143,7 @@ class Scheduler:
seq_group
=
self
.
swapped
.
pop
(
0
)
seq_group
=
self
.
swapped
.
pop
(
0
)
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_append
(
seq_group
,
blocks_to_copy
)
self
.
_append
_slot
(
seq_group
,
blocks_to_copy
)
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_batched_tokens
=
sum
(
num_batched_tokens
=
sum
(
...
@@ -252,7 +252,7 @@ class Scheduler:
...
@@ -252,7 +252,7 @@ class Scheduler:
prompt_group_ids
=
scheduler_output
[
3
]
prompt_group_ids
=
scheduler_output
[
3
]
# Create input data structures.
# Create input data structures.
input_
seq_group
s
:
List
[
SequenceGroup
Inputs
]
=
[]
seq_group
_metadata_list
:
List
[
SequenceGroup
Metadata
]
=
[]
updated_seq_groups
:
List
[
SequenceGroup
]
=
self
.
running
.
copy
()
updated_seq_groups
:
List
[
SequenceGroup
]
=
self
.
running
.
copy
()
for
seq_group
in
self
.
running
:
for
seq_group
in
self
.
running
:
...
@@ -274,7 +274,7 @@ class Scheduler:
...
@@ -274,7 +274,7 @@ class Scheduler:
# sequence length
# sequence length
seq_len
=
seq
.
get_len
()
seq_len
=
seq
.
get_len
()
input_
seq_group
=
SequenceGroup
Inputs
(
seq_group
_metadata
=
SequenceGroup
Metadata
(
group_id
=
group_id
,
group_id
=
group_id
,
is_prompt
=
is_prompt
,
is_prompt
=
is_prompt
,
input_tokens
=
input_tokens
,
input_tokens
=
input_tokens
,
...
@@ -283,14 +283,14 @@ class Scheduler:
...
@@ -283,14 +283,14 @@ class Scheduler:
sampling_params
=
self
.
sampling_params
[
group_id
],
sampling_params
=
self
.
sampling_params
[
group_id
],
block_tables
=
block_tables
,
block_tables
=
block_tables
,
)
)
input_
seq_group
s
.
append
(
input_
seq_group
)
seq_group
_metadata_list
.
append
(
seq_group
_metadata
)
# Execute the first stage of the pipeline.
# Execute the first stage of the pipeline.
if
input_
seq_group
s
or
blocks_to_swap_in
or
blocks_to_swap_out
:
if
seq_group
_metadata_list
or
blocks_to_swap_in
or
blocks_to_swap_out
:
# Swap in and swap out should never happen at the same time.
# Swap in and swap out should never happen at the same time.
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
self
.
controllers
[
0
].
execute_stage
(
self
.
controllers
[
0
].
execute_stage
(
input_
seq_group
s
,
seq_group
_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
...
@@ -330,7 +330,7 @@ class Scheduler:
...
@@ -330,7 +330,7 @@ class Scheduler:
# Append a new token to the sequence.
# Append a new token to the sequence.
output
=
seq_outputs
[
seq
.
seq_id
]
output
=
seq_outputs
[
seq
.
seq_id
]
seq
.
append
(
output
.
output_token
,
output
.
logprobs
)
seq
.
append
_token
(
output
.
output_token
,
output
.
logprobs
)
# Check if the sequence has generated a stop token.
# Check if the sequence has generated a stop token.
if
output
.
output_token
in
stop_token_ids
:
if
output
.
output_token
in
stop_token_ids
:
...
@@ -360,13 +360,13 @@ class Scheduler:
...
@@ -360,13 +360,13 @@ class Scheduler:
if
seq_group
.
group_id
not
in
self
.
num_steps
:
if
seq_group
.
group_id
not
in
self
.
num_steps
:
self
.
num_steps
[
seq_group
.
group_id
]
=
0
self
.
num_steps
[
seq_group
.
group_id
]
=
0
def
_append
(
def
_append
_slot
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
)
->
None
:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
ret
=
self
.
block_manager
.
append
(
seq
)
ret
=
self
.
block_manager
.
append
_slot
(
seq
)
if
ret
is
not
None
:
if
ret
is
not
None
:
src_block
,
dst_block
=
ret
src_block
,
dst_block
=
ret
if
src_block
in
blocks_to_copy
:
if
src_block
in
blocks_to_copy
:
...
...
cacheflow/sampling_params.py
View file @
8d66a7b6
from
typing
import
Optional
,
Set
,
Dic
t
from
typing
import
Dict
,
Se
t
class
SamplingParams
:
class
SamplingParams
:
...
@@ -12,7 +12,6 @@ class SamplingParams:
...
@@ -12,7 +12,6 @@ class SamplingParams:
stop_token_ids
:
Set
[
int
],
stop_token_ids
:
Set
[
int
],
max_num_steps
:
int
,
max_num_steps
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
context_window_size
:
Optional
[
int
],
)
->
None
:
)
->
None
:
if
n
<
1
:
if
n
<
1
:
raise
ValueError
(
f
'n must be at least 1, got
{
n
}
.'
)
raise
ValueError
(
f
'n must be at least 1, got
{
n
}
.'
)
...
@@ -27,10 +26,6 @@ class SamplingParams:
...
@@ -27,10 +26,6 @@ class SamplingParams:
if
num_logprobs
<
0
:
if
num_logprobs
<
0
:
raise
ValueError
(
raise
ValueError
(
f
'num_logprobs must be non-negative, got
{
num_logprobs
}
.'
)
f
'num_logprobs must be non-negative, got
{
num_logprobs
}
.'
)
if
context_window_size
is
not
None
and
context_window_size
<
0
:
raise
ValueError
(
'context_window_size must be non-negative, '
f
'got
{
context_window_size
}
.'
)
if
use_beam_search
:
if
use_beam_search
:
if
n
==
1
:
if
n
==
1
:
...
@@ -58,7 +53,6 @@ class SamplingParams:
...
@@ -58,7 +53,6 @@ class SamplingParams:
self
.
stop_token_ids
=
stop_token_ids
self
.
stop_token_ids
=
stop_token_ids
self
.
max_num_steps
=
max_num_steps
self
.
max_num_steps
=
max_num_steps
self
.
num_logprobs
=
num_logprobs
self
.
num_logprobs
=
num_logprobs
self
.
context_window_size
=
context_window_size
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
'SamplingParams(n=
{
self
.
n
}
, '
return
(
f
'SamplingParams(n=
{
self
.
n
}
, '
...
@@ -67,8 +61,7 @@ class SamplingParams:
...
@@ -67,8 +61,7 @@ class SamplingParams:
f
'use_beam_search=
{
self
.
use_beam_search
}
, '
f
'use_beam_search=
{
self
.
use_beam_search
}
, '
f
'stop_token_ids=
{
self
.
stop_token_ids
}
, '
f
'stop_token_ids=
{
self
.
stop_token_ids
}
, '
f
'max_num_steps=
{
self
.
max_num_steps
}
, '
f
'max_num_steps=
{
self
.
max_num_steps
}
, '
f
'num_logprobs=
{
self
.
num_logprobs
}
, '
f
'num_logprobs=
{
self
.
num_logprobs
}
'
)
f
'context_window_size=
{
self
.
context_window_size
}
)'
)
@
classmethod
@
classmethod
def
from_dict
(
cls
,
d
:
Dict
)
->
'SamplingParams'
:
def
from_dict
(
cls
,
d
:
Dict
)
->
'SamplingParams'
:
...
@@ -80,5 +73,4 @@ class SamplingParams:
...
@@ -80,5 +73,4 @@ class SamplingParams:
stop_token_ids
=
set
(
d
.
get
(
'stop_token_ids'
,
set
())),
stop_token_ids
=
set
(
d
.
get
(
'stop_token_ids'
,
set
())),
max_num_steps
=
d
.
get
(
'max_num_steps'
,
16
),
max_num_steps
=
d
.
get
(
'max_num_steps'
,
16
),
num_logprobs
=
d
.
get
(
'num_logprobs'
,
0
),
num_logprobs
=
d
.
get
(
'num_logprobs'
,
0
),
context_window_size
=
d
.
get
(
'context_window_size'
,
None
),
)
)
cacheflow/sequence.py
View file @
8d66a7b6
...
@@ -18,45 +18,46 @@ class Sequence:
...
@@ -18,45 +18,46 @@ class Sequence:
def
__init__
(
def
__init__
(
self
,
self
,
seq_id
:
int
,
seq_id
:
int
,
token_ids
:
List
[
int
],
prompt_
token_ids
:
List
[
int
],
block_size
:
int
,
block_size
:
int
,
)
->
None
:
)
->
None
:
self
.
seq_id
=
seq_id
self
.
seq_id
=
seq_id
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
prompt_len
=
len
(
prompt_token_ids
)
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
# Initialize the logical token blocks with the
given
token ids.
# Initialize the logical token blocks with the
prompt
token ids.
self
.
add
(
token_ids
)
self
.
_append_tokens
(
prompt_
token_ids
)
self
.
prompt_len
=
len
(
token_ids
)
self
.
status
=
SequenceStatus
.
WAITING
self
.
status
=
SequenceStatus
.
WAITING
# Used for beam search.
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
output_logprobs
:
List
[
Dict
[
int
,
float
]]
=
[]
self
.
cumulative_logprobs
=
0.0
self
.
cumulative_logprobs
=
0.0
def
add
_block
(
self
)
->
None
:
def
_append_logical
_block
(
self
)
->
None
:
block
=
LogicalTokenBlock
(
block
=
LogicalTokenBlock
(
block_number
=
len
(
self
.
logical_token_blocks
),
block_number
=
len
(
self
.
logical_token_blocks
),
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
)
)
self
.
logical_token_blocks
.
append
(
block
)
self
.
logical_token_blocks
.
append
(
block
)
def
add
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
_append_tokens
(
self
,
token_ids
:
List
[
int
])
->
None
:
while
token_ids
:
while
token_ids
:
if
not
self
.
logical_token_blocks
:
if
not
self
.
logical_token_blocks
:
self
.
add
_block
()
self
.
_append_logical
_block
()
last_block
=
self
.
logical_token_blocks
[
-
1
]
last_block
=
self
.
logical_token_blocks
[
-
1
]
if
last_block
.
is_full
():
if
last_block
.
is_full
():
self
.
add
_block
()
self
.
_append_logical
_block
()
last_block
=
self
.
logical_token_blocks
[
-
1
]
last_block
=
self
.
logical_token_blocks
[
-
1
]
num_empty_slots
=
last_block
.
get_num_empty_slots
()
num_empty_slots
=
last_block
.
get_num_empty_slots
()
last_block
.
append
(
token_ids
[:
num_empty_slots
])
last_block
.
append
_tokens
(
token_ids
[:
num_empty_slots
])
token_ids
=
token_ids
[
num_empty_slots
:]
token_ids
=
token_ids
[
num_empty_slots
:]
def
append
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
float
])
->
None
:
def
append
_token
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
float
])
->
None
:
assert
token_id
in
logprobs
assert
token_id
in
logprobs
self
.
add
([
token_id
])
self
.
_append_tokens
([
token_id
])
self
.
output_logprobs
.
append
(
logprobs
)
self
.
output_logprobs
.
append
(
logprobs
)
self
.
cumulative_logprobs
+=
logprobs
[
token_id
]
self
.
cumulative_logprobs
+=
logprobs
[
token_id
]
...
@@ -121,7 +122,7 @@ class SequenceGroup:
...
@@ -121,7 +122,7 @@ class SequenceGroup:
f
'num_seqs=
{
len
(
self
.
seqs
)
}
)'
)
f
'num_seqs=
{
len
(
self
.
seqs
)
}
)'
)
class
SequenceGroup
Inputs
:
class
SequenceGroup
Metadata
:
def
__init__
(
def
__init__
(
self
,
self
,
...
...
cacheflow/worker/controller.py
View file @
8d66a7b6
from
typing
import
Dict
,
List
,
Un
ion
,
Tuple
,
Opt
ion
al
from
typing
import
List
,
Opt
ion
al
,
Tuple
,
Un
ion
try
:
try
:
import
ray
import
ray
...
@@ -6,7 +6,6 @@ except ImportError:
...
@@ -6,7 +6,6 @@ except ImportError:
ray
=
None
ray
=
None
from
cacheflow.core.scheduler
import
Scheduler
from
cacheflow.core.scheduler
import
Scheduler
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.worker.worker
import
Worker
from
cacheflow.worker.worker
import
Worker
...
@@ -81,23 +80,12 @@ class Controller:
...
@@ -81,23 +80,12 @@ class Controller:
self
.
next_node
=
next_node
self
.
next_node
=
next_node
self
.
is_last_stage
=
isinstance
(
next_node
,
Scheduler
)
self
.
is_last_stage
=
isinstance
(
next_node
,
Scheduler
)
def
execute_stage
(
def
execute_stage
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
,
input_seq_groups
:
List
[
SequenceGroupInputs
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
all_outputs
=
[]
all_outputs
=
[]
for
worker
in
self
.
workers
:
for
worker
in
self
.
workers
:
executor
=
(
worker
.
execute_stage
.
remote
executor
=
(
worker
.
execute_stage
.
remote
if
self
.
use_ray
else
worker
.
execute_stage
)
if
self
.
use_ray
else
worker
.
execute_stage
)
output
=
executor
(
output
=
executor
(
*
args
,
**
kwargs
)
input_seq_groups
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
)
all_outputs
.
append
(
output
)
all_outputs
.
append
(
output
)
if
self
.
use_ray
:
if
self
.
use_ray
:
...
...
cacheflow/worker/worker.py
View file @
8d66a7b6
...
@@ -8,10 +8,11 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
...
@@ -8,10 +8,11 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
initialize_all_reduce_launcher
,
initialize_all_reduce_launcher
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceGroup
Inputs
from
cacheflow.sequence
import
SequenceGroup
Metadata
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.worker.cache_engine
import
CacheEngine
from
cacheflow.worker.cache_engine
import
CacheEngine
class
Worker
:
class
Worker
:
def
__init__
(
def
__init__
(
...
@@ -93,30 +94,29 @@ class Worker:
...
@@ -93,30 +94,29 @@ class Worker:
def
prepare_inputs
(
def
prepare_inputs
(
self
,
self
,
input_
seq_group
s
:
List
[
SequenceGroup
Inputs
],
seq_group
_metadata_list
:
List
[
SequenceGroup
Metadata
],
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
seq_logprobs
:
Dict
[
int
,
float
]
=
{}
sampling_params
:
Dict
[
int
,
SamplingParams
]
=
{}
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
# Add prompt tokens.
# Add prompt tokens.
prompt_lens
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
for
input_
seq_group
in
input_seq_groups
:
for
seq_group
_metadata
in
seq_group_metadata_list
:
if
not
input_
seq_group
.
is_prompt
:
if
not
seq_group
_metadata
.
is_prompt
:
continue
continue
seq_ids
=
list
(
input_
seq_group
.
input_tokens
.
keys
())
seq_ids
=
list
(
seq_group
_metadata
.
input_tokens
.
keys
())
sampling_params
=
input_
seq_group
.
sampling_params
sampling_params
=
seq_group
_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
input_
seq_group
.
seq_logprobs
)
seq_logprobs
.
update
(
seq_group
_metadata
.
seq_logprobs
)
# Use any sequence in the group.
# Use any sequence in the group.
seq_id
=
seq_ids
[
0
]
seq_id
=
seq_ids
[
0
]
prompt_tokens
=
input_
seq_group
.
input_tokens
[
seq_id
]
prompt_tokens
=
seq_group
_metadata
.
input_tokens
[
seq_id
]
prompt_len
=
len
(
prompt_tokens
)
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
...
@@ -126,7 +126,7 @@ class Worker:
...
@@ -126,7 +126,7 @@ class Worker:
input_positions
.
extend
(
range
(
len
(
prompt_tokens
)))
input_positions
.
extend
(
range
(
len
(
prompt_tokens
)))
# Compute the slot mapping.
# Compute the slot mapping.
block_table
=
input_
seq_group
.
block_tables
[
seq_id
]
block_table
=
seq_group
_metadata
.
block_tables
[
seq_id
]
for
i
in
range
(
prompt_len
):
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
block_offset
=
i
%
self
.
block_size
...
@@ -138,31 +138,31 @@ class Worker:
...
@@ -138,31 +138,31 @@ class Worker:
max_num_blocks_per_seq
=
0
max_num_blocks_per_seq
=
0
context_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
generation_block_tables
:
List
[
List
[
int
]]
=
[]
generation_block_tables
:
List
[
List
[
int
]]
=
[]
for
input_
seq_group
in
input_seq_groups
:
for
seq_group
_metadata
in
seq_group_metadata_list
:
if
input_
seq_group
.
is_prompt
:
if
seq_group
_metadata
.
is_prompt
:
continue
continue
seq_ids
=
list
(
input_
seq_group
.
input_tokens
.
keys
())
seq_ids
=
list
(
seq_group
_metadata
.
input_tokens
.
keys
())
sampling_params
=
input_
seq_group
.
sampling_params
sampling_params
=
seq_group
_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_groups
.
append
((
seq_ids
,
sampling_params
))
seq_logprobs
.
update
(
input_
seq_group
.
seq_logprobs
)
seq_logprobs
.
update
(
seq_group
_metadata
.
seq_logprobs
)
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
assert
len
(
input_
seq_group
.
input_tokens
[
seq_id
])
==
1
assert
len
(
seq_group
_metadata
.
input_tokens
[
seq_id
])
==
1
generation_token
=
input_
seq_group
.
input_tokens
[
seq_id
][
0
]
generation_token
=
seq_group
_metadata
.
input_tokens
[
seq_id
][
0
]
input_tokens
.
append
(
generation_token
)
input_tokens
.
append
(
generation_token
)
position
=
input_
seq_group
.
context_len
-
1
position
=
seq_group
_metadata
.
context_len
-
1
input_positions
.
append
(
position
)
input_positions
.
append
(
position
)
block_table
=
input_
seq_group
.
block_tables
[
seq_id
]
block_table
=
seq_group
_metadata
.
block_tables
[
seq_id
]
generation_block_tables
.
append
(
block_table
)
generation_block_tables
.
append
(
block_table
)
max_context_len
=
max
(
max_context_len
=
max
(
max_context_len
,
input_
seq_group
.
context_len
)
max_context_len
,
seq_group
_metadata
.
context_len
)
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
,
len
(
block_table
))
max_num_blocks_per_seq
,
len
(
block_table
))
context_lens
.
append
(
input_
seq_group
.
context_len
)
context_lens
.
append
(
seq_group
_metadata
.
context_len
)
block_number
=
block_table
[
position
//
self
.
block_size
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
block_offset
=
position
%
self
.
block_size
...
@@ -203,30 +203,30 @@ class Worker:
...
@@ -203,30 +203,30 @@ class Worker:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_stage
(
def
execute_stage
(
self
,
self
,
input_
seq_group
s
:
List
[
SequenceGroup
Inputs
],
seq_group
_metadata_list
:
List
[
SequenceGroup
Metadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
)
->
Dict
[
int
,
SequenceOutputs
]:
# Issue cache operations.
# Issue cache operations.
command_issued
=
False
issued_cache_op
=
False
if
blocks_to_swap_in
:
if
blocks_to_swap_in
:
self
.
cache_engine
.
swap_in
(
blocks_to_swap_in
)
self
.
cache_engine
.
swap_in
(
blocks_to_swap_in
)
command_issued
=
True
issued_cache_op
=
True
if
blocks_to_swap_out
:
if
blocks_to_swap_out
:
self
.
cache_engine
.
swap_out
(
blocks_to_swap_out
)
self
.
cache_engine
.
swap_out
(
blocks_to_swap_out
)
command_issued
=
True
issued_cache_op
=
True
if
blocks_to_copy
:
if
blocks_to_copy
:
self
.
cache_engine
.
copy
(
blocks_to_copy
)
self
.
cache_engine
.
copy
(
blocks_to_copy
)
command_issued
=
True
issued_cache_op
=
True
if
command_issued
:
if
issued_cache_op
:
cache_events
=
self
.
cache_events
cache_events
=
self
.
cache_events
else
:
else
:
cache_events
=
None
cache_events
=
None
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
if
not
input_
seq_group
s
:
if
not
seq_group
_metadata_list
:
if
cache_events
is
not
None
:
if
cache_events
is
not
None
:
for
event
in
cache_events
:
for
event
in
cache_events
:
event
.
wait
()
event
.
wait
()
...
@@ -234,7 +234,7 @@ class Worker:
...
@@ -234,7 +234,7 @@ class Worker:
# Prepare input tensors.
# Prepare input tensors.
input_tokens
,
input_positions
,
input_metadata
=
self
.
prepare_inputs
(
input_tokens
,
input_positions
,
input_metadata
=
self
.
prepare_inputs
(
input_
seq_group
s
)
seq_group
_metadata_list
)
# Execute the model.
# Execute the model.
output
=
self
.
model
(
output
=
self
.
model
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment