Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc34937d
Unverified
Commit
bc34937d
authored
Jun 25, 2024
by
Woosuk Kwon
Committed by
GitHub
Jun 25, 2024
Browse files
[Hardware][TPU] Refactor TPU backend (#5831)
parent
dd248f76
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
32 deletions
+65
-32
vllm/executor/tpu_executor.py
vllm/executor/tpu_executor.py
+37
-21
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+4
-0
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+24
-11
No files found.
vllm/executor/tpu_executor.py
View file @
bc34937d
from
typing
import
List
,
Set
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
import
torch
...
...
@@ -26,29 +26,45 @@ class TPUExecutor(ExecutorBase):
self
.
model_config
.
dtype
=
torch
.
bfloat16
# Instantiate the worker and load the model to the device.
self
.
_init_worker
()
def
_init_worker
(
self
):
from
vllm.worker.tpu_worker
import
TPUWorker
self
.
driver_worker
=
self
.
_create_worker
()
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
load_model
()
assert
self
.
parallel_config
.
world_size
==
1
,
(
"TPUExecutor currently only supports a single TPU chip."
)
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
self
.
driver_worker
=
TPUWorker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
self
.
device_config
,
self
.
cache_config
,
self
.
load_config
,
self
.
vision_language_config
,
local_rank
=
0
,
rank
=
0
,
def
_get_worker_kwargs
(
self
,
local_rank
:
int
=
0
,
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
,
)
->
Dict
[
str
,
Any
]:
"""Return worker init args for a given rank."""
if
distributed_init_method
is
None
:
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
return
dict
(
model_config
=
self
.
model_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
device_config
=
self
.
device_config
,
cache_config
=
self
.
cache_config
,
load_config
=
self
.
load_config
,
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
vision_language_config
=
self
.
vision_language_config
,
is_driver_worker
=
rank
==
0
,
)
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
load_model
()
def
_create_worker
(
self
,
local_rank
:
int
=
0
,
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
,
):
from
vllm.worker.tpu_worker
import
TPUWorker
worker
=
TPUWorker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
))
return
worker
def
initialize_cache
(
self
,
...
...
vllm/worker/tpu_model_runner.py
View file @
bc34937d
...
...
@@ -33,6 +33,7 @@ class TPUModelRunner:
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
is_driver_worker
:
bool
=
False
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
...
...
@@ -41,6 +42,7 @@ class TPUModelRunner:
self
.
cache_config
=
cache_config
self
.
load_config
=
load_config
self
.
vision_language_config
=
vision_language_config
self
.
is_driver_worker
=
is_driver_worker
self
.
block_size
=
self
.
cache_config
.
block_size
self
.
max_num_blocks_per_seq
=
(
self
.
model_config
.
max_model_len
//
...
...
@@ -373,6 +375,8 @@ class TPUModelRunner:
inputs
=
self
.
prepare_inputs
(
seq_group_metadata_list
)
next_token_ids
=
self
.
model
(
inputs
[
0
],
inputs
[
1
],
kv_caches
,
*
inputs
[
2
:])
if
not
self
.
is_driver_worker
:
return
[]
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
i
=
0
...
...
vllm/worker/tpu_worker.py
View file @
bc34937d
...
...
@@ -34,6 +34,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
,
)
->
None
:
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
...
...
@@ -45,6 +46,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
assert
self
.
device_config
.
device_type
==
"tpu"
if
self
.
cache_config
.
cache_dtype
==
"auto"
:
...
...
@@ -53,10 +55,14 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_config
.
cache_dtype
]
self
.
model_runner
=
TPUModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
cache_config
,
load_config
,
vision_language_config
)
self
.
model_runner
=
TPUModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
cache_config
,
load_config
,
vision_language_config
,
is_driver_worker
=
is_driver_worker
)
def
init_device
(
self
)
->
None
:
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
...
...
@@ -175,16 +181,13 @@ class TPUWorker(LoraNotSupportedWorkerBase):
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
List
[
SamplerOutput
]:
if
execute_model_req
is
None
:
return
[]
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
num_seq_groups
=
len
(
seq_group_metadata_list
)
if
num_seq_groups
==
0
:
if
not
self
.
is_driver_worker
:
self
.
_execute_model_non_driver
()
return
[]
assert
execute_model_req
is
not
None
# Currently, TPUWorker does not support swapping.
# TODO(woosuk): Support block copying.
assert
len
(
execute_model_req
.
blocks_to_swap_in
)
==
0
,
(
...
...
@@ -193,6 +196,16 @@ class TPUWorker(LoraNotSupportedWorkerBase):
"Swapping is not supported for the TPU backend."
)
assert
len
(
execute_model_req
.
blocks_to_copy
)
==
0
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
assert
len
(
seq_group_metadata_list
)
>
0
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
tpu_cache
)
return
[
output
]
def
start_worker_execution_loop
(
self
)
->
None
:
while
self
.
_execute_model_non_driver
():
pass
def
_execute_model_non_driver
(
self
)
->
bool
:
self
.
model_runner
.
execute_model
(
None
,
self
.
tpu_cache
)
return
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment