Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f178e56c
Unverified
Commit
f178e56c
authored
Jun 25, 2024
by
Woosuk Kwon
Committed by
GitHub
Jun 25, 2024
Browse files
[Hardware][TPU] Raise errors for unsupported sampling params (#5850)
parent
dd793d1d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
19 deletions
+44
-19
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+44
-19
No files found.
vllm/worker/tpu_model_runner.py
View file @
f178e56c
...
@@ -20,6 +20,8 @@ from vllm.utils import make_tensor_with_pad
...
@@ -20,6 +20,8 @@ from vllm.utils import make_tensor_with_pad
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_PAD_SLOT_ID
=
0
# FIXME(woosuk)
_PAD_SLOT_ID
=
0
# FIXME(woosuk)
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P
=
False
class
TPUModelRunner
:
class
TPUModelRunner
:
...
@@ -339,9 +341,34 @@ class TPUModelRunner:
...
@@ -339,9 +341,34 @@ class TPUModelRunner:
assert
seq_group_metadata
.
sampling_params
is
not
None
assert
seq_group_metadata
.
sampling_params
is
not
None
sampling_params
=
seq_group_metadata
.
sampling_params
sampling_params
=
seq_group_metadata
.
sampling_params
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
# low temperature. This is not accurate.
t
.
append
(
sampling_params
.
temperature
t
.
append
(
sampling_params
.
temperature
if
sampling_params
.
temperature
>=
1e-5
else
1e-5
)
if
sampling_params
.
temperature
>=
1e-5
else
1e-5
)
if
sampling_params
.
top_p
!=
1
and
not
_ENABLE_TOP_P
:
raise
NotImplementedError
(
"Top-p sampling is currently disabled for the TPU backend "
"due to performance issues."
)
p
.
append
(
sampling_params
.
top_p
)
p
.
append
(
sampling_params
.
top_p
)
if
sampling_params
.
top_k
!=
-
1
:
raise
NotImplementedError
(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues."
)
if
sampling_params
.
best_of
>
1
:
raise
NotImplementedError
(
"best_of > 1 is not currently supported by the TPU "
"backend."
)
if
sampling_params
.
use_beam_search
:
raise
NotImplementedError
(
"Beam search is not supported by the TPU backend."
)
if
sampling_params
.
logprobs
is
not
None
:
raise
NotImplementedError
(
"logprobs is not currently supported by the TPU backend."
)
if
sampling_params
.
prompt_logprobs
is
not
None
:
raise
NotImplementedError
(
"prompt_logprobs is not currently supported by the TPU "
"backend."
)
num_paddings
=
padded_batch_size
-
len
(
seq_group_metadata_list
)
num_paddings
=
padded_batch_size
-
len
(
seq_group_metadata_list
)
t
+=
[
1.0
]
*
num_paddings
t
+=
[
1.0
]
*
num_paddings
p
+=
[
1.0
]
*
num_paddings
p
+=
[
1.0
]
*
num_paddings
...
@@ -350,35 +377,32 @@ class TPUModelRunner:
...
@@ -350,35 +377,32 @@ class TPUModelRunner:
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
return
t
,
p
return
t
,
p
def
prepare_inputs
(
def
_execute_model
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
):
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
assert
seq_group_metadata_list
is
not
None
)
->
List
[
CompletionSequenceGroupOutput
]:
# Prepare inputs.
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
# NOTE: We assume that all sequences in the group are all prompts or
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# all decodes.
if
seq_group_metadata_list
[
0
].
is_prompt
:
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
if
is_prompt
:
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
padded_batch_size
=
inputs
[
0
].
shape
[
0
]
padded_batch_size
=
inputs
[
0
].
shape
[
0
]
sample_inputs
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
t
,
p
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
padded_batch_size
)
padded_batch_size
)
return
inputs
+
sample_inputs
def
_execute_model
(
# Execute the model.
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
List
[
CompletionSequenceGroupOutput
]:
inputs
=
self
.
prepare_inputs
(
seq_group_metadata_list
)
next_token_ids
=
self
.
model
(
inputs
[
0
],
inputs
[
1
],
kv_caches
,
next_token_ids
=
self
.
model
(
inputs
[
0
],
inputs
[
1
],
kv_caches
,
*
inputs
[
2
:])
*
inputs
[
2
:],
t
,
p
)
if
not
self
.
is_driver_worker
:
# Retrieve the outputs to CPU.
return
[]
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support the advanced sampling parameters such as logprobs.
i
=
0
i
=
0
sampler_outputs
=
[]
sampler_outputs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
...
@@ -400,6 +424,7 @@ class TPUModelRunner:
...
@@ -400,6 +424,7 @@ class TPUModelRunner:
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
SamplerOutput
:
)
->
SamplerOutput
:
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
assert
len
(
seq_group_metadata_list
)
>
0
if
seq_group_metadata_list
[
0
].
is_prompt
:
if
seq_group_metadata_list
[
0
].
is_prompt
:
# NOTE(woosuk): To reduce the compilation time, we only compile the
# NOTE(woosuk): To reduce the compilation time, we only compile the
# prefill inputs with batch size 1. Because the scheduler is not
# prefill inputs with batch size 1. Because the scheduler is not
...
@@ -492,8 +517,8 @@ class ModelWrapper(nn.Module):
...
@@ -492,8 +517,8 @@ class ModelWrapper(nn.Module):
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
logits
/
t
.
unsqueeze
(
dim
=
1
)
logits
=
logits
/
t
.
unsqueeze
(
dim
=
1
)
# FIXME(woosuk): Disabled top-p sampling since it's too slow.
if
_ENABLE_TOP_P
:
#
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
logits
=
_apply_top_p
(
logits
,
p
.
unsqueeze
(
dim
=
1
))
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
# FIXME(woosuk): best_of > 1 is not supported.
# FIXME(woosuk): best_of > 1 is not supported.
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
dim
=
1
)
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
dim
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment