Unverified Commit 39b87b14 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

feat: allow flux transformer to be sharded during inference (#9159)

* feat: support sharding for flux.

* tests
parent 3e460432
......@@ -251,6 +251,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
@register_to_config
def __init__(
......
......@@ -29,6 +29,8 @@ enable_full_determinism()
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
@property
def dummy_input(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment