Unverified Commit f5929e03 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

[FEAT] Model loading refactor (#10604)



* first draft model loading refactor

* revert name change

* fix bnb

* revert name

* fix dduf

* fix huanyan

* style

* Update src/diffusers/models/model_loading_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* suggestions from reviews

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* remove safetensors check

* fix default value

* more fix from suggestions

* revert logic for single file

* style

* typing + fix couple of issues

* improve speed

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarAryan <aryan@huggingface.co>

* fp8 dtype

* add tests

* rename resolved_archive_file to resolved_model_file

* format

* map_location default cpu

* add utility function

* switch to smaller model + test inference

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* rm comment

* add log

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* add decorator

* cosine sim instead

* fix use_keep_in_fp32_modules

* comm

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 6fe05b9b
...@@ -52,7 +52,7 @@ logger = logging.get_logger(__name__) ...@@ -52,7 +52,7 @@ logger = logging.get_logger(__name__)
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import dispatch_model, init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta from ..models.modeling_utils import load_model_dict_into_meta
...@@ -366,19 +366,23 @@ class FromOriginalModelMixin: ...@@ -366,19 +366,23 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules, keep_in_fp32_modules=keep_in_fp32_modules,
) )
device_map = None
if is_accelerate_available(): if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu") param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers() empty_state_dict = model.state_dict()
unexpected_keys = load_model_dict_into_meta( unexpected_keys = [
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
]
device_map = {"": param_device}
load_model_dict_into_meta(
model, model,
diffusers_format_checkpoint, diffusers_format_checkpoint,
dtype=torch_dtype, dtype=torch_dtype,
device=param_device, device_map=device_map,
hf_quantizer=hf_quantizer, hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules, keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers, unexpected_keys=unexpected_keys,
) )
else: else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
...@@ -400,4 +404,8 @@ class FromOriginalModelMixin: ...@@ -400,4 +404,8 @@ class FromOriginalModelMixin:
model.eval() model.eval()
if device_map is not None:
device_map_kwargs = {"device_map": device_map}
dispatch_model(model, **device_map_kwargs)
return model return model
...@@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm( ...@@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm(
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
if is_accelerate_available(): if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else: else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) model.load_state_dict(diffusers_format_checkpoint, strict=False)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
if torch_dtype is not None: if torch_dtype is not None:
model.to(torch_dtype) model.to(torch_dtype)
...@@ -2061,16 +2052,7 @@ def create_diffusers_t5_model_from_checkpoint( ...@@ -2061,16 +2052,7 @@ def create_diffusers_t5_model_from_checkpoint(
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
if is_accelerate_available(): if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: else:
model.load_state_dict(diffusers_format_checkpoint) model.load_state_dict(diffusers_format_checkpoint)
......
...@@ -20,13 +20,15 @@ import os ...@@ -20,13 +20,15 @@ import os
from array import array from array import array
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
import safetensors import safetensors
import torch import torch
from huggingface_hub import DDUFEntry from huggingface_hub import DDUFEntry
from huggingface_hub.utils import EntryNotFoundError from huggingface_hub.utils import EntryNotFoundError
from ..quantizers import DiffusersQuantizer
from ..utils import ( from ..utils import (
GGUF_FILE_EXTENSION, GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
...@@ -55,7 +57,7 @@ _CLASS_REMAPPING_DICT = { ...@@ -55,7 +57,7 @@ _CLASS_REMAPPING_DICT = {
if is_accelerate_available(): if is_accelerate_available():
from accelerate import infer_auto_device_map from accelerate import infer_auto_device_map
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
# Adapted from `transformers` (see modeling_utils.py) # Adapted from `transformers` (see modeling_utils.py)
...@@ -132,17 +134,46 @@ def _fetch_remapped_cls_from_config(config, old_class): ...@@ -132,17 +134,46 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class return old_class
def _check_archive_and_maybe_raise_error(checkpoint_file, format_list):
"""
Check format of the archive
"""
with safetensors.safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") not in format_list:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
"""
Find the device of param_name from the device_map.
"""
if device_map is None:
return "cpu"
else:
module_name = param_name
# find next higher level module that is defined in device_map:
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
if module_name == "" and "" not in device_map:
raise ValueError(f"{param_name} doesn't have any device set.")
return device_map[module_name]
def load_state_dict( def load_state_dict(
checkpoint_file: Union[str, os.PathLike], checkpoint_file: Union[str, os.PathLike],
variant: Optional[str] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None, dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
disable_mmap: bool = False, disable_mmap: bool = False,
map_location: Union[str, torch.device] = "cpu",
): ):
""" """
Reads a checkpoint file, returning properly formatted errors if they arise. Reads a checkpoint file, returning properly formatted errors if they arise.
""" """
# TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change # TODO: maybe refactor a bit this part where we pass a dict here
# when refactoring the _merge_sharded_checkpoints() method later.
if isinstance(checkpoint_file, dict): if isinstance(checkpoint_file, dict):
return checkpoint_file return checkpoint_file
try: try:
...@@ -152,19 +183,26 @@ def load_state_dict( ...@@ -152,19 +183,26 @@ def load_state_dict(
# tensors are loaded on cpu # tensors are loaded on cpu
with dduf_entries[checkpoint_file].as_mmap() as mm: with dduf_entries[checkpoint_file].as_mmap() as mm:
return safetensors.torch.load(mm) return safetensors.torch.load(mm)
_check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"])
if disable_mmap: if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read()) return safetensors.torch.load(open(checkpoint_file, "rb").read())
else: else:
return safetensors.torch.load_file(checkpoint_file, device="cpu") return safetensors.torch.load_file(checkpoint_file, device=map_location)
elif file_extension == GGUF_FILE_EXTENSION: elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file) return load_gguf_checkpoint(checkpoint_file)
else: else:
extra_args = {}
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load( # mmap can only be used with files serialized with zipfile-based format.
checkpoint_file, if (
map_location="cpu", isinstance(checkpoint_file, str)
**weights_only_kwarg, and map_location != "meta"
) and is_torch_version(">=", "2.1.0")
and is_zipfile(checkpoint_file)
and not disable_mmap
):
extra_args = {"mmap": True}
return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
except Exception as e: except Exception as e:
try: try:
with open(checkpoint_file) as f: with open(checkpoint_file) as f:
...@@ -188,23 +226,24 @@ def load_state_dict( ...@@ -188,23 +226,24 @@ def load_state_dict(
def load_model_dict_into_meta( def load_model_dict_into_meta(
model, model,
state_dict: OrderedDict, state_dict: OrderedDict,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None, dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None, model_name_or_path: Optional[str] = None,
hf_quantizer=None, hf_quantizer: Optional[DiffusersQuantizer] = None,
keep_in_fp32_modules=None, keep_in_fp32_modules: Optional[List] = None,
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None, device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
unexpected_keys: Optional[List[str]] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
offload_index: Optional[Dict] = None,
state_dict_index: Optional[Dict] = None,
state_dict_folder: Optional[Union[str, os.PathLike]] = None,
) -> List[str]: ) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)): """
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.") This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
if hf_quantizer is None: params on a `meta` device. It replaces the model params with the data from the `state_dict`
device = device or torch.device("cpu") """
dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) is_quantized = hf_quantizer is not None
empty_state_dict = model.state_dict() empty_state_dict = model.state_dict()
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
for param_name, param in state_dict.items(): for param_name, param in state_dict.items():
if param_name not in empty_state_dict: if param_name not in empty_state_dict:
...@@ -214,21 +253,35 @@ def load_model_dict_into_meta( ...@@ -214,21 +253,35 @@ def load_model_dict_into_meta(
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
# in int/uint/bool and not cast them. # in int/uint/bool and not cast them.
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn # TODO: revisit cases when param.dtype == torch.float8_e4m3fn
if torch.is_floating_point(param): if dtype is not None and torch.is_floating_point(param):
if ( if keep_in_fp32_modules is not None and any(
keep_in_fp32_modules is not None module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
and any(
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
)
and dtype == torch.float16
): ):
param = param.to(torch.float32) param = param.to(torch.float32)
if accepts_dtype: set_module_kwargs["dtype"] = torch.float32
set_module_kwargs["dtype"] = torch.float32
else: else:
param = param.to(dtype) param = param.to(dtype)
if accepts_dtype: set_module_kwargs["dtype"] = dtype
set_module_kwargs["dtype"] = dtype
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
if old_param.is_contiguous():
param = param.contiguous()
param_device = _determine_param_device(param_name, device_map)
# bnb params are flattened. # bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied # gguf quants have a different shape based on the type of quantization applied
...@@ -236,7 +289,9 @@ def load_model_dict_into_meta( ...@@ -236,7 +289,9 @@ def load_model_dict_into_meta(
if ( if (
is_quantized is_quantized
and hf_quantizer.pre_quantized and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) and hf_quantizer.check_if_quantized_param(
model, param, param_name, state_dict, param_device=param_device
)
): ):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param) hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
else: else:
...@@ -244,35 +299,23 @@ def load_model_dict_into_meta( ...@@ -244,35 +299,23 @@ def load_model_dict_into_meta(
raise ValueError( raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
) )
if param_device == "disk":
if is_quantized and ( offload_index = offload_weight(param, param_name, offload_folder, offload_index)
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
): ):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
else: else:
if accepts_dtype: set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
if named_buffers is None:
return unexpected_keys
for param_name, param in named_buffers:
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
return unexpected_keys return offload_index, state_dict_index
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it # copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy() state_dict = state_dict.copy()
...@@ -280,15 +323,19 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[ ...@@ -280,15 +323,19 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively. # so we need to apply the function recursively.
def load(module: torch.nn.Module, prefix: str = ""): def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False):
args = (state_dict, prefix, {}, True, [], [], error_msgs) local_metadata = {}
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
if assign_to_params_buffers and not is_torch_version(">=", "2.1"):
logger.info("You need to have torch>=2.1 in order to load the model with assign_to_params_buffers=True")
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
module._load_from_state_dict(*args) module._load_from_state_dict(*args)
for name, child in module._modules.items(): for name, child in module._modules.items():
if child is not None: if child is not None:
load(child, prefix + name + ".") load(child, prefix + name + ".", assign_to_params_buffers)
load(model_to_load) load(model_to_load, assign_to_params_buffers=assign_to_params_buffers)
return error_msgs return error_msgs
...@@ -343,46 +390,6 @@ def _fetch_index_file( ...@@ -343,46 +390,6 @@ def _fetch_index_file(
return index_file return index_file
# Adapted from
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
def _merge_sharded_checkpoints(
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
weight_map = sharded_metadata.get("weight_map", None)
if weight_map is None:
raise KeyError("'weight_map' key not found in the shard index file.")
# Collect all unique safetensors files from weight_map
files_to_load = set(weight_map.values())
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
merged_state_dict = {}
# Load tensors from each unique file
for file_name in files_to_load:
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
if dduf_entries:
if part_file_path not in dduf_entries:
raise FileNotFoundError(f"Part file {file_name} not found.")
else:
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")
if is_safetensors:
if dduf_entries:
with dduf_entries[part_file_path].as_mmap() as mm:
tensors = safetensors.torch.load(mm)
merged_state_dict.update(tensors)
else:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
else:
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
return merged_state_dict
def _fetch_index_file_legacy( def _fetch_index_file_legacy(
is_local, is_local,
pretrained_model_name_or_path, pretrained_model_name_or_path,
......
This diff is collapsed.
...@@ -280,9 +280,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -280,9 +280,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
act_fn="silu_fp32", act_fn="silu_fp32",
) )
self.text_embedding_padding = nn.Parameter( self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim))
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
)
self.pos_embed = PatchEmbed( self.pos_embed = PatchEmbed(
height=sample_size, height=sample_size,
......
...@@ -693,7 +693,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -693,7 +693,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None) max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None) offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False) offload_state_dict = kwargs.pop("offload_state_dict", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
dduf_file = kwargs.pop("dduf_file", None) dduf_file = kwargs.pop("dduf_file", None)
......
...@@ -235,18 +235,16 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -235,18 +235,16 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
torch_dtype = torch.float16 torch_dtype = torch.float16
return torch_dtype return torch_dtype
# (sayakpaul): I think it could be better to disable custom `device_map`s def update_device_map(self, device_map):
# for the first phase of the integration in the interest of simplicity. if device_map is None:
# Commenting this for discussions on the PR. device_map = {"": f"cuda:{torch.cuda.current_device()}"}
# def update_device_map(self, device_map): logger.info(
# if device_map is None: "The device_map was not initialized. "
# device_map = {"": torch.cuda.current_device()} "Setting device_map to {"
# logger.info( ": f`cuda:{torch.cuda.current_device()}`}. "
# "The device_map was not initialized. " "If you want to use the model for inference, please set device_map ='auto' "
# "Setting device_map to {'':torch.cuda.current_device()}. " )
# "If you want to use the model for inference, please set device_map ='auto' " return device_map
# )
# return device_map
def _process_model_before_weight_loading( def _process_model_before_weight_loading(
self, self,
...@@ -289,9 +287,9 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -289,9 +287,9 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
) )
model.config.quantization_config = self.quantization_config model.config.quantization_config = self.quantization_config
model.is_loaded_in_4bit = True
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
model.is_loaded_in_4bit = True
model.is_4bit_serializable = self.is_serializable model.is_4bit_serializable = self.is_serializable
return model return model
...@@ -400,16 +398,17 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -400,16 +398,17 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
torch_dtype = torch.float16 torch_dtype = torch.float16
return torch_dtype return torch_dtype
# # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
# def update_device_map(self, device_map): def update_device_map(self, device_map):
# if device_map is None: if device_map is None:
# device_map = {"": torch.cuda.current_device()} device_map = {"": f"cuda:{torch.cuda.current_device()}"}
# logger.info( logger.info(
# "The device_map was not initialized. " "The device_map was not initialized. "
# "Setting device_map to {'':torch.cuda.current_device()}. " "Setting device_map to {"
# "If you want to use the model for inference, please set device_map ='auto' " ": f`cuda:{torch.cuda.current_device()}`}. "
# ) "If you want to use the model for inference, please set device_map ='auto' "
# return device_map )
return device_map
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if target_dtype != torch.int8: if target_dtype != torch.int8:
...@@ -493,11 +492,10 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -493,11 +492,10 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
model.is_loaded_in_8bit = True
model.is_8bit_serializable = self.is_serializable model.is_8bit_serializable = self.is_serializable
return model return model
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading with 4bit->8bit
def _process_model_before_weight_loading( def _process_model_before_weight_loading(
self, self,
model: "ModelMixin", model: "ModelMixin",
...@@ -539,6 +537,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -539,6 +537,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
) )
model.config.quantization_config = self.quantization_config model.config.quantization_config = self.quantization_config
model.is_loaded_in_8bit = True
@property @property
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable
......
...@@ -338,22 +338,6 @@ def _get_model_file( ...@@ -338,22 +338,6 @@ def _get_model_file(
) from e ) from e
# Adapted from
# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976
# Differences are in parallelization of shard downloads and checking if shards are present.
def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames):
shards_path = os.path.join(local_dir, subfolder)
shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
for shard_file in shard_filenames:
if not os.path.exists(shard_file):
raise ValueError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
def _get_checkpoint_shard_files( def _get_checkpoint_shard_files(
pretrained_model_name_or_path, pretrained_model_name_or_path,
index_filename, index_filename,
...@@ -396,13 +380,22 @@ def _get_checkpoint_shard_files( ...@@ -396,13 +380,22 @@ def _get_checkpoint_shard_files(
shards_path = os.path.join(pretrained_model_name_or_path, subfolder) shards_path = os.path.join(pretrained_model_name_or_path, subfolder)
# First, let's deal with local folder. # First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path) or dduf_entries:
_check_if_shards_exist_locally( shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames for shard_file in shard_filenames:
) if dduf_entries:
return shards_path, sharded_metadata if shard_file not in dduf_entries:
elif dduf_entries: raise FileNotFoundError(
return shards_path, sharded_metadata f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
else:
if not os.path.exists(shard_file):
raise FileNotFoundError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
return shard_filenames, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub # At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames allow_patterns = original_shard_filenames
...@@ -444,7 +437,9 @@ def _get_checkpoint_shard_files( ...@@ -444,7 +437,9 @@ def _get_checkpoint_shard_files(
" again after checking your internet connection." " again after checking your internet connection."
) from e ) from e
return cached_folder, sharded_metadata cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
return cached_filenames, sharded_metadata
def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None): def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
......
...@@ -37,7 +37,7 @@ from huggingface_hub.utils import is_jinja_available ...@@ -37,7 +37,7 @@ from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized from parameterized import parameterized
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnProcessor, AttnProcessor,
AttnProcessor2_0, AttnProcessor2_0,
...@@ -200,12 +200,12 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -200,12 +200,12 @@ class ModelUtilsTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
def test_accelerate_loading_error_message(self): def test_missing_key_loading_warning_message(self):
with self.assertRaises(ValueError) as error_context: with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
# make sure that error message states what keys are missing # make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception) assert "conv_out.bias" in " ".join(logs.output)
@parameterized.expand( @parameterized.expand(
[ [
...@@ -334,6 +334,58 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -334,6 +334,58 @@ class ModelUtilsTest(unittest.TestCase):
assert model.config.in_channels == 9 assert model.config.in_channels == 9
@require_torch_gpu
def test_keep_modules_in_fp32(self):
r"""
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
Also ensures if inference works.
"""
fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
for torch_dtype in [torch.bfloat16, torch.float16]:
SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]
model = SD3Transformer2DModel.from_pretrained(
"hf-internal-testing/tiny-sd3-pipe", subfolder="transformer", torch_dtype=torch_dtype
).to(torch_device)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if name in model._keep_in_fp32_modules:
self.assertTrue(module.weight.dtype == torch.float32)
else:
self.assertTrue(module.weight.dtype == torch_dtype)
def get_dummy_inputs():
batch_size = 2
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
# test if inference works.
with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch_dtype):
input_dict_for_transformer = get_dummy_inputs()
model_inputs = {
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
}
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
_ = model(**model_inputs)
SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
class UNetTesterMixin: class UNetTesterMixin:
def test_forward_with_norm_groups(self): def test_forward_with_norm_groups(self):
......
...@@ -136,7 +136,7 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -136,7 +136,7 @@ class BnB4BitBasicTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_compute_dtype=torch.float16,
) )
self.model_4bit = SD3Transformer2DModel.from_pretrained( self.model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
) )
def tearDown(self): def tearDown(self):
...@@ -202,7 +202,7 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -202,7 +202,7 @@ class BnB4BitBasicTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_compute_dtype=torch.float16,
) )
model = SD3Transformer2DModel.from_pretrained( model = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
) )
for name, module in model.named_modules(): for name, module in model.named_modules():
...@@ -327,7 +327,7 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -327,7 +327,7 @@ class BnB4BitBasicTests(Base4bitTests):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
nf4_config = BitsAndBytesConfig(load_in_4bit=True) nf4_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = SD3Transformer2DModel.from_pretrained( model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
) )
model_4bit.save_pretrained(tmpdirname) model_4bit.save_pretrained(tmpdirname)
del model_4bit del model_4bit
...@@ -362,7 +362,7 @@ class BnB4BitTrainingTests(Base4bitTests): ...@@ -362,7 +362,7 @@ class BnB4BitTrainingTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_compute_dtype=torch.float16,
) )
self.model_4bit = SD3Transformer2DModel.from_pretrained( self.model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
) )
def test_training(self): def test_training(self):
...@@ -410,7 +410,7 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -410,7 +410,7 @@ class SlowBnb4BitTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_compute_dtype=torch.float16,
) )
model_4bit = SD3Transformer2DModel.from_pretrained( model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
) )
self.pipeline_4bit = DiffusionPipeline.from_pretrained( self.pipeline_4bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_4bit, torch_dtype=torch.float16 self.model_name, transformer=model_4bit, torch_dtype=torch.float16
...@@ -472,7 +472,7 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -472,7 +472,7 @@ class SlowBnb4BitTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_compute_dtype=torch.float16,
) )
model_4bit = SD3Transformer2DModel.from_pretrained( model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
) )
logger = logging.get_logger("diffusers.pipelines.pipeline_utils") logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
...@@ -502,6 +502,7 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -502,6 +502,7 @@ class SlowBnb4BitTests(Base4bitTests):
subfolder="transformer", subfolder="transformer",
quantization_config=transformer_nf4_config, quantization_config=transformer_nf4_config,
torch_dtype=torch.float16, torch_dtype=torch.float16,
device_map=torch_device,
) )
text_encoder_3_nf4_config = BnbConfig( text_encoder_3_nf4_config = BnbConfig(
load_in_4bit=True, load_in_4bit=True,
...@@ -513,6 +514,7 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -513,6 +514,7 @@ class SlowBnb4BitTests(Base4bitTests):
subfolder="text_encoder_3", subfolder="text_encoder_3",
quantization_config=text_encoder_3_nf4_config, quantization_config=text_encoder_3_nf4_config,
torch_dtype=torch.float16, torch_dtype=torch.float16,
device_map=torch_device,
) )
# CUDA device placement works. # CUDA device placement works.
pipeline_4bit = DiffusionPipeline.from_pretrained( pipeline_4bit = DiffusionPipeline.from_pretrained(
...@@ -527,6 +529,94 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -527,6 +529,94 @@ class SlowBnb4BitTests(Base4bitTests):
del pipeline_4bit del pipeline_4bit
def test_device_map(self):
"""
Test if the quantized model is working properly with "auto".
cpu/disk offloading as well doesn't work with bnb.
"""
def get_dummy_tensor_inputs(device=None, seed: int = 0):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"txt_ids": text_ids,
"img_ids": image_ids,
"timestep": timestep,
}
inputs = get_dummy_tensor_inputs(torch_device)
expected_slice = np.array(
[0.47070312, 0.00390625, -0.03662109, -0.19628906, -0.53125, 0.5234375, -0.17089844, -0.59375, 0.578125]
)
# non sharded
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
# sharded
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
@require_transformers_version_greater("4.44.0") @require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests): class SlowBnb4BitFluxTests(Base4bitTests):
...@@ -610,7 +700,10 @@ class BaseBnb4BitSerializationTests(Base4bitTests): ...@@ -610,7 +700,10 @@ class BaseBnb4BitSerializationTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_compute_dtype=torch.bfloat16,
) )
model_0 = SD3Transformer2DModel.from_pretrained( model_0 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=self.quantization_config self.model_name,
subfolder="transformer",
quantization_config=self.quantization_config,
device_map=torch_device,
) )
self.assertTrue("_pre_quantization_dtype" in model_0.config) self.assertTrue("_pre_quantization_dtype" in model_0.config)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
......
...@@ -138,7 +138,7 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -138,7 +138,7 @@ class BnB8bitBasicTests(Base8bitTests):
) )
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained( self.model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
) )
def tearDown(self): def tearDown(self):
...@@ -200,7 +200,7 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -200,7 +200,7 @@ class BnB8bitBasicTests(Base8bitTests):
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model = SD3Transformer2DModel.from_pretrained( model = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
) )
for name, module in model.named_modules(): for name, module in model.named_modules():
...@@ -242,7 +242,7 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -242,7 +242,7 @@ class BnB8bitBasicTests(Base8bitTests):
""" """
config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"]) config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"])
model_8bit = SD3Transformer2DModel.from_pretrained( model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=config self.model_name, subfolder="transformer", quantization_config=config, device_map=torch_device
) )
linear = get_some_linear_layer(model_8bit) linear = get_some_linear_layer(model_8bit)
self.assertTrue(linear.weight.dtype == torch.int8) self.assertTrue(linear.weight.dtype == torch.int8)
...@@ -319,6 +319,7 @@ class Bnb8bitDeviceTests(Base8bitTests): ...@@ -319,6 +319,7 @@ class Bnb8bitDeviceTests(Base8bitTests):
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
subfolder="transformer", subfolder="transformer",
quantization_config=mixed_int8_config, quantization_config=mixed_int8_config,
device_map=torch_device,
) )
def tearDown(self): def tearDown(self):
...@@ -343,7 +344,7 @@ class BnB8bitTrainingTests(Base8bitTests): ...@@ -343,7 +344,7 @@ class BnB8bitTrainingTests(Base8bitTests):
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained( self.model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
) )
def test_training(self): def test_training(self):
...@@ -387,7 +388,7 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -387,7 +388,7 @@ class SlowBnb8bitTests(Base8bitTests):
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained( model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
) )
self.pipeline_8bit = DiffusionPipeline.from_pretrained( self.pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_8bit, torch_dtype=torch.float16 self.model_name, transformer=model_8bit, torch_dtype=torch.float16
...@@ -415,7 +416,10 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -415,7 +416,10 @@ class SlowBnb8bitTests(Base8bitTests):
def test_model_cpu_offload_raises_warning(self): def test_model_cpu_offload_raises_warning(self):
model_8bit = SD3Transformer2DModel.from_pretrained( model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True) self.model_name,
subfolder="transformer",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map=torch_device,
) )
pipeline_8bit = DiffusionPipeline.from_pretrained( pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_8bit, torch_dtype=torch.float16 self.model_name, transformer=model_8bit, torch_dtype=torch.float16
...@@ -430,7 +434,10 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -430,7 +434,10 @@ class SlowBnb8bitTests(Base8bitTests):
def test_moving_to_cpu_throws_warning(self): def test_moving_to_cpu_throws_warning(self):
model_8bit = SD3Transformer2DModel.from_pretrained( model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True) self.model_name,
subfolder="transformer",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map=torch_device,
) )
logger = logging.get_logger("diffusers.pipelines.pipeline_utils") logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(30) logger.setLevel(30)
...@@ -483,6 +490,7 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -483,6 +490,7 @@ class SlowBnb8bitTests(Base8bitTests):
subfolder="transformer", subfolder="transformer",
quantization_config=transformer_8bit_config, quantization_config=transformer_8bit_config,
torch_dtype=torch.float16, torch_dtype=torch.float16,
device_map=torch_device,
) )
text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True) text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
text_encoder_3_8bit = T5EncoderModel.from_pretrained( text_encoder_3_8bit = T5EncoderModel.from_pretrained(
...@@ -490,6 +498,7 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -490,6 +498,7 @@ class SlowBnb8bitTests(Base8bitTests):
subfolder="text_encoder_3", subfolder="text_encoder_3",
quantization_config=text_encoder_3_8bit_config, quantization_config=text_encoder_3_8bit_config,
torch_dtype=torch.float16, torch_dtype=torch.float16,
device_map=torch_device,
) )
# CUDA device placement works. # CUDA device placement works.
pipeline_8bit = DiffusionPipeline.from_pretrained( pipeline_8bit = DiffusionPipeline.from_pretrained(
...@@ -504,6 +513,99 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -504,6 +513,99 @@ class SlowBnb8bitTests(Base8bitTests):
del pipeline_8bit del pipeline_8bit
def test_device_map(self):
"""
Test if the quantized model is working properly with "auto"
pu/disk offloading doesn't work with bnb.
"""
def get_dummy_tensor_inputs(device=None, seed: int = 0):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"txt_ids": text_ids,
"img_ids": image_ids,
"timestep": timestep,
}
inputs = get_dummy_tensor_inputs(torch_device)
expected_slice = np.array(
[
0.33789062,
-0.04736328,
-0.00256348,
-0.23144531,
-0.49804688,
0.4375,
-0.15429688,
-0.65234375,
0.44335938,
]
)
# non sharded
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
# sharded
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
@require_transformers_version_greater("4.44.0") @require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests): class SlowBnb8bitFluxTests(Base8bitTests):
...@@ -579,7 +681,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests): ...@@ -579,7 +681,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
load_in_8bit=True, load_in_8bit=True,
) )
self.model_0 = SD3Transformer2DModel.from_pretrained( self.model_0 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=quantization_config self.model_name, subfolder="transformer", quantization_config=quantization_config, device_map=torch_device
) )
def tearDown(self): def tearDown(self):
......
...@@ -34,6 +34,7 @@ from diffusers.utils.testing_utils import ( ...@@ -34,6 +34,7 @@ from diffusers.utils.testing_utils import (
is_torch_available, is_torch_available,
is_torchao_available, is_torchao_available,
nightly, nightly,
numpy_cosine_similarity_distance,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
require_torchao_version_greater_or_equal, require_torchao_version_greater_or_equal,
...@@ -282,9 +283,6 @@ class TorchAoTest(unittest.TestCase): ...@@ -282,9 +283,6 @@ class TorchAoTest(unittest.TestCase):
self.assertEqual(weight.quant_max, 15) self.assertEqual(weight.quant_max, 15)
def test_device_map(self): def test_device_map(self):
# Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did
# it would have errored out. Now, we do. So, device_map basically never worked with or without
# sharded checkpoints. This will need to be supported in the future (TODO(aryan))
""" """
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps. Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
...@@ -301,54 +299,73 @@ class TorchAoTest(unittest.TestCase): ...@@ -301,54 +299,73 @@ class TorchAoTest(unittest.TestCase):
} }
device_maps = ["auto", custom_device_map_dict] device_maps = ["auto", custom_device_map_dict]
# inputs = self.get_dummy_tensor_inputs(torch_device) inputs = self.get_dummy_tensor_inputs(torch_device)
# expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) # requires with different expected slices since models are different due to offload (we don't quantize modules offloaded to cpu/disk)
expected_slice_auto = np.array(
[
0.34179688,
-0.03613281,
0.01428223,
-0.22949219,
-0.49609375,
0.4375,
-0.1640625,
-0.66015625,
0.43164062,
]
)
expected_slice_offload = np.array(
[0.34375, -0.03515625, 0.0123291, -0.22753906, -0.49414062, 0.4375, -0.16308594, -0.66015625, 0.43554688]
)
for device_map in device_maps: for device_map in device_maps:
# device_map_to_compare = {"": 0} if device_map == "auto" else device_map if device_map == "auto":
expected_slice = expected_slice_auto
# Test non-sharded model - should work else:
with self.assertRaises(NotImplementedError): expected_slice = expected_slice_offload
with tempfile.TemporaryDirectory() as offload_folder: with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
_ = FluxTransformer2DModel.from_pretrained( quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-pipe",
subfolder="transformer", subfolder="transformer",
quantization_config=quantization_config, quantization_config=quantization_config,
device_map=device_map, device_map=device_map,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
offload_folder=offload_folder, offload_folder=offload_folder,
) )
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# self.assertTrue(isinstance(weight, AffineQuantizedTensor)) # Note that when performing cpu/disk offload, the offloaded weights are not quantized, only the weights on the gpu.
# This is not the case when the model are already quantized
# output = quantized_model(**inputs)[0] if "transformer_blocks.0" in device_map:
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy() self.assertTrue(isinstance(weight, nn.Parameter))
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) else:
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
# Test sharded model - should not work
with self.assertRaises(NotImplementedError): output = quantized_model(**inputs)[0]
with tempfile.TemporaryDirectory() as offload_folder: output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
_ = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded", with tempfile.TemporaryDirectory() as offload_folder:
subfolder="transformer", quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantization_config=quantization_config, quantized_model = FluxTransformer2DModel.from_pretrained(
device_map=device_map, "hf-internal-testing/tiny-flux-sharded",
torch_dtype=torch.bfloat16, subfolder="transformer",
offload_folder=offload_folder, quantization_config=quantization_config,
) device_map=device_map,
torch_dtype=torch.bfloat16,
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight offload_folder=offload_folder,
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) )
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# output = quantized_model(**inputs)[0] if "transformer_blocks.0" in device_map:
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy() self.assertTrue(isinstance(weight, nn.Parameter))
else:
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) self.assertTrue(isinstance(weight, AffineQuantizedTensor))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def test_modules_to_not_convert(self): def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
...@@ -544,7 +561,7 @@ class TorchAoSerializationTest(unittest.TestCase): ...@@ -544,7 +561,7 @@ class TorchAoSerializationTest(unittest.TestCase):
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
weight = quantized_model.transformer_blocks[0].ff.net[2].weight weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device): def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device) quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
...@@ -564,7 +581,7 @@ class TorchAoSerializationTest(unittest.TestCase): ...@@ -564,7 +581,7 @@ class TorchAoSerializationTest(unittest.TestCase):
loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor) loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)
) )
) )
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def test_int_a8w8_cuda(self): def test_int_a8w8_cuda(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment