Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9c772ac8
Unverified
Commit
9c772ac8
authored
May 06, 2024
by
Younes Belkada
Committed by
GitHub
May 06, 2024
Browse files
Quantization / HQQ: Fix HQQ tests on our runner (#30668)
Update test_hqq.py
parent
a45c5148
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
tests/quantization/hqq/test_hqq.py
tests/quantization/hqq/test_hqq.py
+3
-3
No files found.
tests/quantization/hqq/test_hqq.py
View file @
9c772ac8
...
@@ -35,7 +35,7 @@ if is_hqq_available():
...
@@ -35,7 +35,7 @@ if is_hqq_available():
class
HQQLLMRunner
:
class
HQQLLMRunner
:
def
__init__
(
self
,
model_id
,
quant_config
,
compute_dtype
,
device
,
cache_dir
):
def
__init__
(
self
,
model_id
,
quant_config
,
compute_dtype
,
device
,
cache_dir
=
None
):
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_id
,
model_id
,
torch_dtype
=
compute_dtype
,
torch_dtype
=
compute_dtype
,
...
@@ -118,7 +118,7 @@ class HQQTest(unittest.TestCase):
...
@@ -118,7 +118,7 @@ class HQQTest(unittest.TestCase):
check_hqqlayer
(
self
,
hqq_runner
.
model
.
model
.
layers
[
0
].
self_attn
.
v_proj
)
check_hqqlayer
(
self
,
hqq_runner
.
model
.
model
.
layers
[
0
].
self_attn
.
v_proj
)
check_forward
(
self
,
hqq_runner
.
model
)
check_forward
(
self
,
hqq_runner
.
model
)
def
test_
bfp
16_quantized_model_with_offloading
(
self
):
def
test_
f
16_quantized_model_with_offloading
(
self
):
"""
"""
Simple LLM model testing bfp16 with meta-data offloading
Simple LLM model testing bfp16 with meta-data offloading
"""
"""
...
@@ -137,7 +137,7 @@ class HQQTest(unittest.TestCase):
...
@@ -137,7 +137,7 @@ class HQQTest(unittest.TestCase):
)
)
hqq_runner
=
HQQLLMRunner
(
hqq_runner
=
HQQLLMRunner
(
model_id
=
MODEL_ID
,
quant_config
=
quant_config
,
compute_dtype
=
torch
.
b
float16
,
device
=
torch_device
model_id
=
MODEL_ID
,
quant_config
=
quant_config
,
compute_dtype
=
torch
.
float16
,
device
=
torch_device
)
)
check_hqqlayer
(
self
,
hqq_runner
.
model
.
model
.
layers
[
0
].
self_attn
.
v_proj
)
check_hqqlayer
(
self
,
hqq_runner
.
model
.
model
.
layers
[
0
].
self_attn
.
v_proj
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment