Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8cef6e02
Unverified
Commit
8cef6e02
authored
Dec 23, 2024
by
Dipika Sikka
Committed by
GitHub
Dec 23, 2024
Browse files
[Misc] add w8a8 asym models (#11075)
parent
b866cdbd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
6 deletions
+10
-6
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+10
-6
No files found.
tests/quantization/test_compressed_tensors.py
View file @
8cef6e02
...
@@ -79,12 +79,12 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
...
@@ -79,12 +79,12 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
assert
output
assert
output
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model_path"
,
[
"
model_path
"
,
"
neuralmagic/Llama-3.2-1B-quantized.w8a8
"
,
[
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym"
,
"n
euralmagic/Llama-3.2-1B-quantized.w8a8"
"n
m-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym"
,
# TODO static & asymmetric
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
])
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
def
test_compressed_tensors_w8a8_logprobs
(
hf_runner
,
vllm_runner
,
def
test_compressed_tensors_w8a8_logprobs
(
hf_runner
,
vllm_runner
,
...
@@ -92,6 +92,10 @@ def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
...
@@ -92,6 +92,10 @@ def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
max_tokens
,
num_logprobs
):
max_tokens
,
num_logprobs
):
dtype
=
"bfloat16"
dtype
=
"bfloat16"
# skip language translation prompt for the static per tensor asym model
if
model_path
==
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
:
# noqa: E501
example_prompts
=
example_prompts
[
0
:
-
1
]
with
hf_runner
(
model_path
,
dtype
=
dtype
)
as
hf_model
:
with
hf_runner
(
model_path
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment