Unverified Commit f5929e03 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

[FEAT] Model loading refactor (#10604)



* first draft model loading refactor

* revert name change

* fix bnb

* revert name

* fix dduf

* fix huanyan

* style

* Update src/diffusers/models/model_loading_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* suggestions from reviews

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* remove safetensors check

* fix default value

* more fix from suggestions

* revert logic for single file

* style

* typing + fix couple of issues

* improve speed

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarAryan <aryan@huggingface.co>

* fp8 dtype

* add tests

* rename resolved_archive_file to resolved_model_file

* format

* map_location default cpu

* add utility function

* switch to smaller model + test inference

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* rm comment

* add log

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* add decorator

* cosine sim instead

* fix use_keep_in_fp32_modules

* comm

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 6fe05b9b
......@@ -52,7 +52,7 @@ logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate import dispatch_model, init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta
......@@ -366,19 +366,23 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules,
)
device_map = None
if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
empty_state_dict = model.state_dict()
unexpected_keys = [
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
]
device_map = {"": param_device}
load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
device_map=device_map,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
unexpected_keys=unexpected_keys,
)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
......@@ -400,4 +404,8 @@ class FromOriginalModelMixin:
model.eval()
if device_map is not None:
device_map_kwargs = {"device_map": device_map}
dispatch_model(model, **device_map_kwargs)
return model
......@@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm(
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
model.load_state_dict(diffusers_format_checkpoint, strict=False)
if torch_dtype is not None:
model.to(torch_dtype)
......@@ -2061,16 +2052,7 @@ def create_diffusers_t5_model_from_checkpoint(
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
model.load_state_dict(diffusers_format_checkpoint)
......
......@@ -20,13 +20,15 @@ import os
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
import safetensors
import torch
from huggingface_hub import DDUFEntry
from huggingface_hub.utils import EntryNotFoundError
from ..quantizers import DiffusersQuantizer
from ..utils import (
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
......@@ -55,7 +57,7 @@ _CLASS_REMAPPING_DICT = {
if is_accelerate_available():
from accelerate import infer_auto_device_map
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
# Adapted from `transformers` (see modeling_utils.py)
......@@ -132,17 +134,46 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class
def _check_archive_and_maybe_raise_error(checkpoint_file, format_list):
"""
Check format of the archive
"""
with safetensors.safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") not in format_list:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
"""
Find the device of param_name from the device_map.
"""
if device_map is None:
return "cpu"
else:
module_name = param_name
# find next higher level module that is defined in device_map:
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
if module_name == "" and "" not in device_map:
raise ValueError(f"{param_name} doesn't have any device set.")
return device_map[module_name]
def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
variant: Optional[str] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
disable_mmap: bool = False,
map_location: Union[str, torch.device] = "cpu",
):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
# TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
# when refactoring the _merge_sharded_checkpoints() method later.
# TODO: maybe refactor a bit this part where we pass a dict here
if isinstance(checkpoint_file, dict):
return checkpoint_file
try:
......@@ -152,19 +183,26 @@ def load_state_dict(
# tensors are loaded on cpu
with dduf_entries[checkpoint_file].as_mmap() as mm:
return safetensors.torch.load(mm)
_check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"])
if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read())
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
return safetensors.torch.load_file(checkpoint_file, device=map_location)
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else:
extra_args = {}
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
checkpoint_file,
map_location="cpu",
**weights_only_kwarg,
)
# mmap can only be used with files serialized with zipfile-based format.
if (
isinstance(checkpoint_file, str)
and map_location != "meta"
and is_torch_version(">=", "2.1.0")
and is_zipfile(checkpoint_file)
and not disable_mmap
):
extra_args = {"mmap": True}
return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
except Exception as e:
try:
with open(checkpoint_file) as f:
......@@ -188,23 +226,24 @@ def load_state_dict(
def load_model_dict_into_meta(
model,
state_dict: OrderedDict,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
hf_quantizer=None,
keep_in_fp32_modules=None,
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
hf_quantizer: Optional[DiffusersQuantizer] = None,
keep_in_fp32_modules: Optional[List] = None,
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
unexpected_keys: Optional[List[str]] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
offload_index: Optional[Dict] = None,
state_dict_index: Optional[Dict] = None,
state_dict_folder: Optional[Union[str, os.PathLike]] = None,
) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
if hf_quantizer is None:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`
"""
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
is_quantized = hf_quantizer is not None
empty_state_dict = model.state_dict()
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
......@@ -214,29 +253,45 @@ def load_model_dict_into_meta(
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
if torch.is_floating_point(param):
if (
keep_in_fp32_modules is not None
and any(
if dtype is not None and torch.is_floating_point(param):
if keep_in_fp32_modules is not None and any(
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
)
and dtype == torch.float16
):
param = param.to(torch.float32)
if accepts_dtype:
set_module_kwargs["dtype"] = torch.float32
else:
param = param.to(dtype)
if accepts_dtype:
set_module_kwargs["dtype"] = dtype
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
if old_param.is_contiguous():
param = param.contiguous()
param_device = _determine_param_device(param_name, device_map)
# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
if empty_state_dict[param_name].shape != param.shape:
if (
is_quantized
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
and hf_quantizer.check_if_quantized_param(
model, param, param_name, state_dict, param_device=param_device
)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
else:
......@@ -244,35 +299,23 @@ def load_model_dict_into_meta(
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
if param_device == "disk":
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
if named_buffers is None:
return unexpected_keys
for param_name, param in named_buffers:
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
return unexpected_keys
return offload_index, state_dict_index
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy()
......@@ -280,15 +323,19 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: torch.nn.Module, prefix: str = ""):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False):
local_metadata = {}
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
if assign_to_params_buffers and not is_torch_version(">=", "2.1"):
logger.info("You need to have torch>=2.1 in order to load the model with assign_to_params_buffers=True")
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(child, prefix + name + ".", assign_to_params_buffers)
load(model_to_load)
load(model_to_load, assign_to_params_buffers=assign_to_params_buffers)
return error_msgs
......@@ -343,46 +390,6 @@ def _fetch_index_file(
return index_file
# Adapted from
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
def _merge_sharded_checkpoints(
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
weight_map = sharded_metadata.get("weight_map", None)
if weight_map is None:
raise KeyError("'weight_map' key not found in the shard index file.")
# Collect all unique safetensors files from weight_map
files_to_load = set(weight_map.values())
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
merged_state_dict = {}
# Load tensors from each unique file
for file_name in files_to_load:
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
if dduf_entries:
if part_file_path not in dduf_entries:
raise FileNotFoundError(f"Part file {file_name} not found.")
else:
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")
if is_safetensors:
if dduf_entries:
with dduf_entries[part_file_path].as_mmap() as mm:
tensors = safetensors.torch.load(mm)
merged_state_dict.update(tensors)
else:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
else:
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
return merged_state_dict
def _fetch_index_file_legacy(
is_local,
pretrained_model_name_or_path,
......
This diff is collapsed.
......@@ -280,9 +280,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
act_fn="silu_fp32",
)
self.text_embedding_padding = nn.Parameter(
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
)
self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim))
self.pos_embed = PatchEmbed(
height=sample_size,
......
......@@ -693,7 +693,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
offload_state_dict = kwargs.pop("offload_state_dict", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
dduf_file = kwargs.pop("dduf_file", None)
......
......@@ -235,18 +235,16 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
torch_dtype = torch.float16
return torch_dtype
# (sayakpaul): I think it could be better to disable custom `device_map`s
# for the first phase of the integration in the interest of simplicity.
# Commenting this for discussions on the PR.
# def update_device_map(self, device_map):
# if device_map is None:
# device_map = {"": torch.cuda.current_device()}
# logger.info(
# "The device_map was not initialized. "
# "Setting device_map to {'':torch.cuda.current_device()}. "
# "If you want to use the model for inference, please set device_map ='auto' "
# )
# return device_map
def update_device_map(self, device_map):
if device_map is None:
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
logger.info(
"The device_map was not initialized. "
"Setting device_map to {"
": f`cuda:{torch.cuda.current_device()}`}. "
"If you want to use the model for inference, please set device_map ='auto' "
)
return device_map
def _process_model_before_weight_loading(
self,
......@@ -289,9 +287,9 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
model.config.quantization_config = self.quantization_config
model.is_loaded_in_4bit = True
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
model.is_loaded_in_4bit = True
model.is_4bit_serializable = self.is_serializable
return model
......@@ -400,16 +398,17 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
torch_dtype = torch.float16
return torch_dtype
# # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
# def update_device_map(self, device_map):
# if device_map is None:
# device_map = {"": torch.cuda.current_device()}
# logger.info(
# "The device_map was not initialized. "
# "Setting device_map to {'':torch.cuda.current_device()}. "
# "If you want to use the model for inference, please set device_map ='auto' "
# )
# return device_map
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
def update_device_map(self, device_map):
if device_map is None:
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
logger.info(
"The device_map was not initialized. "
"Setting device_map to {"
": f`cuda:{torch.cuda.current_device()}`}. "
"If you want to use the model for inference, please set device_map ='auto' "
)
return device_map
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if target_dtype != torch.int8:
......@@ -493,11 +492,10 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
model.is_loaded_in_8bit = True
model.is_8bit_serializable = self.is_serializable
return model
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading with 4bit->8bit
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
......@@ -539,6 +537,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
model.config.quantization_config = self.quantization_config
model.is_loaded_in_8bit = True
@property
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable
......
......@@ -338,22 +338,6 @@ def _get_model_file(
) from e
# Adapted from
# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976
# Differences are in parallelization of shard downloads and checking if shards are present.
def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames):
shards_path = os.path.join(local_dir, subfolder)
shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
for shard_file in shard_filenames:
if not os.path.exists(shard_file):
raise ValueError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
def _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_filename,
......@@ -396,13 +380,22 @@ def _get_checkpoint_shard_files(
shards_path = os.path.join(pretrained_model_name_or_path, subfolder)
# First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path):
_check_if_shards_exist_locally(
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
if os.path.isdir(pretrained_model_name_or_path) or dduf_entries:
shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
for shard_file in shard_filenames:
if dduf_entries:
if shard_file not in dduf_entries:
raise FileNotFoundError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
return shards_path, sharded_metadata
elif dduf_entries:
return shards_path, sharded_metadata
else:
if not os.path.exists(shard_file):
raise FileNotFoundError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
return shard_filenames, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames
......@@ -444,7 +437,9 @@ def _get_checkpoint_shard_files(
" again after checking your internet connection."
) from e
return cached_folder, sharded_metadata
cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
return cached_filenames, sharded_metadata
def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
......
......@@ -37,7 +37,7 @@ from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
......@@ -200,12 +200,12 @@ class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
def test_accelerate_loading_error_message(self):
with self.assertRaises(ValueError) as error_context:
def test_missing_key_loading_warning_message(self):
with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
# make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception)
assert "conv_out.bias" in " ".join(logs.output)
@parameterized.expand(
[
......@@ -334,6 +334,58 @@ class ModelUtilsTest(unittest.TestCase):
assert model.config.in_channels == 9
@require_torch_gpu
def test_keep_modules_in_fp32(self):
r"""
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
Also ensures if inference works.
"""
fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
for torch_dtype in [torch.bfloat16, torch.float16]:
SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]
model = SD3Transformer2DModel.from_pretrained(
"hf-internal-testing/tiny-sd3-pipe", subfolder="transformer", torch_dtype=torch_dtype
).to(torch_device)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if name in model._keep_in_fp32_modules:
self.assertTrue(module.weight.dtype == torch.float32)
else:
self.assertTrue(module.weight.dtype == torch_dtype)
def get_dummy_inputs():
batch_size = 2
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
# test if inference works.
with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch_dtype):
input_dict_for_transformer = get_dummy_inputs()
model_inputs = {
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
}
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
_ = model(**model_inputs)
SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
class UNetTesterMixin:
def test_forward_with_norm_groups(self):
......
......@@ -136,7 +136,7 @@ class BnB4BitBasicTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16,
)
self.model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
def tearDown(self):
......@@ -202,7 +202,7 @@ class BnB4BitBasicTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16,
)
model = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
for name, module in model.named_modules():
......@@ -327,7 +327,7 @@ class BnB4BitBasicTests(Base4bitTests):
with tempfile.TemporaryDirectory() as tmpdirname:
nf4_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
model_4bit.save_pretrained(tmpdirname)
del model_4bit
......@@ -362,7 +362,7 @@ class BnB4BitTrainingTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16,
)
self.model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
def test_training(self):
......@@ -410,7 +410,7 @@ class SlowBnb4BitTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
self.pipeline_4bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_4bit, torch_dtype=torch.float16
......@@ -472,7 +472,7 @@ class SlowBnb4BitTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
......@@ -502,6 +502,7 @@ class SlowBnb4BitTests(Base4bitTests):
subfolder="transformer",
quantization_config=transformer_nf4_config,
torch_dtype=torch.float16,
device_map=torch_device,
)
text_encoder_3_nf4_config = BnbConfig(
load_in_4bit=True,
......@@ -513,6 +514,7 @@ class SlowBnb4BitTests(Base4bitTests):
subfolder="text_encoder_3",
quantization_config=text_encoder_3_nf4_config,
torch_dtype=torch.float16,
device_map=torch_device,
)
# CUDA device placement works.
pipeline_4bit = DiffusionPipeline.from_pretrained(
......@@ -527,6 +529,94 @@ class SlowBnb4BitTests(Base4bitTests):
del pipeline_4bit
def test_device_map(self):
"""
Test if the quantized model is working properly with "auto".
cpu/disk offloading as well doesn't work with bnb.
"""
def get_dummy_tensor_inputs(device=None, seed: int = 0):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"txt_ids": text_ids,
"img_ids": image_ids,
"timestep": timestep,
}
inputs = get_dummy_tensor_inputs(torch_device)
expected_slice = np.array(
[0.47070312, 0.00390625, -0.03662109, -0.19628906, -0.53125, 0.5234375, -0.17089844, -0.59375, 0.578125]
)
# non sharded
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
# sharded
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
@require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests):
......@@ -610,7 +700,10 @@ class BaseBnb4BitSerializationTests(Base4bitTests):
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_0 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=self.quantization_config
self.model_name,
subfolder="transformer",
quantization_config=self.quantization_config,
device_map=torch_device,
)
self.assertTrue("_pre_quantization_dtype" in model_0.config)
with tempfile.TemporaryDirectory() as tmpdirname:
......
......@@ -138,7 +138,7 @@ class BnB8bitBasicTests(Base8bitTests):
)
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
def tearDown(self):
......@@ -200,7 +200,7 @@ class BnB8bitBasicTests(Base8bitTests):
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
for name, module in model.named_modules():
......@@ -242,7 +242,7 @@ class BnB8bitBasicTests(Base8bitTests):
"""
config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"])
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=config
self.model_name, subfolder="transformer", quantization_config=config, device_map=torch_device
)
linear = get_some_linear_layer(model_8bit)
self.assertTrue(linear.weight.dtype == torch.int8)
......@@ -319,6 +319,7 @@ class Bnb8bitDeviceTests(Base8bitTests):
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
subfolder="transformer",
quantization_config=mixed_int8_config,
device_map=torch_device,
)
def tearDown(self):
......@@ -343,7 +344,7 @@ class BnB8bitTrainingTests(Base8bitTests):
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
def test_training(self):
......@@ -387,7 +388,7 @@ class SlowBnb8bitTests(Base8bitTests):
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
self.pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_8bit, torch_dtype=torch.float16
......@@ -415,7 +416,10 @@ class SlowBnb8bitTests(Base8bitTests):
def test_model_cpu_offload_raises_warning(self):
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
self.model_name,
subfolder="transformer",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map=torch_device,
)
pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_8bit, torch_dtype=torch.float16
......@@ -430,7 +434,10 @@ class SlowBnb8bitTests(Base8bitTests):
def test_moving_to_cpu_throws_warning(self):
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
self.model_name,
subfolder="transformer",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map=torch_device,
)
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(30)
......@@ -483,6 +490,7 @@ class SlowBnb8bitTests(Base8bitTests):
subfolder="transformer",
quantization_config=transformer_8bit_config,
torch_dtype=torch.float16,
device_map=torch_device,
)
text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
text_encoder_3_8bit = T5EncoderModel.from_pretrained(
......@@ -490,6 +498,7 @@ class SlowBnb8bitTests(Base8bitTests):
subfolder="text_encoder_3",
quantization_config=text_encoder_3_8bit_config,
torch_dtype=torch.float16,
device_map=torch_device,
)
# CUDA device placement works.
pipeline_8bit = DiffusionPipeline.from_pretrained(
......@@ -504,6 +513,99 @@ class SlowBnb8bitTests(Base8bitTests):
del pipeline_8bit
def test_device_map(self):
"""
Test if the quantized model is working properly with "auto"
pu/disk offloading doesn't work with bnb.
"""
def get_dummy_tensor_inputs(device=None, seed: int = 0):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"txt_ids": text_ids,
"img_ids": image_ids,
"timestep": timestep,
}
inputs = get_dummy_tensor_inputs(torch_device)
expected_slice = np.array(
[
0.33789062,
-0.04736328,
-0.00256348,
-0.23144531,
-0.49804688,
0.4375,
-0.15429688,
-0.65234375,
0.44335938,
]
)
# non sharded
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
# sharded
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
@require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests):
......@@ -579,7 +681,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
load_in_8bit=True,
)
self.model_0 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=quantization_config
self.model_name, subfolder="transformer", quantization_config=quantization_config, device_map=torch_device
)
def tearDown(self):
......
......@@ -34,6 +34,7 @@ from diffusers.utils.testing_utils import (
is_torch_available,
is_torchao_available,
nightly,
numpy_cosine_similarity_distance,
require_torch,
require_torch_gpu,
require_torchao_version_greater_or_equal,
......@@ -282,9 +283,6 @@ class TorchAoTest(unittest.TestCase):
self.assertEqual(weight.quant_max, 15)
def test_device_map(self):
# Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did
# it would have errored out. Now, we do. So, device_map basically never worked with or without
# sharded checkpoints. This will need to be supported in the future (TODO(aryan))
"""
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
......@@ -301,17 +299,32 @@ class TorchAoTest(unittest.TestCase):
}
device_maps = ["auto", custom_device_map_dict]
# inputs = self.get_dummy_tensor_inputs(torch_device)
# expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
inputs = self.get_dummy_tensor_inputs(torch_device)
# requires with different expected slices since models are different due to offload (we don't quantize modules offloaded to cpu/disk)
expected_slice_auto = np.array(
[
0.34179688,
-0.03613281,
0.01428223,
-0.22949219,
-0.49609375,
0.4375,
-0.1640625,
-0.66015625,
0.43164062,
]
)
expected_slice_offload = np.array(
[0.34375, -0.03515625, 0.0123291, -0.22753906, -0.49414062, 0.4375, -0.16308594, -0.66015625, 0.43554688]
)
for device_map in device_maps:
# device_map_to_compare = {"": 0} if device_map == "auto" else device_map
# Test non-sharded model - should work
with self.assertRaises(NotImplementedError):
if device_map == "auto":
expected_slice = expected_slice_auto
else:
expected_slice = expected_slice_offload
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
_ = FluxTransformer2DModel.from_pretrained(
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
......@@ -320,19 +333,22 @@ class TorchAoTest(unittest.TestCase):
offload_folder=offload_folder,
)
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# output = quantized_model(**inputs)[0]
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
# Note that when performing cpu/disk offload, the offloaded weights are not quantized, only the weights on the gpu.
# This is not the case when the model are already quantized
if "transformer_blocks.0" in device_map:
self.assertTrue(isinstance(weight, nn.Parameter))
else:
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
# Test sharded model - should not work
with self.assertRaises(NotImplementedError):
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
_ = FluxTransformer2DModel.from_pretrained(
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
......@@ -341,14 +357,15 @@ class TorchAoTest(unittest.TestCase):
offload_folder=offload_folder,
)
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
# output = quantized_model(**inputs)[0]
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
if "transformer_blocks.0" in device_map:
self.assertTrue(isinstance(weight, nn.Parameter))
else:
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
......@@ -544,7 +561,7 @@ class TorchAoSerializationTest(unittest.TestCase):
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
......@@ -564,7 +581,7 @@ class TorchAoSerializationTest(unittest.TestCase):
loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)
)
)
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def test_int_a8w8_cuda(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment