Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d1744376
Unverified
Commit
d1744376
authored
Aug 16, 2023
by
Abraham-Xu
Committed by
GitHub
Aug 15, 2023
Browse files
Align with huggingface Top K sampling (#753)
parent
805de738
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
24 deletions
+23
-24
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+23
-24
No files found.
vllm/model_executor/layers/sampler.py
View file @
d1744376
...
@@ -71,19 +71,19 @@ class Sampler(nn.Module):
...
@@ -71,19 +71,19 @@ class Sampler(nn.Module):
# Use in-place division to avoid creating a new tensor.
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities (before applying top-p and top-k).
logprobs
=
torch
.
log
(
probs
)
# Apply top-p and top-k truncation.
# Apply top-p and top-k truncation.
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
prob
s
.
shape
[
0
]
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
logit
s
.
shape
[
0
]
do_top_p
=
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
do_top_p
=
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
do_top_k
=
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
)
do_top_k
=
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
)
if
do_top_p
or
do_top_k
:
if
do_top_p
or
do_top_k
:
probs
=
_apply_top_p_top_k
(
probs
,
top_ps
,
top_ks
)
logits
=
_apply_top_p_top_k
(
logits
,
top_ps
,
top_ks
)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities (before applying top-p and top-k).
logprobs
=
torch
.
log
(
probs
)
# Sample the next tokens.
# Sample the next tokens.
return
_sample
(
probs
,
logprobs
,
input_metadata
)
return
_sample
(
probs
,
logprobs
,
input_metadata
)
...
@@ -235,31 +235,32 @@ def _get_top_p_top_k(
...
@@ -235,31 +235,32 @@ def _get_top_p_top_k(
def
_apply_top_p_top_k
(
def
_apply_top_p_top_k
(
prob
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
top_ps
:
List
[
float
],
top_ps
:
List
[
float
],
top_ks
:
List
[
int
],
top_ks
:
List
[
int
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
p
=
torch
.
tensor
(
top_ps
,
dtype
=
prob
s
.
dtype
,
device
=
prob
s
.
device
)
p
=
torch
.
tensor
(
top_ps
,
dtype
=
logit
s
.
dtype
,
device
=
logit
s
.
device
)
k
=
torch
.
tensor
(
top_ks
,
dtype
=
torch
.
int
,
device
=
prob
s
.
device
)
k
=
torch
.
tensor
(
top_ks
,
dtype
=
torch
.
int
,
device
=
logit
s
.
device
)
prob
s_sort
,
prob
s_idx
=
prob
s
.
sort
(
dim
=-
1
,
descending
=
True
)
logit
s_sort
,
logit
s_idx
=
logit
s
.
sort
(
dim
=-
1
,
descending
=
True
)
# Apply top-p.
# Apply top-p.
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
top_p_mask
=
(
probs_sum
-
probs_sort
)
>
p
.
unsqueeze
(
dim
=
1
)
top_p_mask
=
(
probs_sum
-
probs_sort
)
>
p
.
unsqueeze
(
dim
=
1
)
prob
s_sort
[
top_p_mask
]
=
0.0
logit
s_sort
[
top_p_mask
]
=
-
float
(
"inf"
)
# Apply top-k.
# Apply top-k.
# Create a mask for the top-k elements.
# Create a mask for the top-k elements.
top_k_mask
=
torch
.
arange
(
prob
s_idx
.
shape
[
-
1
],
device
=
prob
s_idx
.
device
)
top_k_mask
=
torch
.
arange
(
logit
s_idx
.
shape
[
-
1
],
device
=
logit
s_idx
.
device
)
top_k_mask
=
top_k_mask
.
expand
(
prob
s_idx
.
shape
[
0
],
-
1
)
top_k_mask
=
top_k_mask
.
expand
(
logit
s_idx
.
shape
[
0
],
-
1
)
top_k_mask
=
top_k_mask
>=
k
.
unsqueeze
(
dim
=
1
)
top_k_mask
=
top_k_mask
>=
k
.
unsqueeze
(
dim
=
1
)
prob
s_sort
[
top_k_mask
]
=
0.0
logit
s_sort
[
top_k_mask
]
=
-
float
(
"inf"
)
# Re-sort the probabilities.
# Re-sort the probabilities.
prob
s
=
torch
.
gather
(
prob
s_sort
,
logit
s
=
torch
.
gather
(
logit
s_sort
,
dim
=-
1
,
dim
=-
1
,
index
=
torch
.
argsort
(
prob
s_idx
,
dim
=-
1
))
index
=
torch
.
argsort
(
logit
s_idx
,
dim
=-
1
))
return
prob
s
return
logit
s
def
_get_topk_logprobs
(
def
_get_topk_logprobs
(
...
@@ -301,9 +302,7 @@ def _sample_from_prompt(
...
@@ -301,9 +302,7 @@ def _sample_from_prompt(
# Random sampling.
# Random sampling.
# Sample `best_of` tokens for the prompt.
# Sample `best_of` tokens for the prompt.
num_seqs
=
sampling_params
.
best_of
num_seqs
=
sampling_params
.
best_of
next_token_ids
=
torch
.
multinomial
(
prob
,
next_token_ids
=
torch
.
multinomial
(
prob
,
num_samples
=
num_seqs
)
num_samples
=
num_seqs
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
return
next_token_ids
return
next_token_ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment