Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cbc53b6b
Unverified
Commit
cbc53b6b
authored
Jun 26, 2024
by
Woosuk Kwon
Committed by
GitHub
Jun 26, 2024
Browse files
[Hardware][TPU] Support parallel sampling & Swapping (#5855)
parent
c54269d9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
147 additions
and
56 deletions
+147
-56
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+22
-8
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+50
-26
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+75
-22
No files found.
vllm/attention/backends/pallas.py
View file @
cbc53b6b
...
...
@@ -28,21 +28,35 @@ class PallasAttentionBackend(AttentionBackend):
)
->
Tuple
[
int
,
...]:
return
(
num_kv_heads
,
num_blocks
,
block_size
,
head_size
)
@
torch
.
compile
(
backend
=
"openxla"
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
src_kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
dst_kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
src_to_dst
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
None
:
raise
NotImplementedError
(
"swap_blocks is not implemented."
)
src_k_cache
,
src_v_cache
=
src_kv_cache
dst_k_cache
,
dst_v_cache
=
dst_kv_cache
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
dst_k_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
dst_v_cache
,
True
)
device
=
dst_k_cache
.
device
src_indices
,
dst_indices
=
src_to_dst
dst_k_cache
[:,
dst_indices
]
=
src_k_cache
[:,
src_indices
].
to
(
device
)
dst_v_cache
[:,
dst_indices
]
=
src_v_cache
[:,
src_indices
].
to
(
device
)
@
torch
.
compile
(
backend
=
"openxla"
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]
],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
],
src_to_dists
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
None
:
# TODO(woosuk): Implement this.
raise
NotImplementedError
(
"copy_blocks is not implemented."
)
src_indices
,
dst_indices
=
src_to_dists
for
k_cache
,
v_cache
in
kv_caches
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
k_cache
,
True
)
k_cache
[:,
dst_indices
]
=
k_cache
[:,
src_indices
]
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
v_cache
,
True
)
v_cache
[:,
dst_indices
]
=
v_cache
[:,
src_indices
]
@
dataclass
...
...
vllm/worker/tpu_model_runner.py
View file @
cbc53b6b
...
...
@@ -22,6 +22,9 @@ logger = init_logger(__name__)
_PAD_SLOT_ID
=
0
# FIXME(woosuk)
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P
=
False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES
=
128
class
TPUModelRunner
:
...
...
@@ -143,8 +146,9 @@ class TPUModelRunner:
p
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
# Dummy run.
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
self
.
model
(
token_ids
,
position_ids
,
kv_caches
,
attn_metadata
,
input_lens
,
t
,
p
)
input_lens
,
t
,
p
,
num_samples
)
def
warmup_model
(
self
,
...
...
@@ -268,14 +272,11 @@ class TPUModelRunner:
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
context_lens
:
List
[
int
]
=
[]
num_seq_groups
=
len
(
seq_group_metadata_list
)
batch_size
=
_get_padded_batch_size
(
num_seq_groups
)
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
batch_idx
=
0
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
...
...
@@ -288,14 +289,16 @@ class TPUModelRunner:
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
self
.
block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
self
.
block_tables
[
batch_idx
,
:
len
(
block_table
)]
=
block_table
batch_idx
+=
1
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
([
slot
])
num_paddings
=
batch_size
-
num_seq_groups
batch_size
=
_get_padded_batch_size
(
batch_idx
)
num_paddings
=
batch_size
-
batch_idx
input_tokens
=
input_tokens
+
[[
0
]]
*
num_paddings
input_positions
=
input_positions
+
[[
0
]]
*
num_paddings
slot_mapping
=
slot_mapping
+
[[
_PAD_SLOT_ID
]]
*
num_paddings
...
...
@@ -333,14 +336,13 @@ class TPUModelRunner:
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
padded_batch_size
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
int
]
]:
assert
len
(
seq_group_metadata_list
)
>
0
t
=
[]
p
=
[]
best_of
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
sampling_params
is
not
None
sampling_params
=
seq_group_metadata
.
sampling_params
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
# low temperature. This is not accurate.
t
.
append
(
sampling_params
.
temperature
...
...
@@ -354,10 +356,11 @@ class TPUModelRunner:
raise
NotImplementedError
(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues."
)
if
sampling_params
.
best_of
>
1
:
if
sampling_params
.
best_of
>
_MAX_NUM_SAMPLES
:
raise
NotImplementedError
(
"b
est
_
of >
1 is not currently
supported by the TPU "
f
"B
est
of >
{
_MAX_NUM_SAMPLES
}
is not
supported by the TPU "
"backend."
)
best_of
.
append
(
sampling_params
.
best_of
)
if
sampling_params
.
use_beam_search
:
raise
NotImplementedError
(
"Beam search is not supported by the TPU backend."
)
...
...
@@ -369,13 +372,19 @@ class TPUModelRunner:
"prompt_logprobs is not currently supported by the TPU "
"backend."
)
num_paddings
=
padded_batch_size
-
len
(
seq_group_metadata_list
)
# Repeat the sampling params if the seq group has multiple seqs.
num_seqs
=
len
(
seq_group_metadata
.
seq_data
)
t
+=
[
t
[
-
1
]]
*
(
num_seqs
-
1
)
p
+=
[
p
[
-
1
]]
*
(
num_seqs
-
1
)
best_of
+=
[
best_of
[
-
1
]]
*
(
num_seqs
-
1
)
num_paddings
=
padded_batch_size
-
len
(
t
)
t
+=
[
1.0
]
*
num_paddings
p
+=
[
1.0
]
*
num_paddings
t
=
torch
.
tensor
(
t
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
return
t
,
p
return
t
,
p
,
best_of
def
_execute_model
(
self
,
...
...
@@ -392,28 +401,41 @@ class TPUModelRunner:
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
padded_batch_size
=
inputs
[
0
].
shape
[
0
]
t
,
p
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
padded_batch_size
)
t
,
p
,
best_of
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
padded_batch_size
)
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
# Execute the model.
next_token_ids
=
self
.
model
(
inputs
[
0
],
inputs
[
1
],
kv_caches
,
*
inputs
[
2
:],
t
,
p
)
*
inputs
[
2
:],
t
,
p
,
num_samples
)
# Retrieve the outputs to CPU.
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support the advanced sampling parameters such as logprobs.
i
=
0
zero_logprob
=
Logprob
(
0.0
)
batch_idx
=
0
sampler_outputs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_outputs
=
[]
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
if
is_prompt
:
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
for
i
in
range
(
best_of
[
batch_idx
]):
next_token_id
=
next_token_ids
[
batch_idx
][
i
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
zero_logprob
}))
batch_idx
+=
1
else
:
for
seq_id
in
seq_ids
:
next_token_id
=
next_token_ids
[
i
]
next_token_id
=
next_token_ids
[
batch_idx
][
0
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
L
ogprob
(
0.0
)
}))
i
+=
1
{
next_token_id
:
zero_l
ogprob
}))
batch_idx
+=
1
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
sampler_outputs
...
...
@@ -458,6 +480,7 @@ class ModelWrapper(nn.Module):
input_lens
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
num_samples
:
int
,
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model and samples the next token.
...
...
@@ -520,8 +543,9 @@ class ModelWrapper(nn.Module):
if
_ENABLE_TOP_P
:
logits
=
_apply_top_p
(
logits
,
p
.
unsqueeze
(
dim
=
1
))
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
# FIXME(woosuk): best_of > 1 is not supported.
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
dim
=
1
)
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
,
replacement
=
True
)
return
next_token_ids
...
...
vllm/worker/tpu_worker.py
View file @
cbc53b6b
import
os
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch_xla.core.xla_model
as
xm
...
...
@@ -117,19 +117,26 @@ class TPUWorker(LoraNotSupportedWorkerBase):
# Synchronize before measuring the memory usage.
xm
.
wait_device_ops
()
dtype_btyes
=
get_dtype_size
(
self
.
cache_dtype
)
block_size
=
self
.
cache_config
.
block_size
block_size_bytes
=
(
dtype_btyes
*
block_size
*
num_layers
*
2
*
head_size
*
num_kv_heads
)
# Calculate the TPU KV cache size based on profiling.
m
=
xm
.
get_memory_info
(
self
.
device
)
total_memory_size
=
m
[
"bytes_limit"
]
usable_memory_size
=
int
(
total_memory_size
*
self
.
cache_config
.
gpu_memory_utilization
)
profiled
=
m
[
"bytes_used"
]
# Weights + intermediate activations.
kv_cache_bytes
=
max
(
usable_memory_size
-
profiled
,
0
)
dtype_btyes
=
get_dtype_size
(
self
.
cache_dtype
)
block_size
=
self
.
cache_config
.
block_size
num_tpu_blocks
=
(
kv_cache_bytes
//
(
dtype_btyes
*
block_size
*
num_layers
*
2
*
head_size
*
num_kv_heads
))
tpu_kv_cache_bytes
=
max
(
usable_memory_size
-
profiled
,
0
)
num_tpu_blocks
=
tpu_kv_cache_bytes
//
block_size_bytes
num_tpu_blocks
=
(
num_tpu_blocks
//
8
)
*
8
# Round down to 8.
return
num_tpu_blocks
,
0
# Calculate the CPU KV cache size based on the config.
num_cpu_blocks
=
(
self
.
cache_config
.
swap_space_bytes
//
block_size_bytes
)
num_cpu_blocks
=
(
num_cpu_blocks
//
8
)
*
8
# Round down to 8.
return
num_tpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
...
...
@@ -145,15 +152,19 @@ class TPUWorker(LoraNotSupportedWorkerBase):
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
head_size
=
self
.
model_config
.
get_head_size
()
self
.
cpu_cache
=
[]
self
.
tpu_cache
=
[]
tpu_cache_shape
=
self
.
model_runner
.
attn_backend
.
get_kv_cache_shape
(
num_gpu_blocks
,
self
.
block_size
,
num_kv_heads
,
head_size
)
for
_
in
range
(
num_layers
):
key
_cache
=
torch
.
zeros
(
tpu_cache_shape
,
tpu_k
_cache
=
torch
.
zeros
(
tpu_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
)
value_cache
=
torch
.
zeros_like
(
key_cache
)
self
.
tpu_cache
.
append
((
key_cache
,
value_cache
))
tpu_v_cache
=
torch
.
zeros_like
(
tpu_k_cache
)
self
.
tpu_cache
.
append
((
tpu_k_cache
,
tpu_v_cache
))
cpu_k_cache
=
torch
.
zeros_like
(
tpu_k_cache
,
device
=
"cpu"
)
cpu_v_cache
=
torch
.
zeros_like
(
tpu_v_cache
,
device
=
"cpu"
)
self
.
cpu_cache
.
append
((
cpu_k_cache
,
cpu_v_cache
))
self
.
_warmup_model
()
def
_warmup_model
(
self
)
->
None
:
...
...
@@ -187,22 +198,48 @@ class TPUWorker(LoraNotSupportedWorkerBase):
if
not
self
.
is_driver_worker
:
self
.
_execute_model_non_driver
()
return
[]
assert
execute_model_req
is
not
None
# Currently, TPUWorker does not support swapping.
# TODO(woosuk): Support block copying.
assert
len
(
execute_model_req
.
blocks_to_swap_in
)
==
0
,
(
"Swapping is not supported for the TPU backend."
)
assert
len
(
execute_model_req
.
blocks_to_swap_out
)
==
0
,
(
"Swapping is not supported for the TPU backend."
)
assert
len
(
execute_model_req
.
blocks_to_copy
)
==
0
# Issue cache operations.
self
.
cache_swap
(
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
,
)
# Run the model.
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
assert
len
(
seq_group_metadata_list
)
>
0
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
tpu_cache
)
return
[
output
]
def
cache_swap
(
self
,
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]],
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]],
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
attn_backend
=
self
.
model_runner
.
attn_backend
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
if
blocks_to_swap_in
:
# Swap from CPU to TPU.
src_to_dst
=
_make_src_to_dst
(
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
for
i
in
range
(
num_layers
):
attn_backend
.
swap_blocks
(
self
.
cpu_cache
[
i
],
self
.
tpu_cache
[
i
],
src_to_dst
)
if
blocks_to_swap_out
:
# Swap from TPU to CPU.
src_to_dst
=
_make_src_to_dst
(
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
for
i
in
range
(
num_layers
):
attn_backend
.
swap_blocks
(
self
.
tpu_cache
[
i
],
self
.
cpu_cache
[
i
],
src_to_dst
)
if
blocks_to_copy
:
src_to_dst
=
_make_src_to_dst
(
blocks_to_copy
,
self
.
device
,
self
.
device
)
attn_backend
.
copy_blocks
(
self
.
tpu_cache
,
src_to_dst
)
def
start_worker_execution_loop
(
self
)
->
None
:
while
self
.
_execute_model_non_driver
():
pass
...
...
@@ -210,3 +247,19 @@ class TPUWorker(LoraNotSupportedWorkerBase):
def
_execute_model_non_driver
(
self
)
->
bool
:
self
.
model_runner
.
execute_model
(
None
,
self
.
tpu_cache
)
return
True
def
_make_src_to_dst
(
mapping
:
List
[
Tuple
[
int
,
int
]],
src_device
:
Union
[
torch
.
device
,
str
],
dst_device
:
Union
[
torch
.
device
,
str
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
src_indices
=
[
i
for
i
,
_
in
mapping
]
dst_indices
=
[
i
for
_
,
i
in
mapping
]
src_indices
=
torch
.
tensor
(
src_indices
,
device
=
src_device
,
dtype
=
torch
.
int64
)
dst_indices
=
torch
.
tensor
(
dst_indices
,
device
=
dst_device
,
dtype
=
torch
.
int64
)
return
src_indices
,
dst_indices
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment