Unverified Commit cdadb023 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Make Video Tests faster (#5787)

* update test

* update
parent 51fd3dd2
...@@ -985,7 +985,7 @@ class TemporalConvLayer(nn.Module): ...@@ -985,7 +985,7 @@ class TemporalConvLayer(nn.Module):
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
""" """
def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0): def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32):
super().__init__() super().__init__()
out_dim = out_dim or in_dim out_dim = out_dim or in_dim
self.in_dim = in_dim self.in_dim = in_dim
...@@ -993,22 +993,22 @@ class TemporalConvLayer(nn.Module): ...@@ -993,22 +993,22 @@ class TemporalConvLayer(nn.Module):
# conv layers # conv layers
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)) nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
) )
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.GroupNorm(norm_num_groups, out_dim),
nn.SiLU(), nn.SiLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
) )
self.conv3 = nn.Sequential( self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.GroupNorm(norm_num_groups, out_dim),
nn.SiLU(), nn.SiLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
) )
self.conv4 = nn.Sequential( self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.GroupNorm(norm_num_groups, out_dim),
nn.SiLU(), nn.SiLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
......
...@@ -269,6 +269,7 @@ class UNetMidBlock3DCrossAttn(nn.Module): ...@@ -269,6 +269,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
in_channels, in_channels,
in_channels, in_channels,
dropout=0.1, dropout=0.1,
norm_num_groups=resnet_groups,
) )
] ]
attentions = [] attentions = []
...@@ -316,6 +317,7 @@ class UNetMidBlock3DCrossAttn(nn.Module): ...@@ -316,6 +317,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
in_channels, in_channels,
in_channels, in_channels,
dropout=0.1, dropout=0.1,
norm_num_groups=resnet_groups,
) )
) )
...@@ -406,6 +408,7 @@ class CrossAttnDownBlock3D(nn.Module): ...@@ -406,6 +408,7 @@ class CrossAttnDownBlock3D(nn.Module):
out_channels, out_channels,
out_channels, out_channels,
dropout=0.1, dropout=0.1,
norm_num_groups=resnet_groups,
) )
) )
attentions.append( attentions.append(
...@@ -529,6 +532,7 @@ class DownBlock3D(nn.Module): ...@@ -529,6 +532,7 @@ class DownBlock3D(nn.Module):
out_channels, out_channels,
out_channels, out_channels,
dropout=0.1, dropout=0.1,
norm_num_groups=resnet_groups,
) )
) )
...@@ -622,6 +626,7 @@ class CrossAttnUpBlock3D(nn.Module): ...@@ -622,6 +626,7 @@ class CrossAttnUpBlock3D(nn.Module):
out_channels, out_channels,
out_channels, out_channels,
dropout=0.1, dropout=0.1,
norm_num_groups=resnet_groups,
) )
) )
attentions.append( attentions.append(
...@@ -764,6 +769,7 @@ class UpBlock3D(nn.Module): ...@@ -764,6 +769,7 @@ class UpBlock3D(nn.Module):
out_channels, out_channels,
out_channels, out_channels,
dropout=0.1, dropout=0.1,
norm_num_groups=resnet_groups,
) )
) )
......
...@@ -173,6 +173,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -173,6 +173,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attention_head_dim=attention_head_dim, attention_head_dim=attention_head_dim,
in_channels=block_out_channels[0], in_channels=block_out_channels[0],
num_layers=1, num_layers=1,
norm_num_groups=norm_num_groups,
) )
# class embedding # class embedding
......
...@@ -62,8 +62,8 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -62,8 +62,8 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet3DConditionModel( unet = UNet3DConditionModel(
block_out_channels=(32, 32), block_out_channels=(4, 8),
layers_per_block=2, layers_per_block=1,
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
...@@ -71,6 +71,7 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -71,6 +71,7 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"), up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
cross_attention_dim=4, cross_attention_dim=4,
attention_head_dim=4, attention_head_dim=4,
norm_num_groups=2,
) )
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_start=0.00085, beta_start=0.00085,
...@@ -81,13 +82,14 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -81,13 +82,14 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
) )
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKL( vae = AutoencoderKL(
block_out_channels=(32,), block_out_channels=(8,),
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
down_block_types=["DownEncoderBlock2D"], down_block_types=["DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D"], up_block_types=["UpDecoderBlock2D"],
latent_channels=4, latent_channels=4,
sample_size=32, sample_size=32,
norm_num_groups=2,
) )
torch.manual_seed(0) torch.manual_seed(0)
text_encoder_config = CLIPTextConfig( text_encoder_config = CLIPTextConfig(
...@@ -142,10 +144,11 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -142,10 +144,11 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_slice = frames[0][-3:, -3:, -1] image_slice = frames[0][-3:, -3:, -1]
assert frames[0].shape == (32, 32, 3) assert frames[0].shape == (32, 32, 3)
expected_slice = np.array([91.0, 152.0, 66.0, 192.0, 94.0, 126.0, 101.0, 123.0, 152.0]) expected_slice = np.array([192.0, 44.0, 157.0, 140.0, 108.0, 104.0, 123.0, 144.0, 129.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@unittest.skipIf(torch_device != "cuda", reason="Feature isn't heavily used. Test in CUDA environment only.")
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False, expected_max_diff=3e-3) self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False, expected_max_diff=3e-3)
......
...@@ -70,15 +70,16 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -70,15 +70,16 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet3DConditionModel( unet = UNet3DConditionModel(
block_out_channels=(32, 64, 64, 64), block_out_channels=(4, 8),
layers_per_block=2, layers_per_block=1,
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
down_block_types=("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"), down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
cross_attention_dim=32, cross_attention_dim=32,
attention_head_dim=4, attention_head_dim=4,
norm_num_groups=2,
) )
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_start=0.00085, beta_start=0.00085,
...@@ -89,13 +90,18 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -89,13 +90,18 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
) )
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKL( vae = AutoencoderKL(
block_out_channels=[32, 64], block_out_channels=[
8,
],
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], down_block_types=[
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], "DownEncoderBlock2D",
],
up_block_types=["UpDecoderBlock2D"],
latent_channels=4, latent_channels=4,
sample_size=128, sample_size=32,
norm_num_groups=2,
) )
torch.manual_seed(0) torch.manual_seed(0)
text_encoder_config = CLIPTextConfig( text_encoder_config = CLIPTextConfig(
...@@ -154,7 +160,7 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -154,7 +160,7 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_slice = frames[0][-3:, -3:, -1] image_slice = frames[0][-3:, -3:, -1]
assert frames[0].shape == (32, 32, 3) assert frames[0].shape == (32, 32, 3)
expected_slice = np.array([106, 117, 113, 174, 137, 112, 148, 151, 131]) expected_slice = np.array([162.0, 136.0, 132.0, 140.0, 139.0, 137.0, 169.0, 134.0, 132.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment