Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
02c777c0
Unverified
Commit
02c777c0
authored
Dec 23, 2024
by
Aryan
Committed by
GitHub
Dec 23, 2024
Browse files
[tests] Refactor TorchAO serialization fast tests (#10271)
refactor
parent
6a970a45
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
40 deletions
+35
-40
tests/quantization/torchao/test_torchao.py
tests/quantization/torchao/test_torchao.py
+35
-40
No files found.
tests/quantization/torchao/test_torchao.py
View file @
02c777c0
...
...
@@ -447,21 +447,19 @@ class TorchAoTest(unittest.TestCase):
self
.
get_dummy_components
(
TorchAoConfig
(
"int42"
))
#
This class is not to be run as a test by itself. See the tests that follow this clas
s
#
Slices for these tests have been obtained on our aws-g6e-xlarge-plus runner
s
@
require_torch
@
require_torch_gpu
@
require_torchao_version_greater_or_equal
(
"0.7.0"
)
class
TorchAoSerializationTest
(
unittest
.
TestCase
):
model_name
=
"hf-internal-testing/tiny-flux-pipe"
quant_method
,
quant_method_kwargs
=
None
,
None
device
=
"cuda"
def
tearDown
(
self
):
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
get_dummy_model
(
self
,
device
=
None
):
quantization_config
=
TorchAoConfig
(
self
.
quant_method
,
**
self
.
quant_method_kwargs
)
def
get_dummy_model
(
self
,
quant_method
,
quant_method_kwargs
,
device
=
None
):
quantization_config
=
TorchAoConfig
(
quant_method
,
**
quant_method_kwargs
)
quantized_model
=
FluxTransformer2DModel
.
from_pretrained
(
self
.
model_name
,
subfolder
=
"transformer"
,
...
...
@@ -497,15 +495,15 @@ class TorchAoSerializationTest(unittest.TestCase):
"timestep"
:
timestep
,
}
def
test_original_model_expected_slice
(
self
):
quantized_model
=
self
.
get_dummy_model
(
torch_device
)
def
_
test_original_model_expected_slice
(
self
,
quant_method
,
quant_method_kwargs
,
expected_slice
):
quantized_model
=
self
.
get_dummy_model
(
quant_method
,
quant_method_kwargs
,
torch_device
)
inputs
=
self
.
get_dummy_tensor_inputs
(
torch_device
)
output
=
quantized_model
(
**
inputs
)[
0
]
output_slice
=
output
.
flatten
()[
-
9
:].
detach
().
float
().
cpu
().
numpy
()
self
.
assertTrue
(
np
.
allclose
(
output_slice
,
self
.
expected_slice
,
atol
=
1e-3
,
rtol
=
1e-3
))
self
.
assertTrue
(
np
.
allclose
(
output_slice
,
expected_slice
,
atol
=
1e-3
,
rtol
=
1e-3
))
def
check_serialization_expected_slice
(
self
,
expected_slice
):
quantized_model
=
self
.
get_dummy_model
(
self
.
device
)
def
_
check_serialization_expected_slice
(
self
,
quant_method
,
quant_method_kwargs
,
expected_slice
,
device
):
quantized_model
=
self
.
get_dummy_model
(
quant_method
,
quant_method_kwargs
,
device
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
quantized_model
.
save_pretrained
(
tmp_dir
,
safe_serialization
=
False
)
...
...
@@ -524,36 +522,33 @@ class TorchAoSerializationTest(unittest.TestCase):
)
self
.
assertTrue
(
np
.
allclose
(
output_slice
,
expected_slice
,
atol
=
1e-3
,
rtol
=
1e-3
))
def
test_serialization_expected_slice
(
self
):
self
.
check_serialization_expected_slice
(
self
.
serialized_expected_slice
)
class
TorchAoSerializationINTA8W8Test
(
TorchAoSerializationTest
):
quant_method
,
quant_method_kwargs
=
"int8_dynamic_activation_int8_weight"
,
{}
expected_slice
=
np
.
array
([
0.3633
,
-
0.1357
,
-
0.0188
,
-
0.249
,
-
0.4688
,
0.5078
,
-
0.1289
,
-
0.6914
,
0.4551
])
serialized_expected_slice
=
expected_slice
device
=
"cuda"
class
TorchAoSerializationINTA16W8Test
(
TorchAoSerializationTest
):
quant_method
,
quant_method_kwargs
=
"int8_weight_only"
,
{}
expected_slice
=
np
.
array
([
0.3613
,
-
0.127
,
-
0.0223
,
-
0.2539
,
-
0.459
,
0.4961
,
-
0.1357
,
-
0.6992
,
0.4551
])
serialized_expected_slice
=
expected_slice
device
=
"cuda"
class
TorchAoSerializationINTA8W8CPUTest
(
TorchAoSerializationTest
):
quant_method
,
quant_method_kwargs
=
"int8_dynamic_activation_int8_weight"
,
{}
expected_slice
=
np
.
array
([
0.3633
,
-
0.1357
,
-
0.0188
,
-
0.249
,
-
0.4688
,
0.5078
,
-
0.1289
,
-
0.6914
,
0.4551
])
serialized_expected_slice
=
expected_slice
device
=
"cpu"
class
TorchAoSerializationINTA16W8CPUTest
(
TorchAoSerializationTest
):
quant_method
,
quant_method_kwargs
=
"int8_weight_only"
,
{}
expected_slice
=
np
.
array
([
0.3613
,
-
0.127
,
-
0.0223
,
-
0.2539
,
-
0.459
,
0.4961
,
-
0.1357
,
-
0.6992
,
0.4551
])
serialized_expected_slice
=
expected_slice
device
=
"cpu"
def
test_int_a8w8_cuda
(
self
):
quant_method
,
quant_method_kwargs
=
"int8_dynamic_activation_int8_weight"
,
{}
expected_slice
=
np
.
array
([
0.3633
,
-
0.1357
,
-
0.0188
,
-
0.249
,
-
0.4688
,
0.5078
,
-
0.1289
,
-
0.6914
,
0.4551
])
device
=
"cuda"
self
.
_test_original_model_expected_slice
(
quant_method
,
quant_method_kwargs
,
expected_slice
)
self
.
_check_serialization_expected_slice
(
quant_method
,
quant_method_kwargs
,
expected_slice
,
device
)
def
test_int_a16w8_cuda
(
self
):
quant_method
,
quant_method_kwargs
=
"int8_weight_only"
,
{}
expected_slice
=
np
.
array
([
0.3613
,
-
0.127
,
-
0.0223
,
-
0.2539
,
-
0.459
,
0.4961
,
-
0.1357
,
-
0.6992
,
0.4551
])
device
=
"cuda"
self
.
_test_original_model_expected_slice
(
quant_method
,
quant_method_kwargs
,
expected_slice
)
self
.
_check_serialization_expected_slice
(
quant_method
,
quant_method_kwargs
,
expected_slice
,
device
)
def
test_int_a8w8_cpu
(
self
):
quant_method
,
quant_method_kwargs
=
"int8_dynamic_activation_int8_weight"
,
{}
expected_slice
=
np
.
array
([
0.3633
,
-
0.1357
,
-
0.0188
,
-
0.249
,
-
0.4688
,
0.5078
,
-
0.1289
,
-
0.6914
,
0.4551
])
device
=
"cpu"
self
.
_test_original_model_expected_slice
(
quant_method
,
quant_method_kwargs
,
expected_slice
)
self
.
_check_serialization_expected_slice
(
quant_method
,
quant_method_kwargs
,
expected_slice
,
device
)
def
test_int_a16w8_cpu
(
self
):
quant_method
,
quant_method_kwargs
=
"int8_weight_only"
,
{}
expected_slice
=
np
.
array
([
0.3613
,
-
0.127
,
-
0.0223
,
-
0.2539
,
-
0.459
,
0.4961
,
-
0.1357
,
-
0.6992
,
0.4551
])
device
=
"cpu"
self
.
_test_original_model_expected_slice
(
quant_method
,
quant_method_kwargs
,
expected_slice
)
self
.
_check_serialization_expected_slice
(
quant_method
,
quant_method_kwargs
,
expected_slice
,
device
)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment