Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
86c2d8fd
Unverified
Commit
86c2d8fd
authored
Dec 20, 2024
by
Wallas Henrique
Committed by
GitHub
Dec 20, 2024
Browse files
[Bugfix] Fix spec decoding when seed is none in a batch (#10863)
Signed-off-by:
Wallas Santos
<
wallashss@ibm.com
>
parent
b880ffb8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
7 deletions
+66
-7
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+63
-0
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+3
-7
No files found.
tests/samplers/test_rejection_sampler.py
View file @
86c2d8fd
...
...
@@ -200,6 +200,69 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
assert
torch
.
equal
(
results
[
j
][
i
],
results
[
0
][
i
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
1
,
3
,
6
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
3
,
8
,
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
])
@
torch
.
inference_mode
()
def
test_mixed_seeded_batch
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
device
:
str
,
use_flashinfer
:
bool
):
torch
.
set_default_device
(
device
)
set_random_seed
(
0
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
vocab_size
,
dtype
=
torch
.
float32
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
)
single_batches
=
[]
for
i
in
range
(
batch_size
):
single_batches
.
append
((
draft_probs
[
i
].
clone
().
unsqueeze
(
0
),
draft_token_ids
[
i
].
clone
().
unsqueeze
(
0
),
target_probs
[
i
].
clone
().
unsqueeze
(
0
),
bonus_token_ids
[
i
].
clone
().
unsqueeze
(
0
),
draft_token_ids
[
i
].
clone
().
unsqueeze
(
0
)))
set_random_seed
(
0
)
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
results
=
[]
seeded_seqs
=
{
i
:
torch
.
Generator
(
device
=
device
).
manual_seed
(
i
)
for
i
in
range
(
1
,
batch_size
)
# 0 is seed None
}
batch_result
=
rejection_sampler
(
target_probs
.
clone
(),
bonus_token_ids
.
clone
(),
draft_probs
.
clone
(),
draft_token_ids
.
clone
(),
seeded_seqs
)
set_random_seed
(
0
)
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
for
i
in
range
(
batch_size
):
request_seeded_seqs
=
{
0
:
torch
.
Generator
(
device
=
device
).
manual_seed
(
i
)
}
if
seeded_seqs
.
get
(
i
)
is
not
None
else
None
(
draft_probs
,
draft_token_ids
,
target_probs
,
bonus_token_ids
,
draft_token_ids
)
=
single_batches
[
i
]
results
.
append
(
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
,
request_seeded_seqs
))
for
i
in
range
(
batch_size
):
assert
torch
.
equal
(
batch_result
[
i
],
results
[
i
].
squeeze
(
0
))
@
pytest
.
mark
.
parametrize
(
"k"
,
[
1
,
3
,
6
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
8
,
32
,
128
])
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
86c2d8fd
from
functools
import
cached_property
from
importlib.util
import
find_spec
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch.jit
...
...
@@ -386,16 +386,12 @@ def _multinomial(
if
not
seeded_seqs
:
q
.
exponential_
(
1.0
)
else
:
non_seeded_indices
:
List
[
int
]
=
[]
start
=
0
for
idx
in
range
(
len
(
q
)
//
k
):
end
=
start
+
k
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
extend
(
list
(
range
(
start
,
end
)))
else
:
# Note: generator might be None for non seeded
q
[
start
:
end
].
exponential_
(
1.0
,
generator
=
generator
)
start
=
end
q
[
non_seeded_indices
].
exponential_
(
1.0
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment