"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8f21a9f0e29cc5e3d139095785d52ff474ebda0e"
Unverified Commit b58868e6 authored by Junsong Chen's avatar Junsong Chen Committed by GitHub
Browse files

[Sana bug] bug fix for 2K model config (#10340)



* fix the Positinoal Embedding bug in 2K model;

* Change the default model to the BF16 one for more stable training and output

* make style

* substract buffer size

* add compute_module_persistent_sizes

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent da21d590
...@@ -22,7 +22,7 @@ The model can be loaded with the following code snippet. ...@@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
```python ```python
from diffusers import SanaTransformer2DModel from diffusers import SanaTransformer2DModel
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16) transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
``` ```
## SanaTransformer2DModel ## SanaTransformer2DModel
......
...@@ -32,9 +32,9 @@ Available models: ...@@ -32,9 +32,9 @@ Available models:
| Model | Recommended dtype | | Model | Recommended dtype |
|:-----:|:-----------------:| |:-----:|:-----------------:|
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` | | [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` | | [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` | | [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` | | [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` | | [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
......
...@@ -88,13 +88,18 @@ def main(args): ...@@ -88,13 +88,18 @@ def main(args):
# y norm # y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight") converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# scheduler
flow_shift = 3.0 flow_shift = 3.0
# model config
if args.model_type == "SanaMS_1600M_P1_D20": if args.model_type == "SanaMS_1600M_P1_D20":
layer_num = 20 layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28": elif args.model_type == "SanaMS_600M_P1_D28":
layer_num = 28 layer_num = 28
else: else:
raise ValueError(f"{args.model_type} is not supported.") raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
for depth in range(layer_num): for depth in range(layer_num):
# Transformer blocks. # Transformer blocks.
...@@ -176,6 +181,7 @@ def main(args): ...@@ -176,6 +181,7 @@ def main(args):
patch_size=1, patch_size=1,
norm_elementwise_affine=False, norm_elementwise_affine=False,
norm_eps=1e-6, norm_eps=1e-6,
interpolation_scale=interpolation_scale[args.image_size],
) )
if is_accelerate_available(): if is_accelerate_available():
......
...@@ -242,6 +242,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -242,6 +242,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
patch_size: int = 1, patch_size: int = 1,
norm_elementwise_affine: bool = False, norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -249,14 +250,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -249,14 +250,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
inner_dim = num_attention_heads * attention_head_dim inner_dim = num_attention_heads * attention_head_dim
# 1. Patch Embedding # 1. Patch Embedding
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
height=sample_size, height=sample_size,
width=sample_size, width=sample_size,
patch_size=patch_size, patch_size=patch_size,
in_channels=in_channels, in_channels=in_channels,
embed_dim=inner_dim, embed_dim=inner_dim,
interpolation_scale=None, interpolation_scale=interpolation_scale,
pos_embed_type=None,
) )
# 2. Additional condition embeddings # 2. Additional condition embeddings
......
...@@ -59,13 +59,13 @@ EXAMPLE_DOC_STRING = """ ...@@ -59,13 +59,13 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import SanaPAGPipeline >>> from diffusers import SanaPAGPipeline
>>> pipe = SanaPAGPipeline.from_pretrained( >>> pipe = SanaPAGPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
... pag_applied_layers=["transformer_blocks.8"], ... pag_applied_layers=["transformer_blocks.8"],
... torch_dtype=torch.float32, ... torch_dtype=torch.float32,
... ) ... )
>>> pipe.to("cuda") >>> pipe.to("cuda")
>>> pipe.text_encoder.to(torch.bfloat16) >>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.transformer = pipe.transformer.to(torch.float16) >>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
>>> image[0].save("output.png") >>> image[0].save("output.png")
......
...@@ -62,11 +62,11 @@ EXAMPLE_DOC_STRING = """ ...@@ -62,11 +62,11 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import SanaPipeline >>> from diffusers import SanaPipeline
>>> pipe = SanaPipeline.from_pretrained( >>> pipe = SanaPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float32 ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
... ) ... )
>>> pipe.to("cuda") >>> pipe.to("cuda")
>>> pipe.text_encoder.to(torch.bfloat16) >>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.transformer = pipe.transformer.to(torch.float16) >>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
>>> image[0].save("output.png") >>> image[0].save("output.png")
......
...@@ -22,12 +22,14 @@ import traceback ...@@ -22,12 +22,14 @@ import traceback
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
import uuid import uuid
from typing import Dict, List, Tuple from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import requests_mock import requests_mock
import torch import torch
from accelerate.utils import compute_module_sizes import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized from parameterized import parameterized
...@@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): ...@@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
out_queue.join() out_queue.join()
def named_persistent_module_tensors(
module: nn.Module,
recurse: bool = False,
):
"""
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
Args:
module (`torch.nn.Module`):
The module we want the tensors on.
recurse (`bool`, *optional`, defaults to `False`):
Whether or not to go look in every submodule or just return the direct parameters and buffers.
"""
yield from module.named_parameters(recurse=recurse)
for named_buffer in module.named_buffers(recurse=recurse):
name, _ = named_buffer
# Get parent by splitting on dots and traversing the model
parent = module
if "." in name:
parent_name = name.rsplit(".", 1)[0]
for part in parent_name.split("."):
parent = getattr(parent, part)
name = name.split(".")[-1]
if name not in parent._non_persistent_buffers_set:
yield named_buffer
def compute_module_persistent_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the size of each submodule of a given model (parameters + persistent buffers).
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
module_list = []
module_list = named_persistent_module_tensors(model, recurse=True)
for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
for idx in range(len(name_parts) + 1):
module_sizes[".".join(name_parts[:idx])] += size
return module_sizes
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
...@@ -1012,7 +1080,7 @@ class ModelTesterMixin: ...@@ -1012,7 +1080,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
# We test several splits of sizes to make sure it works. # We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
...@@ -1042,7 +1110,7 @@ class ModelTesterMixin: ...@@ -1042,7 +1110,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False) model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
...@@ -1076,7 +1144,7 @@ class ModelTesterMixin: ...@@ -1076,7 +1144,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
...@@ -1104,7 +1172,7 @@ class ModelTesterMixin: ...@@ -1104,7 +1172,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
# We test several splits of sizes to make sure it works. # We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
...@@ -1132,7 +1200,7 @@ class ModelTesterMixin: ...@@ -1132,7 +1200,7 @@ class ModelTesterMixin:
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
...@@ -1164,7 +1232,7 @@ class ModelTesterMixin: ...@@ -1164,7 +1232,7 @@ class ModelTesterMixin:
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16" variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
...@@ -1204,7 +1272,7 @@ class ModelTesterMixin: ...@@ -1204,7 +1272,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
...@@ -1233,7 +1301,7 @@ class ModelTesterMixin: ...@@ -1233,7 +1301,7 @@ class ModelTesterMixin:
config, _ = self.prepare_init_args_and_inputs_for_common() config, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
model_size = compute_module_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16" variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment