Unverified Commit 71ad16b4 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Add `_no_split_modules` to some models (#10308)



* set supports gradient checkpointing to true where necessary; add missing no split modules

* fix cogvideox tests

* update

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent ee7e141d
......@@ -1214,7 +1214,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Adapted from `transformers` modeling_utils.py
def _get_no_split_modules(self, device_map: str):
"""
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
Get the modules of the model that should not be split when using device_map. We iterate through the modules to
get the underlying `_no_split_modules`.
Args:
......
......@@ -210,6 +210,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
@register_to_config
def __init__(
......
......@@ -221,6 +221,8 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
Scaling factor to apply in 3D positional embeddings across time dimension.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
......
......@@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
@register_to_config
def __init__(
......
......@@ -542,6 +542,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"""
_supports_gradient_checkpointing = True
_no_split_modules = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
@register_to_config
def __init__(
......
......@@ -71,7 +71,7 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
"num_layers": 1,
"num_layers": 2,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
......@@ -130,7 +130,7 @@ class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
"num_layers": 1,
"num_layers": 2,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
......
......@@ -71,7 +71,7 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
init_dict = {
"patch_size": 2,
"in_channels": 4,
"num_layers": 1,
"num_layers": 2,
"attention_head_dim": 4,
"num_attention_heads": 2,
"out_channels": 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment