Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a74d1941
Unverified
Commit
a74d1941
authored
Dec 26, 2024
by
Zhizhou Sha
Committed by
GitHub
Dec 26, 2024
Browse files
[unittest] add unit test to test quant args of srt engine (#2574)
parent
3169e66c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
0 deletions
+60
-0
test/srt/test_srt_engine_with_quant_args.py
test/srt/test_srt_engine_with_quant_args.py
+60
-0
No files found.
test/srt/test_srt_engine_with_quant_args.py
0 → 100644
View file @
a74d1941
import
unittest
import
sglang
as
sgl
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class
TestSRTEngineWithQuantArgs
(
unittest
.
TestCase
):
def
test_1_quantization_args
(
self
):
# we only test fp8 because other methods are currenly depend on vllm. We can add other methods back to test after vllm depency is resolved.
quantization_args_list
=
[
# "awq",
"fp8"
,
# "gptq",
# "marlin",
# "gptq_marlin",
# "awq_marlin",
# "bitsandbytes",
# "gguf",
]
prompt
=
"Today is a sunny day and I like"
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
for
quantization_args
in
quantization_args_list
:
engine
=
sgl
.
Engine
(
model_path
=
model_path
,
random_seed
=
42
,
quantization
=
quantization_args
)
engine
.
generate
(
prompt
,
sampling_params
)
engine
.
shutdown
()
def
test_2_torchao_args
(
self
):
# we don't test int8dq because currently there is conflict between int8dq and capture cuda graph
torchao_args_list
=
[
# "int8dq",
"int8wo"
,
"fp8wo"
,
"fp8dq-per_tensor"
,
"fp8dq-per_row"
,
]
+
[
f
"int4wo-
{
group_size
}
"
for
group_size
in
[
32
,
64
,
128
,
256
]]
prompt
=
"Today is a sunny day and I like"
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
for
torchao_config
in
torchao_args_list
:
engine
=
sgl
.
Engine
(
model_path
=
model_path
,
random_seed
=
42
,
torchao_config
=
torchao_config
)
engine
.
generate
(
prompt
,
sampling_params
)
engine
.
shutdown
()
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment