Unverified Commit 56f74005 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[tests] enable bnb tests on xpu (#11001)

* enable bnb on xpu

* add 2 more cases

* add missing change

* add missing change

* add one more
parent a34d97ce
...@@ -427,7 +427,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -427,7 +427,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline." "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
) )
if device_type == "cuda": if device_type in ["cuda", "xpu"]:
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb: if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
raise ValueError( raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
...@@ -440,7 +440,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -440,7 +440,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# Display a warning in this case (the operation succeeds but the benefits are lost) # Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device_type == "cuda": if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
logger.warning( logger.warning(
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
) )
......
...@@ -61,7 +61,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -61,7 +61,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
def validate_environment(self, *args, **kwargs): def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available(): if not (torch.cuda.is_available() or torch.xpu.is_available()):
raise RuntimeError("No GPU found. A GPU is needed for quantization.") raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError( raise ImportError(
...@@ -238,11 +238,15 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -238,11 +238,15 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
def update_device_map(self, device_map): def update_device_map(self, device_map):
if device_map is None: if device_map is None:
device_map = {"": f"cuda:{torch.cuda.current_device()}"} if torch.xpu.is_available():
current_device = f"xpu:{torch.xpu.current_device()}"
else:
current_device = f"cuda:{torch.cuda.current_device()}"
device_map = {"": current_device}
logger.info( logger.info(
"The device_map was not initialized. " "The device_map was not initialized. "
"Setting device_map to {" "Setting device_map to {"
": f`cuda:{torch.cuda.current_device()}`}. " ": {current_device}}. "
"If you want to use the model for inference, please set device_map ='auto' " "If you want to use the model for inference, please set device_map ='auto' "
) )
return device_map return device_map
...@@ -312,7 +316,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -312,7 +316,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
logger.info( logger.info(
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device." "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
) )
model.to(torch.cuda.current_device()) if torch.xpu.is_available():
model.to(torch.xpu.current_device())
else:
model.to(torch.cuda.current_device())
model = dequantize_and_replace( model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config model, self.modules_to_not_convert, quantization_config=self.quantization_config
...@@ -343,7 +350,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -343,7 +350,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
def validate_environment(self, *args, **kwargs): def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available(): if not (torch.cuda.is_available() or torch.xpu.is_available()):
raise RuntimeError("No GPU found. A GPU is needed for quantization.") raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError( raise ImportError(
...@@ -402,11 +409,15 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -402,11 +409,15 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
def update_device_map(self, device_map): def update_device_map(self, device_map):
if device_map is None: if device_map is None:
device_map = {"": f"cuda:{torch.cuda.current_device()}"} if torch.xpu.is_available():
current_device = f"xpu:{torch.xpu.current_device()}"
else:
current_device = f"cuda:{torch.cuda.current_device()}"
device_map = {"": current_device}
logger.info( logger.info(
"The device_map was not initialized. " "The device_map was not initialized. "
"Setting device_map to {" "Setting device_map to {"
": f`cuda:{torch.cuda.current_device()}`}. " ": {current_device}}. "
"If you want to use the model for inference, please set device_map ='auto' " "If you want to use the model for inference, please set device_map ='auto' "
) )
return device_map return device_map
......
...@@ -574,10 +574,10 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) - ...@@ -574,10 +574,10 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
return arry return arry
def load_pt(url: str): def load_pt(url: str, map_location: str):
response = requests.get(url) response = requests.get(url)
response.raise_for_status() response.raise_for_status()
arry = torch.load(BytesIO(response.content)) arry = torch.load(BytesIO(response.content), map_location=map_location)
return arry return arry
......
...@@ -377,9 +377,10 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): ...@@ -377,9 +377,10 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
pipeline.set_ip_adapter_scale(0.7) pipeline.set_ip_adapter_scale(0.7)
inputs = self.get_dummy_inputs() inputs = self.get_dummy_inputs()
id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[ id_embeds = load_pt(
0 "https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt",
] map_location=torch_device,
)[0]
id_embeds = id_embeds.reshape((2, 1, 1, 512)) id_embeds = id_embeds.reshape((2, 1, 1, 512))
inputs["ip_adapter_image_embeds"] = [id_embeds] inputs["ip_adapter_image_embeds"] = [id_embeds]
inputs["ip_adapter_image"] = None inputs["ip_adapter_image"] = None
......
...@@ -26,6 +26,7 @@ from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DMo ...@@ -26,6 +26,7 @@ from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DMo
from diffusers.utils import is_accelerate_version, logging from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
backend_empty_cache,
is_bitsandbytes_available, is_bitsandbytes_available,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
...@@ -35,7 +36,7 @@ from diffusers.utils.testing_utils import ( ...@@ -35,7 +36,7 @@ from diffusers.utils.testing_utils import (
require_bitsandbytes_version_greater, require_bitsandbytes_version_greater,
require_peft_backend, require_peft_backend,
require_torch, require_torch,
require_torch_gpu, require_torch_accelerator,
require_transformers_version_greater, require_transformers_version_greater,
slow, slow,
torch_device, torch_device,
...@@ -66,7 +67,7 @@ if is_bitsandbytes_available(): ...@@ -66,7 +67,7 @@ if is_bitsandbytes_available():
@require_bitsandbytes_version_greater("0.43.2") @require_bitsandbytes_version_greater("0.43.2")
@require_accelerate @require_accelerate
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
@slow @slow
class Base4bitTests(unittest.TestCase): class Base4bitTests(unittest.TestCase):
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
...@@ -84,13 +85,16 @@ class Base4bitTests(unittest.TestCase): ...@@ -84,13 +85,16 @@ class Base4bitTests(unittest.TestCase):
def get_dummy_inputs(self): def get_dummy_inputs(self):
prompt_embeds = load_pt( prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
torch_device,
) )
pooled_prompt_embeds = load_pt( pooled_prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
torch_device,
) )
latent_model_input = load_pt( latent_model_input = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
torch_device,
) )
input_dict_for_transformer = { input_dict_for_transformer = {
...@@ -106,7 +110,7 @@ class Base4bitTests(unittest.TestCase): ...@@ -106,7 +110,7 @@ class Base4bitTests(unittest.TestCase):
class BnB4BitBasicTests(Base4bitTests): class BnB4BitBasicTests(Base4bitTests):
def setUp(self): def setUp(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
# Models # Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained( self.model_fp16 = SD3Transformer2DModel.from_pretrained(
...@@ -128,7 +132,7 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -128,7 +132,7 @@ class BnB4BitBasicTests(Base4bitTests):
del self.model_4bit del self.model_4bit
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_quantization_num_parameters(self): def test_quantization_num_parameters(self):
r""" r"""
...@@ -224,7 +228,7 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -224,7 +228,7 @@ class BnB4BitBasicTests(Base4bitTests):
self.assertTrue(module.weight.dtype == torch.uint8) self.assertTrue(module.weight.dtype == torch.uint8)
# test if inference works. # test if inference works.
with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch.float16):
input_dict_for_transformer = self.get_dummy_inputs() input_dict_for_transformer = self.get_dummy_inputs()
model_inputs = { model_inputs = {
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
...@@ -266,9 +270,9 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -266,9 +270,9 @@ class BnB4BitBasicTests(Base4bitTests):
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
# Move back to CUDA device # Move back to CUDA device
for device in [0, "cuda", "cuda:0", "call()"]: for device in [0, f"{torch_device}", f"{torch_device}:0", "call()"]:
if device == "call()": if device == "call()":
self.model_4bit.cuda(0) self.model_4bit.to(f"{torch_device}:0")
else: else:
self.model_4bit.to(device) self.model_4bit.to(device)
self.assertEqual(self.model_4bit.device, torch.device(0)) self.assertEqual(self.model_4bit.device, torch.device(0))
...@@ -286,7 +290,7 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -286,7 +290,7 @@ class BnB4BitBasicTests(Base4bitTests):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Tries with a `device` and `dtype` # Tries with a `device` and `dtype`
self.model_4bit.to(device="cuda:0", dtype=torch.float16) self.model_4bit.to(device=f"{torch_device}:0", dtype=torch.float16)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Tries with a cast # Tries with a cast
...@@ -297,7 +301,7 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -297,7 +301,7 @@ class BnB4BitBasicTests(Base4bitTests):
self.model_4bit.half() self.model_4bit.half()
# This should work # This should work
self.model_4bit.to("cuda") self.model_4bit.to(torch_device)
# Test if we did not break anything # Test if we did not break anything
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
...@@ -321,7 +325,7 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -321,7 +325,7 @@ class BnB4BitBasicTests(Base4bitTests):
_ = self.model_fp16.float() _ = self.model_fp16.float()
# Check that this does not throw an error # Check that this does not throw an error
_ = self.model_fp16.cuda() _ = self.model_fp16.to(torch_device)
def test_bnb_4bit_wrong_config(self): def test_bnb_4bit_wrong_config(self):
r""" r"""
...@@ -398,7 +402,7 @@ class BnB4BitTrainingTests(Base4bitTests): ...@@ -398,7 +402,7 @@ class BnB4BitTrainingTests(Base4bitTests):
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
# Step 4: Check if the gradient is not None # Step 4: Check if the gradient is not None
with torch.amp.autocast("cuda", dtype=torch.float16): with torch.amp.autocast(torch_device, dtype=torch.float16):
out = self.model_4bit(**model_inputs)[0] out = self.model_4bit(**model_inputs)[0]
out.norm().backward() out.norm().backward()
...@@ -412,7 +416,7 @@ class BnB4BitTrainingTests(Base4bitTests): ...@@ -412,7 +416,7 @@ class BnB4BitTrainingTests(Base4bitTests):
class SlowBnb4BitTests(Base4bitTests): class SlowBnb4BitTests(Base4bitTests):
def setUp(self) -> None: def setUp(self) -> None:
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
nf4_config = BitsAndBytesConfig( nf4_config = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
...@@ -431,7 +435,7 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -431,7 +435,7 @@ class SlowBnb4BitTests(Base4bitTests):
del self.pipeline_4bit del self.pipeline_4bit
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_quality(self): def test_quality(self):
output = self.pipeline_4bit( output = self.pipeline_4bit(
...@@ -501,7 +505,7 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -501,7 +505,7 @@ class SlowBnb4BitTests(Base4bitTests):
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.", reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True, strict=True,
) )
def test_pipeline_cuda_placement_works_with_nf4(self): def test_pipeline_device_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig( transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
...@@ -532,7 +536,7 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -532,7 +536,7 @@ class SlowBnb4BitTests(Base4bitTests):
transformer=transformer_4bit, transformer=transformer_4bit,
text_encoder_3=text_encoder_3_4bit, text_encoder_3=text_encoder_3_4bit,
torch_dtype=torch.float16, torch_dtype=torch.float16,
).to("cuda") ).to(torch_device)
# Check if inference works. # Check if inference works.
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2) _ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
...@@ -696,7 +700,7 @@ class SlowBnb4BitFluxTests(Base4bitTests): ...@@ -696,7 +700,7 @@ class SlowBnb4BitFluxTests(Base4bitTests):
class BaseBnb4BitSerializationTests(Base4bitTests): class BaseBnb4BitSerializationTests(Base4bitTests):
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True): def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
r""" r"""
......
...@@ -31,6 +31,7 @@ from diffusers import ( ...@@ -31,6 +31,7 @@ from diffusers import (
from diffusers.utils import is_accelerate_version from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
backend_empty_cache,
is_bitsandbytes_available, is_bitsandbytes_available,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
...@@ -40,7 +41,7 @@ from diffusers.utils.testing_utils import ( ...@@ -40,7 +41,7 @@ from diffusers.utils.testing_utils import (
require_bitsandbytes_version_greater, require_bitsandbytes_version_greater,
require_peft_version_greater, require_peft_version_greater,
require_torch, require_torch,
require_torch_gpu, require_torch_accelerator,
require_transformers_version_greater, require_transformers_version_greater,
slow, slow,
torch_device, torch_device,
...@@ -71,7 +72,7 @@ if is_bitsandbytes_available(): ...@@ -71,7 +72,7 @@ if is_bitsandbytes_available():
@require_bitsandbytes_version_greater("0.43.2") @require_bitsandbytes_version_greater("0.43.2")
@require_accelerate @require_accelerate
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
@slow @slow
class Base8bitTests(unittest.TestCase): class Base8bitTests(unittest.TestCase):
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
...@@ -111,7 +112,7 @@ class Base8bitTests(unittest.TestCase): ...@@ -111,7 +112,7 @@ class Base8bitTests(unittest.TestCase):
class BnB8bitBasicTests(Base8bitTests): class BnB8bitBasicTests(Base8bitTests):
def setUp(self): def setUp(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
# Models # Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained( self.model_fp16 = SD3Transformer2DModel.from_pretrained(
...@@ -129,7 +130,7 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -129,7 +130,7 @@ class BnB8bitBasicTests(Base8bitTests):
del self.model_8bit del self.model_8bit
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_quantization_num_parameters(self): def test_quantization_num_parameters(self):
r""" r"""
...@@ -279,7 +280,7 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -279,7 +280,7 @@ class BnB8bitBasicTests(Base8bitTests):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Tries with a `device` # Tries with a `device`
self.model_8bit.to(torch.device("cuda:0")) self.model_8bit.to(torch.device(f"{torch_device}:0"))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Tries with a `device` # Tries with a `device`
...@@ -317,7 +318,7 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -317,7 +318,7 @@ class BnB8bitBasicTests(Base8bitTests):
class Bnb8bitDeviceTests(Base8bitTests): class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None: def setUp(self) -> None:
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SanaTransformer2DModel.from_pretrained( self.model_8bit = SanaTransformer2DModel.from_pretrained(
...@@ -331,7 +332,7 @@ class Bnb8bitDeviceTests(Base8bitTests): ...@@ -331,7 +332,7 @@ class Bnb8bitDeviceTests(Base8bitTests):
del self.model_8bit del self.model_8bit
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_buffers_device_assignment(self): def test_buffers_device_assignment(self):
for buffer_name, buffer in self.model_8bit.named_buffers(): for buffer_name, buffer in self.model_8bit.named_buffers():
...@@ -345,7 +346,7 @@ class Bnb8bitDeviceTests(Base8bitTests): ...@@ -345,7 +346,7 @@ class Bnb8bitDeviceTests(Base8bitTests):
class BnB8bitTrainingTests(Base8bitTests): class BnB8bitTrainingTests(Base8bitTests):
def setUp(self): def setUp(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained( self.model_8bit = SD3Transformer2DModel.from_pretrained(
...@@ -389,7 +390,7 @@ class BnB8bitTrainingTests(Base8bitTests): ...@@ -389,7 +390,7 @@ class BnB8bitTrainingTests(Base8bitTests):
class SlowBnb8bitTests(Base8bitTests): class SlowBnb8bitTests(Base8bitTests):
def setUp(self) -> None: def setUp(self) -> None:
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained( model_8bit = SD3Transformer2DModel.from_pretrained(
...@@ -404,7 +405,7 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -404,7 +405,7 @@ class SlowBnb8bitTests(Base8bitTests):
del self.pipeline_8bit del self.pipeline_8bit
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_quality(self): def test_quality(self):
output = self.pipeline_8bit( output = self.pipeline_8bit(
...@@ -616,7 +617,7 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -616,7 +617,7 @@ class SlowBnb8bitTests(Base8bitTests):
class SlowBnb8bitFluxTests(Base8bitTests): class SlowBnb8bitFluxTests(Base8bitTests):
def setUp(self) -> None: def setUp(self) -> None:
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
model_id = "hf-internal-testing/flux.1-dev-int8-pkg" model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
...@@ -633,7 +634,7 @@ class SlowBnb8bitFluxTests(Base8bitTests): ...@@ -633,7 +634,7 @@ class SlowBnb8bitFluxTests(Base8bitTests):
del self.pipeline_8bit del self.pipeline_8bit
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_quality(self): def test_quality(self):
# keep the resolution and max tokens to a lower number for faster execution. # keep the resolution and max tokens to a lower number for faster execution.
...@@ -680,7 +681,7 @@ class SlowBnb8bitFluxTests(Base8bitTests): ...@@ -680,7 +681,7 @@ class SlowBnb8bitFluxTests(Base8bitTests):
class BaseBnb8bitSerializationTests(Base8bitTests): class BaseBnb8bitSerializationTests(Base8bitTests):
def setUp(self): def setUp(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
quantization_config = BitsAndBytesConfig( quantization_config = BitsAndBytesConfig(
load_in_8bit=True, load_in_8bit=True,
...@@ -693,7 +694,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests): ...@@ -693,7 +694,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
del self.model_0 del self.model_0
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_serialization(self): def test_serialization(self):
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment