Unverified Commit 7d887118 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] support saving and loading of sharded checkpoints (#7830)



* feat: support saving a model in sharded checkpoints.

* feat: make loading of sharded checkpoints work.

* add tests

* cleanse the loading logic a bit more.

* more resilience while loading from the Hub.

* parallelize shard downloads by using snapshot_download()/

* default to a shard size.

* more fix

* Empty-Commit

* debug

* fix

* uality

* more debugging

* fix more

* initial comments from Benjamin

* move certain methods to loading_utils

* add test to check if the correct number of shards are present.

* add a test to check if loading of sharded checkpoints from the Hub is okay

* clarify the unit when passed as an int.

* use hf_hub for sharding.

* remove unnecessary code

* remove unnecessary function

* lucain's comments.

* fixes

* address high-level comments.

* fix test

* subfolder shenanigans./

* Update src/diffusers/utils/hub_utils.py
Co-authored-by: default avatarLucain <lucainp@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucainp@gmail.com>

* remove _huggingface_hub_version as not needed.

* address more feedback.

* add a test for local_files_only=True/

* need hf hub to be at least 0.23.2

* style

* final comment.

* clean up subfolder.

* deal with suffixes in code.

* _add_variant default.

* use weights_name_pattern

* remove add_suffix_keyword

* clean up downloading of sharded ckpts.

* don't return something special when using index.json

* fix more

* don't use bare except

* remove comments and catch the errors better

* fix a couple of things when using is_file()

* empty

---------
Co-authored-by: default avatarLucain <lucainp@gmail.com>
parent b63c9568
......@@ -101,7 +101,7 @@ _deps = [
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.20.2",
"huggingface-hub>=0.23.2",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
......
......@@ -9,7 +9,7 @@ deps = {
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.20.2",
"huggingface-hub": "huggingface-hub>=0.23.2",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
......
......@@ -18,13 +18,19 @@ import importlib
import inspect
import os
from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union
import safetensors
import torch
from huggingface_hub.utils import EntryNotFoundError
from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
WEIGHTS_INDEX_NAME,
_add_variant,
_get_model_file,
is_accelerate_available,
is_torch_version,
logging,
......@@ -175,3 +181,52 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
load(model_to_load)
return error_msgs
def _fetch_index_file(
is_local,
pretrained_model_name_or_path,
subfolder,
use_safetensors,
cache_dir,
variant,
force_download,
resume_download,
proxies,
local_files_only,
token,
revision,
user_agent,
commit_hash,
):
if is_local:
index_file = Path(
pretrained_model_name_or_path,
subfolder or "",
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
)
else:
index_file_in_repo = Path(
subfolder or "",
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
).as_posix()
try:
index_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=index_file_in_repo,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file = Path(index_file)
except (EntryNotFoundError, EnvironmentError):
index_file = None
return index_file
......@@ -16,6 +16,7 @@
import inspect
import itertools
import json
import os
import re
from collections import OrderedDict
......@@ -25,7 +26,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import safetensors
import torch
from huggingface_hub import create_repo
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn
......@@ -33,9 +34,12 @@ from .. import __version__
from ..utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
_add_variant,
_get_checkpoint_shard_files,
_get_model_file,
deprecate,
is_accelerate_available,
......@@ -49,6 +53,7 @@ from ..utils.hub_utils import (
)
from .model_loading_utils import (
_determine_device_map,
_fetch_index_file,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
......@@ -57,6 +62,8 @@ from .model_loading_utils import (
logger = logging.get_logger(__name__)
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True
......@@ -263,6 +270,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
save_function: Optional[Callable] = None,
safe_serialization: bool = True,
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "5GB",
push_to_hub: bool = False,
**kwargs,
):
......@@ -285,6 +293,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
max_shard_size (`int` or `str`, defaults to `"5GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
If expressed as an integer, the unit is bytes.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
......@@ -296,6 +308,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weight_name_split = weights_name.split(".")
if len(weight_name_split) in [2, 3]:
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
else:
raise ValueError(f"Invalid {weights_name} provided.")
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
......@@ -317,18 +337,58 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Save the model
state_dict = model_to_save.state_dict()
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
# Save the model
if safe_serialization:
safetensors.torch.save_file(
state_dict, Path(save_directory, weights_name).as_posix(), metadata={"format": "pt"}
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor] for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
torch.save(state_dict, Path(save_directory, weights_name).as_posix())
logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
if push_to_hub:
# Create a new empty model card and eventually tag it
......@@ -566,6 +626,32 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
**kwargs,
)
# Determine if we're loading from a directory of sharded checkpoints.
is_sharded = False
index_file = None
is_local = os.path.isdir(pretrained_model_name_or_path)
index_file = _fetch_index_file(
is_local=is_local,
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder or "",
use_safetensors=use_safetensors,
cache_dir=cache_dir,
variant=variant,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
user_agent=user_agent,
commit_hash=commit_hash,
)
if index_file is not None and index_file.is_file():
is_sharded = True
if is_sharded and from_flax:
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
# load model
model_file = None
if from_flax:
......@@ -590,7 +676,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else:
if use_safetensors:
if is_sharded:
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
cache_dir=cache_dir,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder or "",
)
elif use_safetensors and not is_sharded:
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
......@@ -606,11 +706,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
user_agent=user_agent,
commit_hash=commit_hash,
)
except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise e
pass
if model_file is None:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
if model_file is None and not is_sharded:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
......@@ -632,7 +737,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model = cls.from_config(config, **unused_kwargs)
# if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None:
if device_map is None and not is_sharded:
param_device = "cpu"
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
......@@ -670,7 +775,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
try:
accelerate.load_checkpoint_and_dispatch(
model,
model_file,
model_file if not is_sharded else sharded_ckpt_cached_folder,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
......
......@@ -28,9 +28,11 @@ from .constants import (
MIN_PEFT_VERSION,
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
SAFETENSORS_WEIGHTS_NAME,
USE_PEFT_BACKEND,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from .deprecation_utils import deprecate
......@@ -40,6 +42,7 @@ from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to
from .hub_utils import (
PushToHubMixin,
_add_variant,
_get_checkpoint_shard_files,
_get_model_file,
extract_commit_hash,
http_user_agent,
......
......@@ -28,9 +28,11 @@ _CHECK_PEFT = os.environ.get("_CHECK_PEFT", "1") in ENV_VARS_TRUE_VALUES
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.bin.index.json"
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
SAFETENSORS_FILE_EXTENSION = "safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
......
......@@ -14,6 +14,7 @@
# limitations under the License.
import json
import os
import re
import sys
......@@ -29,6 +30,8 @@ from huggingface_hub import (
ModelCardData,
create_repo,
hf_hub_download,
model_info,
snapshot_download,
upload_folder,
)
from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
......@@ -393,6 +396,109 @@ def _get_model_file(
)
# Adapted from
# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976
# Differences are in parallelization of shard downloads and checking if shards are present.
def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames):
shards_path = os.path.join(local_dir, subfolder)
shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
for shard_file in shard_filenames:
if not os.path.exists(shard_file):
raise ValueError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
def _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_filename,
cache_dir=None,
proxies=None,
resume_download=False,
local_files_only=False,
token=None,
user_agent=None,
revision=None,
subfolder="",
):
"""
For a given model:
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
Hub
- returns the list of paths to all the shards, as well as some metadata.
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
"""
if not os.path.isfile(index_filename):
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
with open(index_filename, "r") as f:
index = json.loads(f.read())
original_shard_filenames = sorted(set(index["weight_map"].values()))
sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
sharded_metadata["weight_map"] = index["weight_map"].copy()
shards_path = os.path.join(pretrained_model_name_or_path, subfolder)
# First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path):
_check_if_shards_exist_locally(
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
return pretrained_model_name_or_path, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames
ignore_patterns = ["*.json", "*.md"]
if not local_files_only:
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
try:
# Load from URL
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
if local_files_only:
_check_if_shards_exist_locally(
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
return cached_folder, sharded_metadata
class PushToHubMixin:
"""
A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub.
......
......@@ -131,7 +131,6 @@ try:
except importlib_metadata.PackageNotFoundError:
_unidecode_available = False
_onnxruntime_version = "N/A"
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
if _onnx_available:
......
......@@ -14,6 +14,8 @@
# limitations under the License.
import inspect
import json
import os
import tempfile
import traceback
import unittest
......@@ -37,7 +39,7 @@ from diffusers.models.attention_processor import (
XFormersAttnProcessor,
)
from diffusers.training_utils import EMAModel
from diffusers.utils import is_torch_npu_available, is_xformers_available, logging
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
get_python_version,
......@@ -129,7 +131,9 @@ class ModelUtilsTest(unittest.TestCase):
)
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model"
assert (
download_requests.count("HEAD") == 3
), "3 HEAD requests one for config, one for model, and one for shard index file."
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m:
......@@ -142,8 +146,8 @@ class ModelUtilsTest(unittest.TestCase):
cache_requests = [r.method for r in m.request_history]
assert (
"HEAD" == cache_requests[0] and len(cache_requests) == 1
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
"HEAD" == cache_requests[0] and len(cache_requests) == 2
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
......@@ -866,6 +870,41 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu
def test_sharded_checkpoints(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
if model._no_split_modules is None:
return
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto")
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
......
......@@ -21,6 +21,7 @@ import unittest
from collections import OrderedDict
import torch
from huggingface_hub import snapshot_download
from parameterized import parameterized
from pytest import mark
......@@ -1034,6 +1035,25 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@require_torch_gpu
def test_load_sharded_checkpoint_from_hub(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy", device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu
def test_load_sharded_checkpoint_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
......@@ -29,6 +29,7 @@ import PIL.Image
import requests_mock
import safetensors.torch
import torch
import torch.nn as nn
from parameterized import parameterized
from PIL import Image
from requests.exceptions import HTTPError
......@@ -135,6 +136,7 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
class CustomEncoder(ModelMixin, ConfigMixin):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
class CustomPipeline(DiffusionPipeline):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment