Unverified Commit 7139f0e8 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

fix: norm group test for UNet3D. (#2959)

parent 8c530fc2
......@@ -119,12 +119,11 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase):
== "XFormersAttnProcessor"
), "xformers is not enabled"
# Overriding because `block_out_channels` needs to be different for this model.
# Overriding to set `norm_num_groups` needs to be different for this model.
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 32
init_dict["block_out_channels"] = (32, 64, 64, 64)
model = self.model_class(**init_dict)
model.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment