Unverified Commit 58bf2682 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

support `hf_quantizer` in cache warmup. (#12043)

* support hf_quantizer in cache warmup.

* reviewer feedback

* up

* up
parent 1b48db4c
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import functools import functools
import importlib import importlib
import inspect import inspect
import math
import os import os
from array import array from array import array
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
...@@ -717,27 +716,33 @@ def _expand_device_map(device_map, param_names): ...@@ -717,27 +716,33 @@ def _expand_device_map(device_map, param_names):
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859 # Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None: def _caching_allocator_warmup(
model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
) -> None:
""" """
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model, device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
very large margin. very large margin.
""" """
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
# Remove disk and cpu devices, and cast to proper torch.device # Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = { accelerator_device_map = {
param: torch.device(device) param: torch.device(device)
for param, device in expanded_device_map.items() for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"] if str(device) not in ["cpu", "disk"]
} }
parameter_count = defaultdict(lambda: 0) total_byte_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items(): for param_name, device in accelerator_device_map.items():
try: try:
param = model.get_parameter(param_name) param = model.get_parameter(param_name)
except AttributeError: except AttributeError:
param = model.get_buffer(param_name) param = model.get_buffer(param_name)
parameter_count[device] += math.prod(param.shape) # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = param.numel() * param.element_size()
# TODO: account for TP when needed.
total_byte_count[device] += param_byte_count
# This will kick off the caching allocator to avoid having to Malloc afterwards # This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items(): for device, byte_count in total_byte_count.items():
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) _ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)
...@@ -1532,10 +1532,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1532,10 +1532,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# tensors using their expected shape and not performing any initialization of the memory (empty data). # tensors using their expected shape and not performing any initialization of the memory (empty data).
# When the actual device allocations happen, the allocator already has a pool of unused device memory # When the actual device allocations happen, the allocator already has a pool of unused device memory
# that it can re-use for faster loading of the model. # that it can re-use for faster loading of the model.
# TODO: add support for warmup with hf_quantizer if device_map is not None:
if device_map is not None and hf_quantizer is None:
expanded_device_map = _expand_device_map(device_map, expected_keys) expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, dtype) _caching_allocator_warmup(model, expanded_device_map, dtype, hf_quantizer)
offload_index = {} if device_map is not None and "disk" in device_map.values() else None offload_index = {} if device_map is not None and "disk" in device_map.values() else None
state_dict_folder, state_dict_index = None, None state_dict_folder, state_dict_index = None, None
......
...@@ -209,6 +209,17 @@ class DiffusersQuantizer(ABC): ...@@ -209,6 +209,17 @@ class DiffusersQuantizer(ABC):
return model return model
def get_cuda_warm_up_factor(self):
"""
The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda.
A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
we allocate half the memory of the weights residing in the empty model, etc...
"""
# By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
# really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
# weight loading)
return 4
def _dequantize(self, model): def _dequantize(self, model):
raise NotImplementedError( raise NotImplementedError(
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub." f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
......
...@@ -19,6 +19,7 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17 ...@@ -19,6 +19,7 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
import importlib import importlib
import types import types
from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Union
from packaging import version from packaging import version
...@@ -278,6 +279,31 @@ class TorchAoHfQuantizer(DiffusersQuantizer): ...@@ -278,6 +279,31 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass()) quantize_(module, self.quantization_config.get_apply_tensor_subclass())
def get_cuda_warm_up_factor(self):
"""
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
- A factor of 2 means we pre-allocate the full memory footprint of the model.
- A factor of 4 means we pre-allocate half of that, and so on
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give
the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents
quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the
torch_dtype not the actual bit-width of the quantized data.
To correct for this:
- Use a division factor of 8 for int4 weights
- Use a division factor of 4 for int8 weights
"""
# Original mapping for non-AOBaseConfig types
# For the uint types, this is a best guess. Once these types become more used
# we can look into their nuances.
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
quant_type = self.quantization_config.quant_type
for pattern, target_dtype in map_to_target_dtype.items():
if fnmatch(quant_type, pattern):
return target_dtype
raise ValueError(f"Unsupported quant_type: {quant_type!r}")
def _process_model_before_weight_loading( def _process_model_before_weight_loading(
self, self,
model: "ModelMixin", model: "ModelMixin",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment