Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4634c872
Unverified
Commit
4634c872
authored
Jul 18, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 18, 2024
Browse files
[TPU] Refactor TPU worker & model runner (#6506)
parent
c8a7d51c
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
272 additions
and
166 deletions
+272
-166
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+200
-97
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+72
-69
No files found.
vllm/worker/tpu_model_runner.py
View file @
4634c872
This diff is collapsed.
Click to expand it.
vllm/worker/tpu_worker.py
View file @
4634c872
...
@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
...
@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
LoraNotSupportedWorkerBase
,
WorkerInput
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
TPUWorker
(
LoraNotSupportedWorkerBase
):
class
TPUWorker
(
LoraNotSupportedWorkerBase
,
LocalOrDistributedWorkerBase
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
...
@@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_config
.
cache_dtype
]
self
.
cache_config
.
cache_dtype
]
self
.
model_runner
=
TPUModelRunner
(
model_config
,
self
.
model_runner
:
TPUModelRunner
=
TPUModelRunner
(
parallel_config
,
model_config
,
scheduler_config
,
parallel_config
,
device_config
,
scheduler_config
,
cache_config
,
device_config
,
load_config
,
cache_config
,
multimodal_config
,
load_config
,
is_driver_worker
=
is_driver_worker
)
multimodal_config
,
is_driver_worker
=
is_driver_worker
)
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
...
@@ -196,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase):
...
@@ -196,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase):
dtype_size
=
get_dtype_size
(
self
.
cache_dtype
)
dtype_size
=
get_dtype_size
(
self
.
cache_dtype
)
return
dtype_size
*
total
return
dtype_size
*
total
def
execute_model
(
@
property
def
do_metadata_broadcast
(
self
)
->
bool
:
# TODO(woosuk): Support TP.
return
False
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return
[
self
.
tpu_cache
]
def
prepare_worker_input
(
self
,
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
)
->
WorkerInput
:
if
not
self
.
is_driver_worker
:
virtual_engine
=
execute_model_req
.
virtual_engine
self
.
_execute_model_non_driver
()
num_seq_groups
=
len
(
execute_model_req
.
seq_group_metadata_list
)
return
[]
blocks_to_swap_in
=
_make_src_to_dst
(
assert
execute_model_req
is
not
None
execute_model_req
.
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
# Issue cache operations.
blocks_to_swap_out
=
_make_src_to_dst
(
self
.
cache_swap
(
execute_model_req
.
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
execute_model_req
.
blocks_to_swap_in
,
blocks_to_copy
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_copy
,
execute_model_req
.
blocks_to_swap_out
,
self
.
device
,
self
.
device
)
execute_model_req
.
blocks_to_copy
,
return
WorkerInput
(
num_seq_groups
=
num_seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
virtual_engine
=
virtual_engine
,
)
)
# Run the model.
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
def
execute_worker
(
self
,
worker_input
:
WorkerInput
)
->
None
:
assert
len
(
seq_group_metadata_list
)
>
0
virtual_engine
=
worker_input
.
virtual_engine
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
assert
virtual_engine
==
0
self
.
tpu_cache
)
return
output
def
cache_swap
(
self
,
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]],
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]],
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
attn_backend
=
self
.
model_runner
.
attn_backend
attn_backend
=
self
.
model_runner
.
attn_backend
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
if
blocks_to_swap_in
:
# Issue cache operations.
# Swap from CPU to TPU.
if
worker_input
.
blocks_to_swap_in
is
not
None
:
src_indices
,
dst_indices
=
_make_src_to_dst
(
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_in
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
if
src_indices
.
numel
()
>
0
:
for
i
in
range
(
num_layers
):
# Swap from CPU to TPU.
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
for
i
in
range
(
num_layers
):
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
k
=
cpu_k_cache
[:,
src_indices
].
to
(
self
.
device
)
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
v
=
cpu_v_cache
[:,
src_indices
].
to
(
self
.
device
)
k
=
cpu_k_cache
[:,
src_indices
].
to
(
self
.
device
)
_insert_kv
(
k
,
v
,
dst_indices
,
tpu_k_cache
,
tpu_v_cache
)
v
=
cpu_v_cache
[:,
src_indices
].
to
(
self
.
device
)
_insert_kv
(
k
,
v
,
dst_indices
,
tpu_k_cache
,
tpu_v_cache
)
if
blocks_to_swap_out
:
# Swap from TPU to CPU.
if
worker_input
.
blocks_to_swap_out
is
not
None
:
src_indices
,
dst_indices
=
_make_src_to_dst
(
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_out
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
if
src_indices
.
numel
()
>
0
:
for
i
in
range
(
num_layers
):
# Swap from TPU to CPU.
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
for
i
in
range
(
num_layers
):
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
[:,
dst_indices
]
=
tpu_k_cache
[:,
src_indices
].
cpu
()
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
cpu_v_cache
[:,
dst_indices
]
=
tpu_v_cache
[:,
src_indices
].
cpu
()
cpu_k_cache
[:,
dst_indices
]
=
tpu_k_cache
[:,
src_indices
]
cpu_v_cache
[:,
dst_indices
]
=
tpu_v_cache
[:,
src_indices
]
if
blocks_to_copy
:
src_to_dst
=
_make_src_to_dst
(
blocks_to_copy
,
self
.
device
,
if
worker_input
.
blocks_to_copy
is
not
None
:
self
.
device
)
src_indices
,
dst_indices
=
worker_input
.
blocks_to_copy
attn_backend
.
copy_blocks
(
self
.
tpu_cache
,
src_to_dst
)
if
src_indices
.
numel
()
>
0
:
attn_backend
.
copy_blocks
(
self
.
tpu_cache
,
def
start_worker_execution_loop
(
self
)
->
None
:
(
src_indices
,
dst_indices
))
while
self
.
_execute_model_non_driver
():
pass
def
_execute_model_non_driver
(
self
)
->
bool
:
self
.
model_runner
.
execute_model
(
None
,
self
.
tpu_cache
)
return
True
def
_make_src_to_dst
(
def
_make_src_to_dst
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment