Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9e1b1ca4
Unverified
Commit
9e1b1ca4
authored
Aug 29, 2022
by
Patrick von Platen
Committed by
GitHub
Aug 29, 2022
Browse files
[Tests] Make sure tests are on GPU (#269)
* [Tests] Make sure tests are on GPU * move more models * speed up tests
parent
16172c1c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
47 additions
and
14 deletions
+47
-14
tests/test_config.py
tests/test_config.py
+0
-0
tests/test_models_unet.py
tests/test_models_unet.py
+7
-2
tests/test_models_vae.py
tests/test_models_vae.py
+6
-1
tests/test_models_vq.py
tests/test_models_vq.py
+6
-2
tests/test_pipelines.py
tests/test_pipelines.py
+28
-9
No files found.
tests/test_
modeling_utils
.py
→
tests/test_
config
.py
View file @
9e1b1ca4
File moved
tests/test_models_unet.py
View file @
9e1b1ca4
...
@@ -24,6 +24,9 @@ from diffusers.testing_utils import floats_tensor, torch_device
...
@@ -24,6 +24,9 @@ from diffusers.testing_utils import floats_tensor, torch_device
from
.test_modeling_common
import
ModelTesterMixin
from
.test_modeling_common
import
ModelTesterMixin
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNet2DModel
model_class
=
UNet2DModel
...
@@ -133,18 +136,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -133,18 +136,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
UNet2DModel
.
from_pretrained
(
"fusing/unet-ldm-dummy-update"
)
model
=
UNet2DModel
.
from_pretrained
(
"fusing/unet-ldm-dummy-update"
)
model
.
eval
()
model
.
eval
()
model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
sample_size
,
model
.
config
.
sample_size
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
sample_size
,
model
.
config
.
sample_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
noise
=
noise
.
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
]).
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)[
"sample"
]
output
=
model
(
noise
,
time_step
)[
"sample"
]
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
.
cpu
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
13.3258
,
-
20.1100
,
-
15.9873
,
-
17.6617
,
-
23.0596
,
-
17.9419
,
-
13.3675
,
-
16.1889
,
-
12.3800
])
expected_output_slice
=
torch
.
tensor
([
-
13.3258
,
-
20.1100
,
-
15.9873
,
-
17.6617
,
-
23.0596
,
-
17.9419
,
-
13.3675
,
-
16.1889
,
-
12.3800
])
# fmt: on
# fmt: on
...
...
tests/test_models_vae.py
View file @
9e1b1ca4
...
@@ -23,6 +23,9 @@ from diffusers.testing_utils import floats_tensor, torch_device
...
@@ -23,6 +23,9 @@ from diffusers.testing_utils import floats_tensor, torch_device
from
.test_modeling_common
import
ModelTesterMixin
from
.test_modeling_common
import
ModelTesterMixin
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
AutoencoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
AutoencoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
AutoencoderKL
model_class
=
AutoencoderKL
...
@@ -74,6 +77,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
...
@@ -74,6 +77,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
)
model
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
)
model
=
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -81,10 +85,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
...
@@ -81,10 +85,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
image
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
sample_size
,
model
.
config
.
sample_size
)
image
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
sample_size
,
model
.
config
.
sample_size
)
image
=
image
.
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
image
,
sample_posterior
=
True
)
output
=
model
(
image
,
sample_posterior
=
True
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
.
cpu
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
4.0078e-01
,
-
3.8304e-04
,
-
1.2681e-01
,
-
1.1462e-01
,
2.0095e-01
,
1.0893e-01
,
-
8.8248e-02
,
-
3.0361e-01
,
-
9.8646e-03
])
expected_output_slice
=
torch
.
tensor
([
-
4.0078e-01
,
-
3.8304e-04
,
-
1.2681e-01
,
-
1.1462e-01
,
2.0095e-01
,
1.0893e-01
,
-
8.8248e-02
,
-
3.0361e-01
,
-
9.8646e-03
])
# fmt: on
# fmt: on
...
...
tests/test_models_vq.py
View file @
9e1b1ca4
...
@@ -23,6 +23,9 @@ from diffusers.testing_utils import floats_tensor, torch_device
...
@@ -23,6 +23,9 @@ from diffusers.testing_utils import floats_tensor, torch_device
from
.test_modeling_common
import
ModelTesterMixin
from
.test_modeling_common
import
ModelTesterMixin
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
VQModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
VQModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
VQModel
model_class
=
VQModel
...
@@ -73,17 +76,18 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -73,17 +76,18 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
VQModel
.
from_pretrained
(
"fusing/vqgan-dummy"
)
model
=
VQModel
.
from_pretrained
(
"fusing/vqgan-dummy"
)
model
.
eval
()
model
.
to
(
torch_device
).
eval
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
image
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
sample_size
,
model
.
config
.
sample_size
)
image
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
sample_size
,
model
.
config
.
sample_size
)
image
=
image
.
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
image
)
output
=
model
(
image
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
.
cpu
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0153
,
-
0.4044
,
-
0.1880
,
-
0.5161
,
-
0.2418
,
-
0.4072
,
-
0.1612
,
-
0.0633
,
-
0.0143
])
expected_output_slice
=
torch
.
tensor
([
-
0.0153
,
-
0.4044
,
-
0.1880
,
-
0.5161
,
-
0.2418
,
-
0.4072
,
-
0.1612
,
-
0.0633
,
-
0.0143
])
# fmt: on
# fmt: on
...
...
tests/test_pipelines.py
View file @
9e1b1ca4
...
@@ -59,10 +59,12 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -59,10 +59,12 @@ class PipelineTesterMixin(unittest.TestCase):
schedular
=
DDPMScheduler
(
num_train_timesteps
=
10
)
schedular
=
DDPMScheduler
(
num_train_timesteps
=
10
)
ddpm
=
DDPMPipeline
(
model
,
schedular
)
ddpm
=
DDPMPipeline
(
model
,
schedular
)
ddpm
.
to
(
torch_device
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
ddpm
.
save_pretrained
(
tmpdirname
)
ddpm
.
save_pretrained
(
tmpdirname
)
new_ddpm
=
DDPMPipeline
.
from_pretrained
(
tmpdirname
)
new_ddpm
=
DDPMPipeline
.
from_pretrained
(
tmpdirname
)
new_ddpm
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -76,11 +78,12 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -76,11 +78,12 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model_path
=
"google/ddpm-cifar10-32"
model_path
=
"google/ddpm-cifar10-32"
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
scheduler
=
DDPMScheduler
(
num_train_timesteps
=
10
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm
.
scheduler
.
num_timesteps
=
10
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
)
ddpm_from_hub
.
scheduler
.
num_timesteps
=
10
ddpm
.
to
(
torch_device
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
)
ddpm_from_hub
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -94,14 +97,15 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -94,14 +97,15 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_from_pretrained_hub_pass_model
(
self
):
def
test_from_pretrained_hub_pass_model
(
self
):
model_path
=
"google/ddpm-cifar10-32"
model_path
=
"google/ddpm-cifar10-32"
scheduler
=
DDPMScheduler
(
num_train_timesteps
=
10
)
# pass unet into DiffusionPipeline
# pass unet into DiffusionPipeline
unet
=
UNet2DModel
.
from_pretrained
(
model_path
)
unet
=
UNet2DModel
.
from_pretrained
(
model_path
)
ddpm_from_hub_custom_model
=
DiffusionPipeline
.
from_pretrained
(
model_path
,
unet
=
unet
)
ddpm_from_hub_custom_model
=
DiffusionPipeline
.
from_pretrained
(
model_path
,
unet
=
unet
,
scheduler
=
scheduler
)
ddpm_from_hub_custom_model
.
to
(
torch_device
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
_custom_model
.
scheduler
.
num_timesteps
=
10
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
)
ddpm_from_hub
.
scheduler
.
num_timesteps
=
10
ddpm_from_hub
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -116,6 +120,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -116,6 +120,7 @@ class PipelineTesterMixin(unittest.TestCase):
model_path
=
"google/ddpm-cifar10-32"
model_path
=
"google/ddpm-cifar10-32"
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
)
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
)
pipe
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
images
=
pipe
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
images
=
pipe
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
...
@@ -141,6 +146,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -141,6 +146,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler
=
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
image
=
ddpm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
...
@@ -159,6 +165,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -159,6 +165,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
image
=
ddpm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
...
@@ -177,6 +184,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -177,6 +184,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddim
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
,
output_type
=
"numpy"
)[
"sample"
]
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
,
output_type
=
"numpy"
)[
"sample"
]
...
@@ -195,6 +203,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -195,6 +203,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
pndm
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
pndm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
image
=
pndm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
...
@@ -207,6 +216,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -207,6 +216,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ldm_text2img
(
self
):
def
test_ldm_text2img
(
self
):
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
)
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
)
ldm
.
to
(
torch_device
)
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -223,6 +233,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -223,6 +233,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ldm_text2img_fast
(
self
):
def
test_ldm_text2img_fast
(
self
):
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
)
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
)
ldm
.
to
(
torch_device
)
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -290,6 +301,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -290,6 +301,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler
=
ScoreSdeVeScheduler
.
from_config
(
model_id
)
scheduler
=
ScoreSdeVeScheduler
.
from_config
(
model_id
)
sde_ve
=
ScoreSdeVePipeline
(
unet
=
model
,
scheduler
=
scheduler
)
sde_ve
=
ScoreSdeVePipeline
(
unet
=
model
,
scheduler
=
scheduler
)
sde_ve
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
image
=
sde_ve
(
num_inference_steps
=
300
,
output_type
=
"numpy"
)[
"sample"
]
image
=
sde_ve
(
num_inference_steps
=
300
,
output_type
=
"numpy"
)[
"sample"
]
...
@@ -304,6 +316,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -304,6 +316,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ldm_uncond
(
self
):
def
test_ldm_uncond
(
self
):
ldm
=
LDMPipeline
.
from_pretrained
(
"CompVis/ldm-celebahq-256"
)
ldm
=
LDMPipeline
.
from_pretrained
(
"CompVis/ldm-celebahq-256"
)
ldm
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
,
output_type
=
"numpy"
)[
"sample"
]
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
,
output_type
=
"numpy"
)[
"sample"
]
...
@@ -323,7 +336,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -323,7 +336,9 @@ class PipelineTesterMixin(unittest.TestCase):
ddim_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
ddpm_scheduler
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
ddpm_scheduler
)
ddpm
.
to
(
torch_device
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
ddim_scheduler
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
ddim_scheduler
)
ddim
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
ddpm_image
=
ddpm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
ddpm_image
=
ddpm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
...
@@ -343,7 +358,10 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -343,7 +358,10 @@ class PipelineTesterMixin(unittest.TestCase):
ddim_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
ddpm_scheduler
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
ddpm_scheduler
)
ddpm
.
to
(
torch_device
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
ddim_scheduler
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
ddim_scheduler
)
ddim
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
ddpm_images
=
ddpm
(
batch_size
=
4
,
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
ddpm_images
=
ddpm
(
batch_size
=
4
,
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
...
@@ -363,6 +381,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -363,6 +381,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler
=
KarrasVeScheduler
(
tensor_format
=
"pt"
)
scheduler
=
KarrasVeScheduler
(
tensor_format
=
"pt"
)
pipe
=
KarrasVePipeline
(
unet
=
model
,
scheduler
=
scheduler
)
pipe
=
KarrasVePipeline
(
unet
=
model
,
scheduler
=
scheduler
)
pipe
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
pipe
(
num_inference_steps
=
20
,
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
image
=
pipe
(
num_inference_steps
=
20
,
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment