Unverified Commit f848feba authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

feat: allow sharding for auraflow. (#8853)

parent b3825500
......@@ -274,6 +274,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
"""
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
_supports_gradient_checkpointing = True
@register_to_config
......
......@@ -29,6 +29,8 @@ enable_full_determinism()
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = AuraFlowTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
@property
def dummy_input(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment