Unverified Commit a1f9a712 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix offload gpu tests etc (#10366)

* add

* style
parent ec37e209
...@@ -82,6 +82,20 @@ class GLUMBConv(nn.Module): ...@@ -82,6 +82,20 @@ class GLUMBConv(nn.Module):
return hidden_states return hidden_states
class SanaModulatedNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
return hidden_states
class SanaTransformerBlock(nn.Module): class SanaTransformerBlock(nn.Module):
r""" r"""
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629). Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
...@@ -221,7 +235,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -221,7 +235,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"] _no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
@register_to_config @register_to_config
def __init__( def __init__(
...@@ -288,8 +302,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -288,8 +302,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 4. Output blocks # 4. Output blocks
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -462,13 +475,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -462,13 +475,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
) )
# 3. Normalization # 3. Normalization
shift, scale = ( hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# 4. Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify # 5. Unpatchify
......
...@@ -29,7 +29,7 @@ import numpy as np ...@@ -29,7 +29,7 @@ import numpy as np
import requests_mock import requests_mock
import torch import torch
import torch.nn as nn import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized from parameterized import parameterized
...@@ -1080,7 +1080,7 @@ class ModelTesterMixin: ...@@ -1080,7 +1080,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_persistent_sizes(model)[""] model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works. # We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
...@@ -1110,7 +1110,7 @@ class ModelTesterMixin: ...@@ -1110,7 +1110,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_persistent_sizes(model)[""] model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False) model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
...@@ -1144,7 +1144,7 @@ class ModelTesterMixin: ...@@ -1144,7 +1144,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_persistent_sizes(model)[""] model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
...@@ -1172,7 +1172,7 @@ class ModelTesterMixin: ...@@ -1172,7 +1172,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_persistent_sizes(model)[""] model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works. # We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
...@@ -1183,6 +1183,7 @@ class ModelTesterMixin: ...@@ -1183,6 +1183,7 @@ class ModelTesterMixin:
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded # Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import unittest import unittest
import pytest
import torch import torch
from diffusers import SanaTransformer2DModel from diffusers import SanaTransformer2DModel
...@@ -33,6 +32,7 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase): ...@@ -33,6 +32,7 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaTransformer2DModel model_class = SanaTransformer2DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
uses_custom_attn_processor = True uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.9]
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -81,27 +81,3 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase): ...@@ -81,27 +81,3 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaTransformer2DModel"} expected_set = {"SanaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_cpu_offload(self):
return super().test_cpu_offload()
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_disk_offload_with_safetensors(self):
return super().test_disk_offload_with_safetensors()
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_disk_offload_without_safetensors(self):
return super().test_disk_offload_without_safetensors()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment