"vscode:/vscode.git/clone" did not exist on "31058cdaef63ca660a1a045281d156239fba8192"
Unverified Commit e7e6d852 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)

* memory usage tests

* fixes

* gguf
parent 8eefed65
...@@ -135,6 +135,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -135,6 +135,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
target_device: "torch.device", target_device: "torch.device",
state_dict: Dict[str, Any], state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None, unexpected_keys: Optional[List[str]] = None,
**kwargs,
): ):
import bitsandbytes as bnb import bitsandbytes as bnb
...@@ -445,6 +446,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -445,6 +446,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
target_device: "torch.device", target_device: "torch.device",
state_dict: Dict[str, Any], state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None, unexpected_keys: Optional[List[str]] = None,
**kwargs,
): ):
import bitsandbytes as bnb import bitsandbytes as bnb
......
...@@ -108,6 +108,7 @@ class GGUFQuantizer(DiffusersQuantizer): ...@@ -108,6 +108,7 @@ class GGUFQuantizer(DiffusersQuantizer):
target_device: "torch.device", target_device: "torch.device",
state_dict: Optional[Dict[str, Any]] = None, state_dict: Optional[Dict[str, Any]] = None,
unexpected_keys: Optional[List[str]] = None, unexpected_keys: Optional[List[str]] = None,
**kwargs,
): ):
module, tensor_name = get_module_from_name(model, param_name) module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters and tensor_name not in module._buffers: if tensor_name not in module._parameters and tensor_name not in module._buffers:
......
...@@ -215,6 +215,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer): ...@@ -215,6 +215,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
target_device: "torch.device", target_device: "torch.device",
state_dict: Dict[str, Any], state_dict: Dict[str, Any],
unexpected_keys: List[str], unexpected_keys: List[str],
**kwargs,
): ):
r""" r"""
Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor, Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
......
...@@ -54,29 +54,8 @@ if is_transformers_available(): ...@@ -54,29 +54,8 @@ if is_transformers_available():
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn as nn
class LoRALayer(nn.Module): from ..utils import LoRALayer, get_memory_consumption_stat
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
if is_bitsandbytes_available(): if is_bitsandbytes_available():
...@@ -96,6 +75,8 @@ class Base4bitTests(unittest.TestCase): ...@@ -96,6 +75,8 @@ class Base4bitTests(unittest.TestCase):
# This was obtained on audace so the number might slightly change # This was obtained on audace so the number might slightly change
expected_rel_difference = 3.69 expected_rel_difference = 3.69
expected_memory_saving_ratio = 0.8
prompt = "a beautiful sunset amidst the mountains." prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10 num_inference_steps = 10
seed = 0 seed = 0
...@@ -140,7 +121,9 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -140,7 +121,9 @@ class BnB4BitBasicTests(Base4bitTests):
) )
def tearDown(self): def tearDown(self):
if hasattr(self, "model_fp16"):
del self.model_fp16 del self.model_fp16
if hasattr(self, "model_4bit"):
del self.model_4bit del self.model_4bit
gc.collect() gc.collect()
...@@ -180,6 +163,32 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -180,6 +163,32 @@ class BnB4BitBasicTests(Base4bitTests):
linear = get_some_linear_layer(self.model_4bit) linear = get_some_linear_layer(self.model_4bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
def test_model_memory_usage(self):
# Delete to not let anything interfere.
del self.model_4bit, self.model_fp16
# Re-instantiate.
inputs = self.get_dummy_inputs()
inputs = {
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
}
model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
).to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
del model_fp16
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.float16
)
quantized_model_memory = get_memory_consumption_stat(model_4bit, inputs)
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
def test_original_dtype(self): def test_original_dtype(self):
r""" r"""
A simple test to check if the model succesfully stores the original dtype A simple test to check if the model succesfully stores the original dtype
......
...@@ -60,29 +60,8 @@ if is_transformers_available(): ...@@ -60,29 +60,8 @@ if is_transformers_available():
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn as nn
class LoRALayer(nn.Module): from ..utils import LoRALayer, get_memory_consumption_stat
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
if is_bitsandbytes_available(): if is_bitsandbytes_available():
...@@ -102,6 +81,8 @@ class Base8bitTests(unittest.TestCase): ...@@ -102,6 +81,8 @@ class Base8bitTests(unittest.TestCase):
# This was obtained on audace so the number might slightly change # This was obtained on audace so the number might slightly change
expected_rel_difference = 1.94 expected_rel_difference = 1.94
expected_memory_saving_ratio = 0.7
prompt = "a beautiful sunset amidst the mountains." prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10 num_inference_steps = 10
seed = 0 seed = 0
...@@ -142,7 +123,9 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -142,7 +123,9 @@ class BnB8bitBasicTests(Base8bitTests):
) )
def tearDown(self): def tearDown(self):
if hasattr(self, "model_fp16"):
del self.model_fp16 del self.model_fp16
if hasattr(self, "model_8bit"):
del self.model_8bit del self.model_8bit
gc.collect() gc.collect()
...@@ -182,6 +165,28 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -182,6 +165,28 @@ class BnB8bitBasicTests(Base8bitTests):
linear = get_some_linear_layer(self.model_8bit) linear = get_some_linear_layer(self.model_8bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
def test_model_memory_usage(self):
# Delete to not let anything interfere.
del self.model_8bit, self.model_fp16
# Re-instantiate.
inputs = self.get_dummy_inputs()
inputs = {
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
}
model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
).to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
del model_fp16
config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=config, torch_dtype=torch.float16
)
quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs)
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
def test_original_dtype(self): def test_original_dtype(self):
r""" r"""
A simple test to check if the model succesfully stores the original dtype A simple test to check if the model succesfully stores the original dtype
...@@ -248,7 +253,7 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -248,7 +253,7 @@ class BnB8bitBasicTests(Base8bitTests):
self.assertTrue(linear.weight.dtype == torch.int8) self.assertTrue(linear.weight.dtype == torch.int8)
self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt)) self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt))
self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear)) self.assertTrue(isinstance(model_8bit.proj_out, torch.nn.Linear))
self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8) self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8)
def test_config_from_pretrained(self): def test_config_from_pretrained(self):
......
...@@ -19,29 +19,8 @@ if is_optimum_quanto_available(): ...@@ -19,29 +19,8 @@ if is_optimum_quanto_available():
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn as nn
class LoRALayer(nn.Module): from ..utils import LoRALayer, get_memory_consumption_stat
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
@nightly @nightly
...@@ -85,20 +64,20 @@ class QuantoBaseTesterMixin: ...@@ -85,20 +64,20 @@ class QuantoBaseTesterMixin:
assert isinstance(module, QLinear) assert isinstance(module, QLinear)
def test_quanto_memory_usage(self): def test_quanto_memory_usage(self):
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
unquantized_model_memory = unquantized_model.get_memory_footprint() / 1024**3
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
inputs = self.get_dummy_inputs() inputs = self.get_dummy_inputs()
inputs = {
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
}
torch.cuda.reset_peak_memory_stats() unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
torch.cuda.empty_cache() unquantized_model.to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
model.to(torch_device) quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
with torch.no_grad(): quantized_model.to(torch_device)
model(**inputs) quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
max_memory = torch.cuda.max_memory_allocated() / 1024**3
assert (1.0 - (max_memory / unquantized_model_memory)) >= self.expected_memory_reduction assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
def test_keep_modules_in_fp32(self): def test_keep_modules_in_fp32(self):
r""" r"""
...@@ -318,14 +297,14 @@ class FluxTransformerQuantoMixin(QuantoBaseTesterMixin): ...@@ -318,14 +297,14 @@ class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.3 expected_memory_reduction = 0.6
def get_dummy_init_kwargs(self): def get_dummy_init_kwargs(self):
return {"weights_dtype": "float8"} return {"weights_dtype": "float8"}
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.3 expected_memory_reduction = 0.6
_test_torch_compile = True _test_torch_compile = True
def get_dummy_init_kwargs(self): def get_dummy_init_kwargs(self):
......
...@@ -50,27 +50,7 @@ if is_torch_available(): ...@@ -50,27 +50,7 @@ if is_torch_available():
import torch import torch
import torch.nn as nn import torch.nn as nn
class LoRALayer(nn.Module): from ..utils import LoRALayer, get_memory_consumption_stat
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
if is_torchao_available(): if is_torchao_available():
...@@ -503,6 +483,22 @@ class TorchAoTest(unittest.TestCase): ...@@ -503,6 +483,22 @@ class TorchAoTest(unittest.TestCase):
# there is additional overhead of scales and zero points # there is additional overhead of scales and zero points
self.assertTrue(total_bf16 < total_int4wo) self.assertTrue(total_bf16 < total_int4wo)
def test_model_memory_usage(self):
model_id = "hf-internal-testing/tiny-flux-pipe"
expected_memory_saving_ratio = 2.0
inputs = self.get_dummy_tensor_inputs(device=torch_device)
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
transformer_bf16.to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
del transformer_bf16
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
transformer_int8wo.to(torch_device)
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
def test_wrong_config(self): def test_wrong_config(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.get_dummy_components(TorchAoConfig("int42")) self.get_dummy_components(TorchAoConfig("int42"))
......
from diffusers.utils import is_torch_available
if is_torch_available():
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
@torch.no_grad()
@torch.inference_mode()
def get_memory_consumption_stat(model, inputs):
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
model(**inputs)
max_memory_mem_allocated = torch.cuda.max_memory_allocated()
return max_memory_mem_allocated
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment