Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db5a29ba
Unverified
Commit
db5a29ba
authored
May 22, 2025
by
Jee Jee Li
Committed by
GitHub
May 21, 2025
Browse files
[Bugfix] Fix LoRA test (#18518)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
51797775
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
65 deletions
+73
-65
tests/lora/test_lora_functions.py
tests/lora/test_lora_functions.py
+1
-1
tests/v1/sample/test_topk_topp_sampler.py
tests/v1/sample/test_topk_topp_sampler.py
+72
-64
No files found.
tests/lora/test_lora_functions.py
View file @
db5a29ba
...
@@ -69,7 +69,7 @@ def test_lora_functions_sync():
...
@@ -69,7 +69,7 @@ def test_lora_functions_sync():
run_check
(
llm
.
add_lora
,
make_lora_request
(
12
),
[
12
,
9
,
10
,
11
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
12
),
[
12
,
9
,
10
,
11
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
13
),
[
12
,
13
,
10
,
11
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
13
),
[
12
,
13
,
10
,
11
])
# Remove all LoRAs
# Remove all LoRAs
.
run_check
(
llm
.
remove_lora
,
13
,
[
12
,
10
,
11
])
run_check
(
llm
.
remove_lora
,
13
,
[
12
,
10
,
11
])
run_check
(
llm
.
remove_lora
,
12
,
[
10
,
11
])
run_check
(
llm
.
remove_lora
,
12
,
[
10
,
11
])
run_check
(
llm
.
remove_lora
,
11
,
[
10
])
run_check
(
llm
.
remove_lora
,
11
,
[
10
])
...
...
tests/v1/sample/test_topk_topp_sampler.py
View file @
db5a29ba
...
@@ -16,31 +16,40 @@ VOCAB_SIZE = 128 * 1024
...
@@ -16,31 +16,40 @@ VOCAB_SIZE = 128 * 1024
FLASHINFER_ENABLED
=
current_platform
.
is_cuda
()
and
is_flashinfer_available
FLASHINFER_ENABLED
=
current_platform
.
is_cuda
()
and
is_flashinfer_available
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_default_device
():
"""
Explicitly set the default device, which can affect subsequent tests.
Adding this fixture helps avoid this problem.
"""
original_device
=
torch
.
get_default_device
()
yield
torch
.
set_default_device
(
original_device
)
def
test_topk_impl_equivalance
():
def
test_topk_impl_equivalance
():
with
torch
.
device
(
DEVICE
)
:
torch
.
set_default_
device
(
DEVICE
)
generator
=
Generator
(
device
=
DEVICE
).
manual_seed
(
33
)
generator
=
Generator
(
device
=
DEVICE
).
manual_seed
(
33
)
logits
=
torch
.
rand
((
BATCH_SIZE
,
VOCAB_SIZE
),
generator
=
generator
)
logits
=
torch
.
rand
((
BATCH_SIZE
,
VOCAB_SIZE
),
generator
=
generator
)
# Random top-k values between 1 and 9.
# Random top-k values between 1 and 9.
k
=
torch
.
randint
(
1
,
10
,
(
BATCH_SIZE
,
),
generator
=
generator
)
k
=
torch
.
randint
(
1
,
10
,
(
BATCH_SIZE
,
),
generator
=
generator
)
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k
.
masked_fill_
(
k
.
masked_fill_
(
torch
.
randint
(
0
,
torch
.
randint
(
0
,
2
,
(
BATCH_SIZE
,
),
generator
=
generator
,
dtype
=
bool
),
2
,
(
BATCH_SIZE
,
),
VOCAB_SIZE
)
generator
=
generator
,
dtype
=
bool
),
VOCAB_SIZE
)
# Top-k only implementation
# Top-k only implementation
result1
=
apply_top_k_top_p
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
None
)
result1
=
apply_top_k_top_p
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
None
)
# Top-p + top-k
# Top-p + top-k
no_op_top_p
=
torch
.
tensor
([
1.0
])
no_op_top_p
=
torch
.
tensor
([
1.0
])
result2
=
apply_top_k_top_p
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
no_op_top_p
)
result2
=
apply_top_k_top_p
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
no_op_top_p
)
assert
torch
.
allclose
(
result1
,
result2
)
assert
torch
.
allclose
(
result1
,
result2
)
def
test_flashinfer_sampler
():
def
test_flashinfer_sampler
():
...
@@ -58,50 +67,49 @@ def test_flashinfer_sampler():
...
@@ -58,50 +67,49 @@ def test_flashinfer_sampler():
pytest
.
skip
(
pytest
.
skip
(
"FlashInfer not installed or not available on this platform."
)
"FlashInfer not installed or not available on this platform."
)
with
torch
.
device
(
DEVICE
):
torch
.
set_default_device
(
DEVICE
)
generator
=
Generator
(
device
=
DEVICE
).
manual_seed
(
42
)
generator
=
Generator
(
device
=
DEVICE
).
manual_seed
(
42
)
# Generate random logits
# Generate random logits
logits
=
torch
.
rand
((
BATCH_SIZE
,
VOCAB_SIZE
),
generator
=
generator
)
logits
=
torch
.
rand
((
BATCH_SIZE
,
VOCAB_SIZE
),
generator
=
generator
)
# Generate various top-k and top-p values
# Generate various top-k and top-p values
k_values
=
torch
.
randint
(
1
,
1000
,
(
BATCH_SIZE
,
),
generator
=
generator
)
k_values
=
torch
.
randint
(
1
,
1000
,
(
BATCH_SIZE
,
),
generator
=
generator
)
p_values
=
torch
.
rand
(
p_values
=
torch
.
rand
(
(
BATCH_SIZE
,
),
(
BATCH_SIZE
,
),
generator
=
generator
)
*
0.5
+
0.5
# range in [0.5, 1.0]
generator
=
generator
)
*
0.5
+
0.5
# range in [0.5, 1.0]
# Sometimes disable top-k (k=vocab_size)
# Sometimes disable top-k (k=vocab_size)
k_values
.
masked_fill_
(
k_values
.
masked_fill_
(
torch
.
randint
(
0
,
torch
.
randint
(
0
,
2
,
(
BATCH_SIZE
,
),
2
,
(
BATCH_SIZE
,
),
generator
=
generator
,
generator
=
generator
,
dtype
=
torch
.
bool
),
VOCAB_SIZE
)
dtype
=
torch
.
bool
),
VOCAB_SIZE
)
# Sometimes disable top-p (p=1.0)
# Sometimes disable top-p (p=1.0)
p_values
.
masked_fill_
(
p_values
.
masked_fill_
(
torch
.
randint
(
0
,
torch
.
randint
(
0
,
2
,
(
BATCH_SIZE
,
),
2
,
(
BATCH_SIZE
,
),
generator
=
generator
,
generator
=
generator
,
dtype
=
torch
.
bool
),
1.0
)
dtype
=
torch
.
bool
),
1.0
)
python_logits
=
apply_top_k_top_p
(
python_logits
=
apply_top_k_top_p
(
logits
=
logits
.
clone
(),
logits
=
logits
.
clone
(),
k
=
k_values
,
k
=
k_values
,
p
=
p_values
,
p
=
p_values
,
)
)
python_probs
=
torch
.
softmax
(
python_logits
,
dim
=-
1
)
python_probs
=
torch
.
softmax
(
python_logits
,
dim
=-
1
)
# FlashInfer only exposed renorm interfaces for probs so convert first
# FlashInfer only exposed renorm interfaces for probs so convert first
flashinfer_probs
=
torch
.
softmax
(
logits
.
clone
(),
dim
=-
1
)
flashinfer_probs
=
torch
.
softmax
(
logits
.
clone
(),
dim
=-
1
)
flashinfer_probs
=
top_k_renorm_probs
(
flashinfer_probs
=
top_k_renorm_probs
(
probs
=
flashinfer_probs
,
probs
=
flashinfer_probs
,
top_k
=
k_values
,
top_k
=
k_values
,
)
)
flashinfer_probs
=
top_p_renorm_probs
(
flashinfer_probs
=
top_p_renorm_probs
(
probs
=
flashinfer_probs
,
probs
=
flashinfer_probs
,
top_p
=
p_values
,
top_p
=
p_values
,
)
)
# Compare the results
# Compare the results
assert
torch
.
allclose
(
python_probs
,
flashinfer_probs
,
atol
=
2e-2
),
\
assert
torch
.
allclose
(
python_probs
,
flashinfer_probs
,
atol
=
2e-2
),
\
"FlashInfer and Python sampling implementations do not match!"
"FlashInfer and Python sampling implementations do not match!"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment