Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db5a29ba
Unverified
Commit
db5a29ba
authored
May 22, 2025
by
Jee Jee Li
Committed by
GitHub
May 21, 2025
Browse files
[Bugfix] Fix LoRA test (#18518)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
51797775
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
65 deletions
+73
-65
tests/lora/test_lora_functions.py
tests/lora/test_lora_functions.py
+1
-1
tests/v1/sample/test_topk_topp_sampler.py
tests/v1/sample/test_topk_topp_sampler.py
+72
-64
No files found.
tests/lora/test_lora_functions.py
View file @
db5a29ba
...
@@ -69,7 +69,7 @@ def test_lora_functions_sync():
...
@@ -69,7 +69,7 @@ def test_lora_functions_sync():
run_check
(
llm
.
add_lora
,
make_lora_request
(
12
),
[
12
,
9
,
10
,
11
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
12
),
[
12
,
9
,
10
,
11
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
13
),
[
12
,
13
,
10
,
11
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
13
),
[
12
,
13
,
10
,
11
])
# Remove all LoRAs
# Remove all LoRAs
.
run_check
(
llm
.
remove_lora
,
13
,
[
12
,
10
,
11
])
run_check
(
llm
.
remove_lora
,
13
,
[
12
,
10
,
11
])
run_check
(
llm
.
remove_lora
,
12
,
[
10
,
11
])
run_check
(
llm
.
remove_lora
,
12
,
[
10
,
11
])
run_check
(
llm
.
remove_lora
,
11
,
[
10
])
run_check
(
llm
.
remove_lora
,
11
,
[
10
])
...
...
tests/v1/sample/test_topk_topp_sampler.py
View file @
db5a29ba
...
@@ -16,9 +16,20 @@ VOCAB_SIZE = 128 * 1024
...
@@ -16,9 +16,20 @@ VOCAB_SIZE = 128 * 1024
FLASHINFER_ENABLED
=
current_platform
.
is_cuda
()
and
is_flashinfer_available
FLASHINFER_ENABLED
=
current_platform
.
is_cuda
()
and
is_flashinfer_available
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_default_device
():
"""
Explicitly set the default device, which can affect subsequent tests.
Adding this fixture helps avoid this problem.
"""
original_device
=
torch
.
get_default_device
()
yield
torch
.
set_default_device
(
original_device
)
def
test_topk_impl_equivalance
():
def
test_topk_impl_equivalance
():
with
torch
.
device
(
DEVICE
)
:
torch
.
set_default_
device
(
DEVICE
)
generator
=
Generator
(
device
=
DEVICE
).
manual_seed
(
33
)
generator
=
Generator
(
device
=
DEVICE
).
manual_seed
(
33
)
logits
=
torch
.
rand
((
BATCH_SIZE
,
VOCAB_SIZE
),
generator
=
generator
)
logits
=
torch
.
rand
((
BATCH_SIZE
,
VOCAB_SIZE
),
generator
=
generator
)
...
@@ -28,10 +39,8 @@ def test_topk_impl_equivalance():
...
@@ -28,10 +39,8 @@ def test_topk_impl_equivalance():
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k
.
masked_fill_
(
k
.
masked_fill_
(
torch
.
randint
(
0
,
torch
.
randint
(
0
,
2
,
(
BATCH_SIZE
,
),
generator
=
generator
,
dtype
=
bool
),
2
,
(
BATCH_SIZE
,
),
VOCAB_SIZE
)
generator
=
generator
,
dtype
=
bool
),
VOCAB_SIZE
)
# Top-k only implementation
# Top-k only implementation
result1
=
apply_top_k_top_p
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
None
)
result1
=
apply_top_k_top_p
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
None
)
...
@@ -58,7 +67,7 @@ def test_flashinfer_sampler():
...
@@ -58,7 +67,7 @@ def test_flashinfer_sampler():
pytest
.
skip
(
pytest
.
skip
(
"FlashInfer not installed or not available on this platform."
)
"FlashInfer not installed or not available on this platform."
)
with
torch
.
device
(
DEVICE
)
:
torch
.
set_default_
device
(
DEVICE
)
generator
=
Generator
(
device
=
DEVICE
).
manual_seed
(
42
)
generator
=
Generator
(
device
=
DEVICE
).
manual_seed
(
42
)
# Generate random logits
# Generate random logits
...
@@ -67,8 +76,7 @@ def test_flashinfer_sampler():
...
@@ -67,8 +76,7 @@ def test_flashinfer_sampler():
# Generate various top-k and top-p values
# Generate various top-k and top-p values
k_values
=
torch
.
randint
(
1
,
1000
,
(
BATCH_SIZE
,
),
generator
=
generator
)
k_values
=
torch
.
randint
(
1
,
1000
,
(
BATCH_SIZE
,
),
generator
=
generator
)
p_values
=
torch
.
rand
(
p_values
=
torch
.
rand
(
(
BATCH_SIZE
,
),
(
BATCH_SIZE
,
),
generator
=
generator
)
*
0.5
+
0.5
# range in [0.5, 1.0]
generator
=
generator
)
*
0.5
+
0.5
# range in [0.5, 1.0]
# Sometimes disable top-k (k=vocab_size)
# Sometimes disable top-k (k=vocab_size)
k_values
.
masked_fill_
(
k_values
.
masked_fill_
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment