Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a2090375
Unverified
Commit
a2090375
authored
Aug 06, 2022
by
Suraj Patil
Committed by
GitHub
Aug 06, 2022
Browse files
[VAE] fix the downsample block in Encoder. (#156)
* pass downsample_padding in encoder * update tests
parent
c4a3b09a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
19 deletions
+9
-19
src/diffusers/models/vae.py
src/diffusers/models/vae.py
+1
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+8
-19
No files found.
src/diffusers/models/vae.py
View file @
a2090375
...
@@ -40,6 +40,7 @@ class Encoder(nn.Module):
...
@@ -40,6 +40,7 @@ class Encoder(nn.Module):
out_channels
=
output_channel
,
out_channels
=
output_channel
,
add_downsample
=
not
is_final_block
,
add_downsample
=
not
is_final_block
,
resnet_eps
=
1e-6
,
resnet_eps
=
1e-6
,
downsample_padding
=
0
,
resnet_act_fn
=
act_fn
,
resnet_act_fn
=
act_fn
,
attn_num_head_channels
=
None
,
attn_num_head_channels
=
None
,
temb_channels
=
None
,
temb_channels
=
None
,
...
...
tests/test_modeling_utils.py
View file @
a2090375
...
@@ -555,11 +555,11 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -555,11 +555,11 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
init_dict
=
{
"block_out_channels"
:
[
64
],
"block_out_channels"
:
[
32
,
64
],
"in_channels"
:
3
,
"in_channels"
:
3
,
"out_channels"
:
3
,
"out_channels"
:
3
,
"down_block_types"
:
[
"DownEncoderBlock2D"
],
"down_block_types"
:
[
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
],
"up_block_types"
:
[
"UpDecoderBlock2D"
],
"up_block_types"
:
[
"UpDecoderBlock2D"
,
"UpDecoderBlock2D"
],
"latent_channels"
:
3
,
"latent_channels"
:
3
,
}
}
inputs_dict
=
self
.
dummy_input
inputs_dict
=
self
.
dummy_input
...
@@ -595,7 +595,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -595,7 +595,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
1.1321
,
0.
1056
,
0.
3505
,
-
0.
64
61
,
-
0.2
014
,
0.
0419
,
-
0.
5763
,
-
0.
8462
,
-
0.
4218
])
expected_output_slice
=
torch
.
tensor
([
-
0.0153
,
-
0.
4044
,
-
0.
1880
,
-
0.
51
61
,
-
0.2
418
,
-
0.
4072
,
-
0.
1612
,
-
0.
0633
,
-
0.
0143
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
...
@@ -623,22 +623,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
...
@@ -623,22 +623,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
init_dict
=
{
"ch"
:
64
,
"block_out_channels"
:
[
32
,
64
],
"ch_mult"
:
(
1
,),
"embed_dim"
:
4
,
"in_channels"
:
3
,
"attn_resolutions"
:
[],
"num_res_blocks"
:
1
,
"out_ch"
:
3
,
"resolution"
:
32
,
"z_channels"
:
4
,
}
init_dict
=
{
"block_out_channels"
:
[
64
],
"in_channels"
:
3
,
"in_channels"
:
3
,
"out_channels"
:
3
,
"out_channels"
:
3
,
"down_block_types"
:
[
"DownEncoderBlock2D"
],
"down_block_types"
:
[
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
],
"up_block_types"
:
[
"UpDecoderBlock2D"
],
"up_block_types"
:
[
"UpDecoderBlock2D"
,
"UpDecoderBlock2D"
],
"latent_channels"
:
4
,
"latent_channels"
:
4
,
}
}
inputs_dict
=
self
.
dummy_input
inputs_dict
=
self
.
dummy_input
...
@@ -674,7 +663,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
...
@@ -674,7 +663,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.3900
,
-
0.2800
,
0.1281
,
-
0.4449
,
-
0.4890
,
-
0.0207
,
0.0784
,
-
0.1258
,
-
0.0409
])
expected_output_slice
=
torch
.
tensor
([
-
4.0078e-01
,
-
3.8304e-04
,
-
1.2681e-01
,
-
1.1462e-01
,
2.0095e-01
,
1.0893e-01
,
-
8.8248e-02
,
-
3.0361e-01
,
-
9.8646e-03
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment