Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6e063ea3
"vscode:/vscode.git/clone" did not exist on "9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08"
Unverified
Commit
6e063ea3
authored
Jul 30, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 30, 2024
Browse files
[TPU] Fix greedy decoding (#6933)
parent
af647fb8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
9 deletions
+18
-9
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+18
-9
No files found.
vllm/worker/tpu_model_runner.py
View file @
6e063ea3
...
@@ -28,7 +28,9 @@ if TYPE_CHECKING:
...
@@ -28,7 +28,9 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_PAD_SLOT_ID
=
-
1
# NOTE(woosuk): In PyTorch XLA, index -1 is ignored.
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID
=
1_000_000_000
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P
=
False
_ENABLE_TOP_P
=
False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# FIXME(woosuk): A temporary hack to support `n > 1`.
...
@@ -414,10 +416,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -414,10 +416,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
best_of
=
[]
best_of
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
sampling_params
=
seq_group_metadata
.
sampling_params
sampling_params
=
seq_group_metadata
.
sampling_params
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
t
.
append
(
sampling_params
.
temperature
)
# low temperature. This is not accurate.
t
.
append
(
sampling_params
.
temperature
if
sampling_params
.
temperature
>=
1e-5
else
1e-5
)
if
sampling_params
.
top_p
!=
1
and
not
_ENABLE_TOP_P
:
if
sampling_params
.
top_p
!=
1
and
not
_ENABLE_TOP_P
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Top-p sampling is currently disabled for the TPU backend "
"Top-p sampling is currently disabled for the TPU backend "
...
@@ -678,13 +677,23 @@ class ModelWrapper(nn.Module):
...
@@ -678,13 +677,23 @@ class ModelWrapper(nn.Module):
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
logits
/
t
.
unsqueeze
(
dim
=
1
)
# Argmax sampling.
argmax_token_ids
=
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
)
argmax_token_ids
=
argmax_token_ids
.
repeat
(
1
,
num_samples
)
# Zero temperature means greedy decoding. Avoid division by zero.
nonzero_t
=
torch
.
where
(
t
!=
0
,
t
,
1.0
)
logits
=
logits
/
nonzero_t
.
unsqueeze
(
dim
=
1
)
if
_ENABLE_TOP_P
:
if
_ENABLE_TOP_P
:
logits
=
_apply_top_p
(
logits
,
p
.
unsqueeze
(
dim
=
1
))
logits
=
_apply_top_p
(
logits
,
p
.
unsqueeze
(
dim
=
1
))
# Random sampling.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
next_token_ids
=
torch
.
multinomial
(
probs
,
sampled_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
,
num_samples
,
replacement
=
True
)
replacement
=
True
)
next_token_ids
=
torch
.
where
(
t
!=
0
,
sampled_token_ids
,
argmax_token_ids
)
return
next_token_ids
return
next_token_ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment