Unverified Commit e8246604 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

fix: caching allocator behaviour for quantization. (#12172)



* fix: caching allocator behaviour for quantization.

* up

* Update src/diffusers/models/model_loading_utils.py
Co-authored-by: default avatarAryan <aryan@huggingface.co>

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 03be15e8
...@@ -726,23 +726,29 @@ def _caching_allocator_warmup( ...@@ -726,23 +726,29 @@ def _caching_allocator_warmup(
very large margin. very large margin.
""" """
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor() factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
# Remove disk and cpu devices, and cast to proper torch.device
# Keep only accelerator devices
accelerator_device_map = { accelerator_device_map = {
param: torch.device(device) param: torch.device(device)
for param, device in expanded_device_map.items() for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"] if str(device) not in ["cpu", "disk"]
} }
total_byte_count = defaultdict(lambda: 0) if not accelerator_device_map:
return
elements_per_device = defaultdict(int)
for param_name, device in accelerator_device_map.items(): for param_name, device in accelerator_device_map.items():
try: try:
param = model.get_parameter(param_name) p = model.get_parameter(param_name)
except AttributeError:
try:
p = model.get_buffer(param_name)
except AttributeError: except AttributeError:
param = model.get_buffer(param_name) raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = param.numel() * param.element_size()
# TODO: account for TP when needed. # TODO: account for TP when needed.
total_byte_count[device] += param_byte_count elements_per_device[device] += p.numel()
# This will kick off the caching allocator to avoid having to Malloc afterwards # This will kick off the caching allocator to avoid having to Malloc afterwards
for device, byte_count in total_byte_count.items(): for device, elem_count in elements_per_device.items():
_ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False) warmup_elems = max(1, elem_count // factor)
_ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment