"...text-generation-inference.git" did not exist on "686cc6671705c666b767fffe71b2ed9c9b6fccd1"
Unverified Commit 7a935a0b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] Unify compilation + offloading tests in quantization (#11910)

* unify the quant compile + offloading tests.

* fix

* update
parent 941b7fc0
...@@ -873,11 +873,11 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests): ...@@ -873,11 +873,11 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
@require_torch_version_greater("2.7.1") @require_torch_version_greater("2.7.1")
@require_bitsandbytes_version_greater("0.45.5") @require_bitsandbytes_version_greater("0.45.5")
class Bnb4BitCompileTests(QuantCompileTests): class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
@property @property
def quantization_config(self): def quantization_config(self):
return PipelineQuantizationConfig( return PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit", quant_backend="bitsandbytes_4bit",
quant_kwargs={ quant_kwargs={
"load_in_4bit": True, "load_in_4bit": True,
"bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_type": "nf4",
...@@ -888,12 +888,7 @@ class Bnb4BitCompileTests(QuantCompileTests): ...@@ -888,12 +888,7 @@ class Bnb4BitCompileTests(QuantCompileTests):
def test_torch_compile(self): def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config) super().test_torch_compile()
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload_leaf(self): def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf( super()._test_torch_compile_with_group_offload_leaf(use_stream=True)
quantization_config=self.quantization_config, use_stream=True
)
...@@ -838,7 +838,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests): ...@@ -838,7 +838,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
@require_torch_version_greater_equal("2.6.0") @require_torch_version_greater_equal("2.6.0")
@require_bitsandbytes_version_greater("0.45.5") @require_bitsandbytes_version_greater("0.45.5")
class Bnb8BitCompileTests(QuantCompileTests): class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
@property @property
def quantization_config(self): def quantization_config(self):
return PipelineQuantizationConfig( return PipelineQuantizationConfig(
...@@ -849,15 +849,11 @@ class Bnb8BitCompileTests(QuantCompileTests): ...@@ -849,15 +849,11 @@ class Bnb8BitCompileTests(QuantCompileTests):
def test_torch_compile(self): def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16) super()._test_torch_compile(torch_dtype=torch.float16)
def test_torch_compile_with_cpu_offload(self): def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload( super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
quantization_config=self.quantization_config, torch_dtype=torch.float16
)
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.") @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload_leaf(self): def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf( super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
)
...@@ -654,7 +654,7 @@ class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ...@@ -654,7 +654,7 @@ class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
@require_torch_version_greater("2.7.1") @require_torch_version_greater("2.7.1")
class GGUFCompileTests(QuantCompileTests): class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
...@@ -662,15 +662,6 @@ class GGUFCompileTests(QuantCompileTests): ...@@ -662,15 +662,6 @@ class GGUFCompileTests(QuantCompileTests):
def quantization_config(self): def quantization_config(self):
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype) return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
def _init_pipeline(self, *args, **kwargs): def _init_pipeline(self, *args, **kwargs):
transformer = FluxTransformer2DModel.from_single_file( transformer = FluxTransformer2DModel.from_single_file(
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc import gc
import unittest import inspect
import torch import torch
...@@ -23,7 +23,7 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu ...@@ -23,7 +23,7 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu
@require_torch_gpu @require_torch_gpu
@slow @slow
class QuantCompileTests(unittest.TestCase): class QuantCompileTests:
@property @property
def quantization_config(self): def quantization_config(self):
raise NotImplementedError( raise NotImplementedError(
...@@ -50,30 +50,26 @@ class QuantCompileTests(unittest.TestCase): ...@@ -50,30 +50,26 @@ class QuantCompileTests(unittest.TestCase):
) )
return pipe return pipe
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16): def _test_torch_compile(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda") pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda")
# import to ensure fullgraph True # `fullgraph=True` ensures no graph breaks
pipe.transformer.compile(fullgraph=True) pipe.transformer.compile(fullgraph=True)
for _ in range(2): # small resolutions to ensure speedy execution.
# small resolutions to ensure speedy execution. pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16): def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype) pipe = self._init_pipeline(self.quantization_config, torch_dtype)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
pipe.transformer.compile() pipe.transformer.compile()
for _ in range(2): # small resolutions to ensure speedy execution.
# small resolutions to ensure speedy execution. pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_group_offload_leaf( def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False):
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False torch._dynamo.config.cache_size_limit = 1000
):
torch._dynamo.config.cache_size_limit = 10000
pipe = self._init_pipeline(quantization_config, torch_dtype) pipe = self._init_pipeline(self.quantization_config, torch_dtype)
group_offload_kwargs = { group_offload_kwargs = {
"onload_device": torch.device("cuda"), "onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"), "offload_device": torch.device("cpu"),
...@@ -87,6 +83,17 @@ class QuantCompileTests(unittest.TestCase): ...@@ -87,6 +83,17 @@ class QuantCompileTests(unittest.TestCase):
if torch.device(component.device).type == "cpu": if torch.device(component.device).type == "cpu":
component.to("cuda") component.to("cuda")
for _ in range(2): # small resolutions to ensure speedy execution.
# small resolutions to ensure speedy execution. pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def test_torch_compile(self):
self._test_torch_compile()
def test_torch_compile_with_cpu_offload(self):
self._test_torch_compile_with_cpu_offload()
def test_torch_compile_with_group_offload_leaf(self, use_stream=False):
for cls in inspect.getmro(self.__class__):
if "test_torch_compile_with_group_offload_leaf" in cls.__dict__ and cls is not QuantCompileTests:
return
self._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
...@@ -630,7 +630,7 @@ class TorchAoSerializationTest(unittest.TestCase): ...@@ -630,7 +630,7 @@ class TorchAoSerializationTest(unittest.TestCase):
@require_torchao_version_greater_or_equal("0.7.0") @require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests): class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property @property
def quantization_config(self): def quantization_config(self):
return PipelineQuantizationConfig( return PipelineQuantizationConfig(
...@@ -639,17 +639,15 @@ class TorchAoCompileTest(QuantCompileTests): ...@@ -639,17 +639,15 @@ class TorchAoCompileTest(QuantCompileTests):
}, },
) )
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
@unittest.skip( @unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work " "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling." "when compiling."
) )
def test_torch_compile_with_cpu_offload(self): def test_torch_compile_with_cpu_offload(self):
# RuntimeError: _apply(): Couldn't swap Linear.weight # RuntimeError: _apply(): Couldn't swap Linear.weight
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) super().test_torch_compile_with_cpu_offload()
@parameterized.expand([False, True])
@unittest.skip( @unittest.skip(
""" """
For `use_stream=False`: For `use_stream=False`:
...@@ -659,8 +657,7 @@ class TorchAoCompileTest(QuantCompileTests): ...@@ -659,8 +657,7 @@ class TorchAoCompileTest(QuantCompileTests):
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO. Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
""" """
) )
@parameterized.expand([False, True]) def test_torch_compile_with_group_offload_leaf(self, use_stream):
def test_torch_compile_with_group_offload_leaf(self):
# For use_stream=False: # For use_stream=False:
# If we run group offloading without compilation, we will see: # If we run group offloading without compilation, we will see:
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match. # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
...@@ -673,7 +670,7 @@ class TorchAoCompileTest(QuantCompileTests): ...@@ -673,7 +670,7 @@ class TorchAoCompileTest(QuantCompileTests):
# For use_stream=True: # For use_stream=True:
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={} # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config) super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment