Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c5c93996
Commit
c5c93996
authored
Jul 21, 2022
by
Patrick von Platen
Browse files
correct paths for tests
parent
836f3f35
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
28 deletions
+20
-28
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+20
-28
No files found.
tests/test_modeling_utils.py
View file @
c5c93996
...
@@ -365,9 +365,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -365,9 +365,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
return
init_dict
,
inputs_dict
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNet2DModel
.
from_pretrained
(
model
,
loading_info
=
UNet2DModel
.
from_pretrained
(
"fusing/unet-ldm-dummy-update"
,
output_loading_info
=
True
)
"/home/patrick/google_checkpoints/unet-ldm-dummy-update"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
@@ -378,7 +376,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -378,7 +376,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
UNet2DModel
.
from_pretrained
(
"
/home/patrick/google_checkpoints
/unet-ldm-dummy-update"
)
model
=
UNet2DModel
.
from_pretrained
(
"
fusing
/unet-ldm-dummy-update"
)
model
.
eval
()
model
.
eval
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -472,9 +470,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -472,9 +470,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
return
init_dict
,
inputs_dict
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNet2DModel
.
from_pretrained
(
model
,
loading_info
=
UNet2DModel
.
from_pretrained
(
"google/ncsnpp-celebahq-256"
,
output_loading_info
=
True
)
"/home/patrick/google_checkpoints/ncsnpp-celebahq-256"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
@@ -487,7 +483,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -487,7 +483,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained_ve_mid
(
self
):
def
test_output_pretrained_ve_mid
(
self
):
model
=
UNet2DModel
.
from_pretrained
(
"
/home/patrick/google_checkpoints
/ncsnpp-celebahq-256"
)
model
=
UNet2DModel
.
from_pretrained
(
"
google
/ncsnpp-celebahq-256"
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -512,7 +508,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -512,7 +508,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
def
test_output_pretrained_ve_large
(
self
):
def
test_output_pretrained_ve_large
(
self
):
model
=
UNet2DModel
.
from_pretrained
(
"
/home/patrick/google_checkpoints
/ncsnpp-ffhq-ve-dummy-update"
)
model
=
UNet2DModel
.
from_pretrained
(
"
fusing
/ncsnpp-ffhq-ve-dummy-update"
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -582,9 +578,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -582,9 +578,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
pass
pass
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
VQModel
.
from_pretrained
(
model
,
loading_info
=
VQModel
.
from_pretrained
(
"fusing/vqgan-dummy"
,
output_loading_info
=
True
)
"/home/patrick/google_checkpoints/vqgan-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
@@ -594,7 +588,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -594,7 +588,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
VQModel
.
from_pretrained
(
"
/home/patrick/google_checkpoints
/vqgan-dummy"
)
model
=
VQModel
.
from_pretrained
(
"
fusing
/vqgan-dummy"
)
model
.
eval
()
model
.
eval
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -655,9 +649,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
...
@@ -655,9 +649,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
pass
pass
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
AutoencoderKL
.
from_pretrained
(
model
,
loading_info
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
,
output_loading_info
=
True
)
"/home/patrick/google_checkpoints/autoencoder-kl-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
@@ -667,7 +659,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
...
@@ -667,7 +659,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
AutoencoderKL
.
from_pretrained
(
"
/home/patrick/google_checkpoints
/autoencoder-kl-dummy"
)
model
=
AutoencoderKL
.
from_pretrained
(
"
fusing
/autoencoder-kl-dummy"
)
model
.
eval
()
model
.
eval
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -715,7 +707,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -715,7 +707,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model_path
=
"
/home/patrick/google_checkpoints
/ddpm-cifar10-32"
model_path
=
"
google
/ddpm-cifar10-32"
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
...
@@ -733,7 +725,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -733,7 +725,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_output_format
(
self
):
def
test_output_format
(
self
):
model_path
=
"
/home/patrick/google_checkpoints
/ddpm-cifar10-32"
model_path
=
"
google
/ddpm-cifar10-32"
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
)
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
)
...
@@ -754,7 +746,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -754,7 +746,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ddpm_cifar10
(
self
):
def
test_ddpm_cifar10
(
self
):
model_id
=
"
/home/patrick/google_checkpoints
/ddpm-cifar10-32"
model_id
=
"
google
/ddpm-cifar10-32"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
...
@@ -773,7 +765,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -773,7 +765,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ddim_lsun
(
self
):
def
test_ddim_lsun
(
self
):
model_id
=
"
/home/patrick/google_checkpoints
/ddpm-ema-bedroom-256"
model_id
=
"
google
/ddpm-ema-bedroom-256"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
...
@@ -791,7 +783,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -791,7 +783,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ddim_cifar10
(
self
):
def
test_ddim_cifar10
(
self
):
model_id
=
"
/home/patrick/google_checkpoints
/ddpm-cifar10-32"
model_id
=
"
google
/ddpm-cifar10-32"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
...
@@ -809,7 +801,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -809,7 +801,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_pndm_cifar10
(
self
):
def
test_pndm_cifar10
(
self
):
model_id
=
"
/home/patrick/google_checkpoints
/ddpm-cifar10-32"
model_id
=
"
google
/ddpm-cifar10-32"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
...
@@ -826,7 +818,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -826,7 +818,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ldm_text2img
(
self
):
def
test_ldm_text2img
(
self
):
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"
/home/patrick/google_checkpoint
s/ldm-text2im-large-256"
)
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"
CompVi
s/ldm-text2im-large-256"
)
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -842,7 +834,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -842,7 +834,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ldm_text2img_fast
(
self
):
def
test_ldm_text2img_fast
(
self
):
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"
/home/patrick/google_checkpoint
s/ldm-text2im-large-256"
)
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"
CompVi
s/ldm-text2im-large-256"
)
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -856,13 +848,13 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -856,13 +848,13 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_score_sde_ve_pipeline
(
self
):
def
test_score_sde_ve_pipeline
(
self
):
model
=
UNet2DModel
.
from_pretrained
(
"
/home/patrick/google_checkpoints
/ncsnpp-church-256"
)
model
=
UNet2DModel
.
from_pretrained
(
"
google
/ncsnpp-church-256"
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
scheduler
=
ScoreSdeVeScheduler
.
from_config
(
"
/home/patrick/google_checkpoints
/ncsnpp-church-256"
)
scheduler
=
ScoreSdeVeScheduler
.
from_config
(
"
google
/ncsnpp-church-256"
)
sde_ve
=
ScoreSdeVePipeline
(
model
=
model
,
scheduler
=
scheduler
)
sde_ve
=
ScoreSdeVePipeline
(
model
=
model
,
scheduler
=
scheduler
)
...
@@ -877,7 +869,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -877,7 +869,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ldm_uncond
(
self
):
def
test_ldm_uncond
(
self
):
ldm
=
LDMPipeline
.
from_pretrained
(
"
/home/patrick/google_checkpoint
s/ldm-celebahq-256"
)
ldm
=
LDMPipeline
.
from_pretrained
(
"
CompVi
s/ldm-celebahq-256"
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
,
output_type
=
"numpy"
)[
"sample"
]
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
,
output_type
=
"numpy"
)[
"sample"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment