Unverified Commit cdaf84a7 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

TorchAO compile + offloading tests (#11697)

* update

* update

* update

* update

* update

* user property instead
parent e8e44a51
...@@ -866,15 +866,17 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests): ...@@ -866,15 +866,17 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
@require_torch_version_greater("2.7.1") @require_torch_version_greater("2.7.1")
class Bnb4BitCompileTests(QuantCompileTests): class Bnb4BitCompileTests(QuantCompileTests):
quantization_config = PipelineQuantizationConfig( @property
quant_backend="bitsandbytes_8bit", def quantization_config(self):
quant_kwargs={ return PipelineQuantizationConfig(
"load_in_4bit": True, quant_backend="bitsandbytes_8bit",
"bnb_4bit_quant_type": "nf4", quant_kwargs={
"bnb_4bit_compute_dtype": torch.bfloat16, "load_in_4bit": True,
}, "bnb_4bit_quant_type": "nf4",
components_to_quantize=["transformer", "text_encoder_2"], "bnb_4bit_compute_dtype": torch.bfloat16,
) },
components_to_quantize=["transformer", "text_encoder_2"],
)
def test_torch_compile(self): def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.capture_dynamic_output_shape_ops = True
...@@ -883,5 +885,7 @@ class Bnb4BitCompileTests(QuantCompileTests): ...@@ -883,5 +885,7 @@ class Bnb4BitCompileTests(QuantCompileTests):
def test_torch_compile_with_cpu_offload(self): def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload(self): def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config) super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, use_stream=True
)
...@@ -831,11 +831,13 @@ class BaseBnb8bitSerializationTests(Base8bitTests): ...@@ -831,11 +831,13 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
@require_torch_version_greater_equal("2.6.0") @require_torch_version_greater_equal("2.6.0")
class Bnb8BitCompileTests(QuantCompileTests): class Bnb8BitCompileTests(QuantCompileTests):
quantization_config = PipelineQuantizationConfig( @property
quant_backend="bitsandbytes_8bit", def quantization_config(self):
quant_kwargs={"load_in_8bit": True}, return PipelineQuantizationConfig(
components_to_quantize=["transformer", "text_encoder_2"], quant_backend="bitsandbytes_8bit",
) quant_kwargs={"load_in_8bit": True},
components_to_quantize=["transformer", "text_encoder_2"],
)
def test_torch_compile(self): def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.capture_dynamic_output_shape_ops = True
...@@ -847,7 +849,7 @@ class Bnb8BitCompileTests(QuantCompileTests): ...@@ -847,7 +849,7 @@ class Bnb8BitCompileTests(QuantCompileTests):
) )
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.") @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload(self): def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload( super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, torch_dtype=torch.float16 quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
) )
...@@ -24,7 +24,11 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu ...@@ -24,7 +24,11 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu
@require_torch_gpu @require_torch_gpu
@slow @slow
class QuantCompileTests(unittest.TestCase): class QuantCompileTests(unittest.TestCase):
quantization_config = None @property
def quantization_config(self):
raise NotImplementedError(
"This property should be implemented in the subclass to return the appropriate quantization config."
)
def setUp(self): def setUp(self):
super().setUp() super().setUp()
...@@ -64,7 +68,9 @@ class QuantCompileTests(unittest.TestCase): ...@@ -64,7 +68,9 @@ class QuantCompileTests(unittest.TestCase):
# small resolutions to ensure speedy execution. # small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16): def _test_torch_compile_with_group_offload_leaf(
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
):
torch._dynamo.config.cache_size_limit = 10000 torch._dynamo.config.cache_size_limit = 10000
pipe = self._init_pipeline(quantization_config, torch_dtype) pipe = self._init_pipeline(quantization_config, torch_dtype)
...@@ -72,8 +78,7 @@ class QuantCompileTests(unittest.TestCase): ...@@ -72,8 +78,7 @@ class QuantCompileTests(unittest.TestCase):
"onload_device": torch.device("cuda"), "onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"), "offload_device": torch.device("cpu"),
"offload_type": "leaf_level", "offload_type": "leaf_level",
"use_stream": True, "use_stream": use_stream,
"non_blocking": True,
} }
pipe.transformer.enable_group_offload(**group_offload_kwargs) pipe.transformer.enable_group_offload(**group_offload_kwargs)
pipe.transformer.compile() pipe.transformer.compile()
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
from typing import List from typing import List
import numpy as np import numpy as np
from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import ( from diffusers import (
...@@ -29,6 +30,7 @@ from diffusers import ( ...@@ -29,6 +30,7 @@ from diffusers import (
TorchAoConfig, TorchAoConfig,
) )
from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import Attention
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache, backend_empty_cache,
backend_synchronize, backend_synchronize,
...@@ -44,6 +46,8 @@ from diffusers.utils.testing_utils import ( ...@@ -44,6 +46,8 @@ from diffusers.utils.testing_utils import (
torch_device, torch_device,
) )
from ..test_torch_compile_utils import QuantCompileTests
enable_full_determinism() enable_full_determinism()
...@@ -625,6 +629,53 @@ class TorchAoSerializationTest(unittest.TestCase): ...@@ -625,6 +629,53 @@ class TorchAoSerializationTest(unittest.TestCase):
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
},
)
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
@unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling."
)
def test_torch_compile_with_cpu_offload(self):
# RuntimeError: _apply(): Couldn't swap Linear.weight
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
@unittest.skip(
"""
For `use_stream=False`:
- Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
For `use_stream=True`:
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
"""
)
@parameterized.expand([False, True])
def test_torch_compile_with_group_offload_leaf(self):
# For use_stream=False:
# If we run group offloading without compilation, we will see:
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
# When running with compilation, the error ends up being different:
# Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
# requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
# Looks like something that will have to be looked into upstream.
# for linear layers, weight.tensor_impl shows cuda... but:
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
# For use_stream=True:
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch @require_torch
@require_torch_accelerator @require_torch_accelerator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment