Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
cdadb023
Unverified
Commit
cdadb023
authored
Nov 15, 2023
by
Dhruv Nair
Committed by
GitHub
Nov 15, 2023
Browse files
Make Video Tests faster (#5787)
* update test * update
parent
51fd3dd2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
34 additions
and
18 deletions
+34
-18
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+5
-5
src/diffusers/models/unet_3d_blocks.py
src/diffusers/models/unet_3d_blocks.py
+6
-0
src/diffusers/models/unet_3d_condition.py
src/diffusers/models/unet_3d_condition.py
+1
-0
tests/pipelines/text_to_video_synthesis/test_text_to_video.py
...s/pipelines/text_to_video_synthesis/test_text_to_video.py
+7
-4
tests/pipelines/text_to_video_synthesis/test_video_to_video.py
.../pipelines/text_to_video_synthesis/test_video_to_video.py
+15
-9
No files found.
src/diffusers/models/resnet.py
View file @
cdadb023
...
@@ -985,7 +985,7 @@ class TemporalConvLayer(nn.Module):
...
@@ -985,7 +985,7 @@ class TemporalConvLayer(nn.Module):
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
"""
"""
def
__init__
(
self
,
in_dim
:
int
,
out_dim
:
Optional
[
int
]
=
None
,
dropout
:
float
=
0.0
):
def
__init__
(
self
,
in_dim
:
int
,
out_dim
:
Optional
[
int
]
=
None
,
dropout
:
float
=
0.0
,
norm_num_groups
:
int
=
32
):
super
().
__init__
()
super
().
__init__
()
out_dim
=
out_dim
or
in_dim
out_dim
=
out_dim
or
in_dim
self
.
in_dim
=
in_dim
self
.
in_dim
=
in_dim
...
@@ -993,22 +993,22 @@ class TemporalConvLayer(nn.Module):
...
@@ -993,22 +993,22 @@ class TemporalConvLayer(nn.Module):
# conv layers
# conv layers
self
.
conv1
=
nn
.
Sequential
(
self
.
conv1
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
in_dim
),
nn
.
SiLU
(),
nn
.
Conv3d
(
in_dim
,
out_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
))
nn
.
GroupNorm
(
norm_num_groups
,
in_dim
),
nn
.
SiLU
(),
nn
.
Conv3d
(
in_dim
,
out_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
))
)
)
self
.
conv2
=
nn
.
Sequential
(
self
.
conv2
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_dim
),
nn
.
GroupNorm
(
norm_num_groups
,
out_dim
),
nn
.
SiLU
(),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
nn
.
Conv3d
(
out_dim
,
in_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)),
nn
.
Conv3d
(
out_dim
,
in_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)),
)
)
self
.
conv3
=
nn
.
Sequential
(
self
.
conv3
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_dim
),
nn
.
GroupNorm
(
norm_num_groups
,
out_dim
),
nn
.
SiLU
(),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
nn
.
Conv3d
(
out_dim
,
in_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)),
nn
.
Conv3d
(
out_dim
,
in_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)),
)
)
self
.
conv4
=
nn
.
Sequential
(
self
.
conv4
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_dim
),
nn
.
GroupNorm
(
norm_num_groups
,
out_dim
),
nn
.
SiLU
(),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
nn
.
Conv3d
(
out_dim
,
in_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)),
nn
.
Conv3d
(
out_dim
,
in_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)),
...
...
src/diffusers/models/unet_3d_blocks.py
View file @
cdadb023
...
@@ -269,6 +269,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
...
@@ -269,6 +269,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
in_channels
,
in_channels
,
in_channels
,
in_channels
,
dropout
=
0.1
,
dropout
=
0.1
,
norm_num_groups
=
resnet_groups
,
)
)
]
]
attentions
=
[]
attentions
=
[]
...
@@ -316,6 +317,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
...
@@ -316,6 +317,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
in_channels
,
in_channels
,
in_channels
,
in_channels
,
dropout
=
0.1
,
dropout
=
0.1
,
norm_num_groups
=
resnet_groups
,
)
)
)
)
...
@@ -406,6 +408,7 @@ class CrossAttnDownBlock3D(nn.Module):
...
@@ -406,6 +408,7 @@ class CrossAttnDownBlock3D(nn.Module):
out_channels
,
out_channels
,
out_channels
,
out_channels
,
dropout
=
0.1
,
dropout
=
0.1
,
norm_num_groups
=
resnet_groups
,
)
)
)
)
attentions
.
append
(
attentions
.
append
(
...
@@ -529,6 +532,7 @@ class DownBlock3D(nn.Module):
...
@@ -529,6 +532,7 @@ class DownBlock3D(nn.Module):
out_channels
,
out_channels
,
out_channels
,
out_channels
,
dropout
=
0.1
,
dropout
=
0.1
,
norm_num_groups
=
resnet_groups
,
)
)
)
)
...
@@ -622,6 +626,7 @@ class CrossAttnUpBlock3D(nn.Module):
...
@@ -622,6 +626,7 @@ class CrossAttnUpBlock3D(nn.Module):
out_channels
,
out_channels
,
out_channels
,
out_channels
,
dropout
=
0.1
,
dropout
=
0.1
,
norm_num_groups
=
resnet_groups
,
)
)
)
)
attentions
.
append
(
attentions
.
append
(
...
@@ -764,6 +769,7 @@ class UpBlock3D(nn.Module):
...
@@ -764,6 +769,7 @@ class UpBlock3D(nn.Module):
out_channels
,
out_channels
,
out_channels
,
out_channels
,
dropout
=
0.1
,
dropout
=
0.1
,
norm_num_groups
=
resnet_groups
,
)
)
)
)
...
...
src/diffusers/models/unet_3d_condition.py
View file @
cdadb023
...
@@ -173,6 +173,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
...
@@ -173,6 +173,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attention_head_dim
=
attention_head_dim
,
attention_head_dim
=
attention_head_dim
,
in_channels
=
block_out_channels
[
0
],
in_channels
=
block_out_channels
[
0
],
num_layers
=
1
,
num_layers
=
1
,
norm_num_groups
=
norm_num_groups
,
)
)
# class embedding
# class embedding
...
...
tests/pipelines/text_to_video_synthesis/test_text_to_video.py
View file @
cdadb023
...
@@ -62,8 +62,8 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -62,8 +62,8 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def
get_dummy_components
(
self
):
def
get_dummy_components
(
self
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
unet
=
UNet3DConditionModel
(
unet
=
UNet3DConditionModel
(
block_out_channels
=
(
32
,
32
),
block_out_channels
=
(
4
,
8
),
layers_per_block
=
2
,
layers_per_block
=
1
,
sample_size
=
32
,
sample_size
=
32
,
in_channels
=
4
,
in_channels
=
4
,
out_channels
=
4
,
out_channels
=
4
,
...
@@ -71,6 +71,7 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -71,6 +71,7 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
up_block_types
=
(
"UpBlock3D"
,
"CrossAttnUpBlock3D"
),
up_block_types
=
(
"UpBlock3D"
,
"CrossAttnUpBlock3D"
),
cross_attention_dim
=
4
,
cross_attention_dim
=
4
,
attention_head_dim
=
4
,
attention_head_dim
=
4
,
norm_num_groups
=
2
,
)
)
scheduler
=
DDIMScheduler
(
scheduler
=
DDIMScheduler
(
beta_start
=
0.00085
,
beta_start
=
0.00085
,
...
@@ -81,13 +82,14 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -81,13 +82,14 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
vae
=
AutoencoderKL
(
vae
=
AutoencoderKL
(
block_out_channels
=
(
32
,),
block_out_channels
=
(
8
,),
in_channels
=
3
,
in_channels
=
3
,
out_channels
=
3
,
out_channels
=
3
,
down_block_types
=
[
"DownEncoderBlock2D"
],
down_block_types
=
[
"DownEncoderBlock2D"
],
up_block_types
=
[
"UpDecoderBlock2D"
],
up_block_types
=
[
"UpDecoderBlock2D"
],
latent_channels
=
4
,
latent_channels
=
4
,
sample_size
=
32
,
sample_size
=
32
,
norm_num_groups
=
2
,
)
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
text_encoder_config
=
CLIPTextConfig
(
text_encoder_config
=
CLIPTextConfig
(
...
@@ -142,10 +144,11 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -142,10 +144,11 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_slice
=
frames
[
0
][
-
3
:,
-
3
:,
-
1
]
image_slice
=
frames
[
0
][
-
3
:,
-
3
:,
-
1
]
assert
frames
[
0
].
shape
==
(
32
,
32
,
3
)
assert
frames
[
0
].
shape
==
(
32
,
32
,
3
)
expected_slice
=
np
.
array
([
9
1.0
,
15
2.0
,
66
.0
,
1
92
.0
,
94
.0
,
1
26
.0
,
10
1
.0
,
123.0
,
1
52
.0
])
expected_slice
=
np
.
array
([
1
92.0
,
44
.0
,
1
57
.0
,
140
.0
,
1
08
.0
,
10
4
.0
,
123.0
,
1
44.0
,
129
.0
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
@
unittest
.
skipIf
(
torch_device
!=
"cuda"
,
reason
=
"Feature isn't heavily used. Test in CUDA environment only."
)
def
test_attention_slicing_forward_pass
(
self
):
def
test_attention_slicing_forward_pass
(
self
):
self
.
_test_attention_slicing_forward_pass
(
test_mean_pixel_difference
=
False
,
expected_max_diff
=
3e-3
)
self
.
_test_attention_slicing_forward_pass
(
test_mean_pixel_difference
=
False
,
expected_max_diff
=
3e-3
)
...
...
tests/pipelines/text_to_video_synthesis/test_video_to_video.py
View file @
cdadb023
...
@@ -70,15 +70,16 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -70,15 +70,16 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def
get_dummy_components
(
self
):
def
get_dummy_components
(
self
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
unet
=
UNet3DConditionModel
(
unet
=
UNet3DConditionModel
(
block_out_channels
=
(
32
,
64
,
64
,
64
),
block_out_channels
=
(
4
,
8
),
layers_per_block
=
2
,
layers_per_block
=
1
,
sample_size
=
32
,
sample_size
=
32
,
in_channels
=
4
,
in_channels
=
4
,
out_channels
=
4
,
out_channels
=
4
,
down_block_types
=
(
"CrossAttnDownBlock3D"
,
"CrossAttnDownBlock3D"
,
"CrossAttnDownBlock3D"
,
"DownBlock3D"
),
down_block_types
=
(
"CrossAttnDownBlock3D"
,
"DownBlock3D"
),
up_block_types
=
(
"UpBlock3D"
,
"CrossAttnUpBlock3D"
,
"CrossAttnUpBlock3D"
,
"CrossAttnUpBlock3D"
),
up_block_types
=
(
"UpBlock3D"
,
"CrossAttnUpBlock3D"
),
cross_attention_dim
=
32
,
cross_attention_dim
=
32
,
attention_head_dim
=
4
,
attention_head_dim
=
4
,
norm_num_groups
=
2
,
)
)
scheduler
=
DDIMScheduler
(
scheduler
=
DDIMScheduler
(
beta_start
=
0.00085
,
beta_start
=
0.00085
,
...
@@ -89,13 +90,18 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -89,13 +90,18 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
vae
=
AutoencoderKL
(
vae
=
AutoencoderKL
(
block_out_channels
=
[
32
,
64
],
block_out_channels
=
[
8
,
],
in_channels
=
3
,
in_channels
=
3
,
out_channels
=
3
,
out_channels
=
3
,
down_block_types
=
[
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
],
down_block_types
=
[
up_block_types
=
[
"UpDecoderBlock2D"
,
"UpDecoderBlock2D"
],
"DownEncoderBlock2D"
,
],
up_block_types
=
[
"UpDecoderBlock2D"
],
latent_channels
=
4
,
latent_channels
=
4
,
sample_size
=
128
,
sample_size
=
32
,
norm_num_groups
=
2
,
)
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
text_encoder_config
=
CLIPTextConfig
(
text_encoder_config
=
CLIPTextConfig
(
...
@@ -154,7 +160,7 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -154,7 +160,7 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_slice
=
frames
[
0
][
-
3
:,
-
3
:,
-
1
]
image_slice
=
frames
[
0
][
-
3
:,
-
3
:,
-
1
]
assert
frames
[
0
].
shape
==
(
32
,
32
,
3
)
assert
frames
[
0
].
shape
==
(
32
,
32
,
3
)
expected_slice
=
np
.
array
([
1
06
,
117
,
1
13
,
174
,
137
,
112
,
148
,
151
,
13
1
])
expected_slice
=
np
.
array
([
1
62.0
,
136.0
,
13
2.0
,
140.0
,
139.0
,
137.0
,
169.0
,
134.0
,
13
2.0
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment