Unverified Commit 14f6464b authored by M Saqlain's avatar M Saqlain Committed by GitHub
Browse files

[Tests] Reduce the model size in the lumina test (#8985)

* Reduced model size for lumina-tests

* Handled failing tests
parent ba5af5ae
...@@ -34,19 +34,19 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM ...@@ -34,19 +34,19 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
transformer = LuminaNextDiT2DModel( transformer = LuminaNextDiT2DModel(
sample_size=16, sample_size=4,
patch_size=2, patch_size=2,
in_channels=4, in_channels=4,
hidden_size=24, hidden_size=4,
num_layers=2, num_layers=2,
num_attention_heads=3, num_attention_heads=1,
num_kv_heads=1, num_kv_heads=1,
multiple_of=16, multiple_of=16,
ffn_dim_multiplier=None, ffn_dim_multiplier=None,
norm_eps=1e-5, norm_eps=1e-5,
learn_sigma=True, learn_sigma=True,
qk_norm=True, qk_norm=True,
cross_attention_dim=32, cross_attention_dim=8,
scaling_factor=1.0, scaling_factor=1.0,
) )
torch.manual_seed(0) torch.manual_seed(0)
...@@ -57,8 +57,8 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM ...@@ -57,8 +57,8 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM
torch.manual_seed(0) torch.manual_seed(0)
config = GemmaConfig( config = GemmaConfig(
head_dim=4, head_dim=2,
hidden_size=32, hidden_size=8,
intermediate_size=37, intermediate_size=37,
num_attention_heads=4, num_attention_heads=4,
num_hidden_layers=2, num_hidden_layers=2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment