Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
8290fce4
Commit
8290fce4
authored
Feb 22, 2023
by
Woosuk Kwon
Browse files
Add Worker class
parent
7b6844e5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
169 additions
and
0 deletions
+169
-0
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+169
-0
No files found.
cacheflow/worker/worker.py
0 → 100644
View file @
8290fce4
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
cacheflow.models
import
get_model
from
cacheflow.models
import
InputMetadata
from
cacheflow.worker.cache_engine
import
CacheEngine
class
Worker
:
def
__init__
(
self
,
worker_id
:
int
,
gpu_id
:
int
,
model_name
:
str
,
block_size
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
)
->
None
:
self
.
worker_id
=
worker_id
self
.
gpu_id
=
gpu_id
self
.
block_size
=
block_size
self
.
device
=
torch
.
device
(
'cuda'
,
index
=
gpu_id
)
# Initialize the model.
# FIXME(woosuk): This is a hack.
self
.
model
=
get_model
(
model_name
).
to
(
device
=
gpu_id
)
self
.
num_layers
=
self
.
model
.
config
.
num_hidden_layers
self
.
num_heads
=
self
.
model
.
config
.
num_attention_heads
self
.
head_size
=
self
.
model
.
config
.
hidden_size
//
self
.
num_heads
self
.
dtype
=
self
.
model
.
dtype
self
.
cache_engine
=
CacheEngine
(
worker_id
=
worker_id
,
gpu_id
=
gpu_id
,
num_layers
=
self
.
num_layers
,
num_heads
=
self
.
num_heads
,
head_size
=
self
.
head_size
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
dtype
=
self
.
dtype
,
)
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
def
prepare_inputs
(
self
,
prompt_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of input token ids.
generation_tokens
:
Dict
[
int
,
int
],
# Seq id -> Input token id.
context_lens
:
Dict
[
int
,
int
],
# Seq id -> Number of tokens participating in attention.
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of physical block numbers.
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
InputMetadata
]:
# TODO(woosuk): Support interactive generation.
# Add the prompt tokens.
prompt_lens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
prompt_seq_ids
=
sorted
(
prompt_tokens
.
keys
())
for
seq_id
in
prompt_seq_ids
:
prompt_len
=
len
(
prompt_tokens
[
seq_id
])
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
extend
(
prompt_tokens
[
seq_id
])
input_positions
.
extend
(
range
(
len
(
prompt_tokens
[
seq_id
])))
block_table
=
block_tables
[
seq_id
]
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
# Add the generation tokens.
max_context_len
=
0
max_num_blocks_per_seq
=
0
generation_block_tables
:
List
[
List
[
int
]]
=
[]
generation_seq_ids
=
sorted
(
generation_tokens
.
keys
())
for
seq_id
in
generation_seq_ids
:
input_tokens
.
append
(
generation_tokens
[
seq_id
])
input_positions
.
append
(
context_lens
[
seq_id
]
-
1
)
generation_block_tables
.
append
(
block_tables
[
seq_id
])
max_context_len
=
max
(
max_context_len
,
context_lens
[
seq_id
])
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
,
len
(
block_tables
[
seq_id
]))
# Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
input_tokens
=
_pad_to_alignment
(
input_tokens
,
multiple_of
=
8
)
input_positions
=
_pad_to_alignment
(
input_positions
,
multiple_of
=
8
)
# Convert to tensors.
tokens_tensor
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
positions_tensor
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
context_lens_tensor
=
torch
.
tensor
(
[
context_lens
[
seq_id
]
for
seq_id
in
generation_seq_ids
],
dtype
=
torch
.
int
,
device
=
self
.
device
)
block_tables_tensor
=
torch
.
tensor
(
[
_pad_to_max
(
block_table
)
for
block_table
in
generation_block_tables
],
dtype
=
int
,
device
=
self
.
device
)
input_metadata
=
InputMetadata
(
prompt_lens
=
prompt_lens
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
max_context_len
=
max_context_len
,
block_tables
=
block_tables_tensor
,
)
return
tokens_tensor
,
positions_tensor
,
input_metadata
@
torch
.
inference_mode
()
def
execute_stage
(
self
,
prompt_tokens
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of input token ids.
generation_tokens
:
Dict
[
int
,
int
],
# Seq id -> Input token id.
context_lens
:
Dict
[
int
,
int
],
# Seq id -> Number of tokens participating in attention.
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> List of physical block numbers.
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
int
],
)
->
torch
.
Tensor
:
# Issue cache operations.
command_issued
=
False
if
blocks_to_swap_in
:
self
.
cache_engine
.
swap_in
(
blocks_to_swap_in
)
command_issued
=
True
if
blocks_to_swap_out
:
self
.
cache_engine
.
swap_out
(
blocks_to_swap_out
)
command_issued
=
True
if
blocks_to_copy
:
self
.
cache_engine
.
copy
(
blocks_to_copy
)
command_issued
=
True
if
command_issued
:
cache_events
=
self
.
cache_events
else
:
cache_events
=
None
# Prepare input tensors.
input_tokens
,
input_positions
,
input_metadata
=
self
.
prepare_inputs
(
prompt_tokens
,
generation_tokens
,
context_lens
,
block_tables
)
# Execute the model.
output
=
self
.
model
(
input_ids
=
input_tokens
,
positions
=
input_positions
,
input_metadata
=
input_metadata
,
cache_events
=
cache_events
,
)
return
output
def
_pad_to_alignment
(
x
:
List
[
int
],
multiple_of
:
int
)
->
List
[
int
]:
return
x
+
[
0
]
*
((
-
len
(
x
))
%
multiple_of
)
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
)
->
List
[
int
]:
return
x
+
[
0
]
*
(
max_len
-
len
(
x
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment