"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d2df40c6f39b7ffde309961c67bdf2c4e2913b8c"
Unverified Commit 02c777c0 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[tests] Refactor TorchAO serialization fast tests (#10271)

refactor
parent 6a970a45
...@@ -447,21 +447,19 @@ class TorchAoTest(unittest.TestCase): ...@@ -447,21 +447,19 @@ class TorchAoTest(unittest.TestCase):
self.get_dummy_components(TorchAoConfig("int42")) self.get_dummy_components(TorchAoConfig("int42"))
# This class is not to be run as a test by itself. See the tests that follow this class # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch @require_torch
@require_torch_gpu @require_torch_gpu
@require_torchao_version_greater_or_equal("0.7.0") @require_torchao_version_greater_or_equal("0.7.0")
class TorchAoSerializationTest(unittest.TestCase): class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe" model_name = "hf-internal-testing/tiny-flux-pipe"
quant_method, quant_method_kwargs = None, None
device = "cuda"
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_dummy_model(self, device=None): def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs) quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
quantized_model = FluxTransformer2DModel.from_pretrained( quantized_model = FluxTransformer2DModel.from_pretrained(
self.model_name, self.model_name,
subfolder="transformer", subfolder="transformer",
...@@ -497,15 +495,15 @@ class TorchAoSerializationTest(unittest.TestCase): ...@@ -497,15 +495,15 @@ class TorchAoSerializationTest(unittest.TestCase):
"timestep": timestep, "timestep": timestep,
} }
def test_original_model_expected_slice(self): def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
quantized_model = self.get_dummy_model(torch_device) quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
inputs = self.get_dummy_tensor_inputs(torch_device) inputs = self.get_dummy_tensor_inputs(torch_device)
output = quantized_model(**inputs)[0] output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3)) self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def check_serialization_expected_slice(self, expected_slice): def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
quantized_model = self.get_dummy_model(self.device) quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
quantized_model.save_pretrained(tmp_dir, safe_serialization=False) quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
...@@ -524,36 +522,33 @@ class TorchAoSerializationTest(unittest.TestCase): ...@@ -524,36 +522,33 @@ class TorchAoSerializationTest(unittest.TestCase):
) )
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_serialization_expected_slice(self): def test_int_a8w8_cuda(self):
self.check_serialization_expected_slice(self.serialized_expected_slice) quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cuda"
class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest): self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
serialized_expected_slice = expected_slice def test_int_a16w8_cuda(self):
device = "cuda" quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cuda"
class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest): self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
quant_method, quant_method_kwargs = "int8_weight_only", {} self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
serialized_expected_slice = expected_slice def test_int_a8w8_cpu(self):
device = "cuda" quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cpu"
class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest): self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
serialized_expected_slice = expected_slice def test_int_a16w8_cpu(self):
device = "cpu" quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cpu"
class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest): self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
quant_method, quant_method_kwargs = "int8_weight_only", {} self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
serialized_expected_slice = expected_slice
device = "cpu"
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment