Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
d2d9764f
Unverified
Commit
d2d9764f
authored
Oct 28, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 28, 2022
Browse files
[Tests] Speed up slow tests (#1040)
* [Tests] Speed up slow tests * Up * up
parent
a80480f0
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
58 additions
and
36 deletions
+58
-36
tests/pipelines/dance_diffusion/test_dance_diffusion.py
tests/pipelines/dance_diffusion/test_dance_diffusion.py
+4
-2
tests/pipelines/ddim/test_ddim.py
tests/pipelines/ddim/test_ddim.py
+2
-2
tests/pipelines/ddpm/test_ddpm.py
tests/pipelines/ddpm/test_ddpm.py
+1
-1
tests/pipelines/karras_ve/test_karras_ve.py
tests/pipelines/karras_ve/test_karras_ve.py
+1
-1
tests/pipelines/latent_diffusion/test_latent_diffusion.py
tests/pipelines/latent_diffusion/test_latent_diffusion.py
+2
-2
tests/pipelines/pndm/test_pndm.py
tests/pipelines/pndm/test_pndm.py
+1
-1
tests/pipelines/score_sde_ve/test_score_sde_ve.py
tests/pipelines/score_sde_ve/test_score_sde_ve.py
+1
-1
tests/pipelines/stable_diffusion/test_stable_diffusion.py
tests/pipelines/stable_diffusion/test_stable_diffusion.py
+15
-10
tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
...pelines/stable_diffusion/test_stable_diffusion_img2img.py
+3
-1
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
...pelines/stable_diffusion/test_stable_diffusion_inpaint.py
+5
-1
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py
.../stable_diffusion/test_stable_diffusion_inpaint_legacy.py
+3
-1
tests/test_pipelines.py
tests/test_pipelines.py
+20
-13
No files found.
tests/pipelines/dance_diffusion/test_dance_diffusion.py
View file @
d2d9764f
...
...
@@ -86,7 +86,7 @@ class PipelineIntegrationTests(unittest.TestCase):
def
test_dance_diffusion
(
self
):
device
=
torch_device
pipe
=
DanceDiffusionPipeline
.
from_pretrained
(
"harmonai/maestro-150k"
)
pipe
=
DanceDiffusionPipeline
.
from_pretrained
(
"harmonai/maestro-150k"
,
device_map
=
"auto"
)
pipe
=
pipe
.
to
(
device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -103,7 +103,9 @@ class PipelineIntegrationTests(unittest.TestCase):
def
test_dance_diffusion_fp16
(
self
):
device
=
torch_device
pipe
=
DanceDiffusionPipeline
.
from_pretrained
(
"harmonai/maestro-150k"
,
torch_dtype
=
torch
.
float16
)
pipe
=
DanceDiffusionPipeline
.
from_pretrained
(
"harmonai/maestro-150k"
,
torch_dtype
=
torch
.
float16
,
device_map
=
"auto"
)
pipe
=
pipe
.
to
(
device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
tests/pipelines/ddim/test_ddim.py
View file @
d2d9764f
...
...
@@ -78,7 +78,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
def
test_inference_ema_bedroom
(
self
):
model_id
=
"google/ddpm-ema-bedroom-256"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
,
device_map
=
"auto"
)
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
...
...
@@ -97,7 +97,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
def
test_inference_cifar10
(
self
):
model_id
=
"google/ddpm-cifar10-32"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
,
device_map
=
"auto"
)
scheduler
=
DDIMScheduler
()
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
...
...
tests/pipelines/ddpm/test_ddpm.py
View file @
d2d9764f
...
...
@@ -38,7 +38,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
def
test_inference_cifar10
(
self
):
model_id
=
"google/ddpm-cifar10-32"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
,
device_map
=
"auto"
)
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
...
...
tests/pipelines/karras_ve/test_karras_ve.py
View file @
d2d9764f
...
...
@@ -70,7 +70,7 @@ class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class
KarrasVePipelineIntegrationTests
(
unittest
.
TestCase
):
def
test_inference
(
self
):
model_id
=
"google/ncsnpp-celebahq-256"
model
=
UNet2DModel
.
from_pretrained
(
model_id
)
model
=
UNet2DModel
.
from_pretrained
(
model_id
,
device_map
=
"auto"
)
scheduler
=
KarrasVeScheduler
()
pipe
=
KarrasVePipeline
(
unet
=
model
,
scheduler
=
scheduler
)
...
...
tests/pipelines/latent_diffusion/test_latent_diffusion.py
View file @
d2d9764f
...
...
@@ -121,7 +121,7 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@
require_torch
class
LDMTextToImagePipelineIntegrationTests
(
unittest
.
TestCase
):
def
test_inference_text2img
(
self
):
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
)
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
,
device_map
=
"auto"
)
ldm
.
to
(
torch_device
)
ldm
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -138,7 +138,7 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_inference_text2img_fast
(
self
):
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
)
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
,
device_map
=
"auto"
)
ldm
.
to
(
torch_device
)
ldm
.
set_progress_bar_config
(
disable
=
None
)
...
...
tests/pipelines/pndm/test_pndm.py
View file @
d2d9764f
...
...
@@ -71,7 +71,7 @@ class PNDMPipelineIntegrationTests(unittest.TestCase):
def
test_inference_cifar10
(
self
):
model_id
=
"google/ddpm-cifar10-32"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
,
device_map
=
"auto"
)
scheduler
=
PNDMScheduler
()
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
...
...
tests/pipelines/score_sde_ve/test_score_sde_ve.py
View file @
d2d9764f
...
...
@@ -72,7 +72,7 @@ class ScoreSdeVeipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class
ScoreSdeVePipelineIntegrationTests
(
unittest
.
TestCase
):
def
test_inference
(
self
):
model_id
=
"google/ncsnpp-church-256"
model
=
UNet2DModel
.
from_pretrained
(
model_id
)
model
=
UNet2DModel
.
from_pretrained
(
model_id
,
device_map
=
"auto"
)
scheduler
=
ScoreSdeVeScheduler
.
from_config
(
model_id
)
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion.py
View file @
d2d9764f
...
...
@@ -528,7 +528,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def
test_stable_diffusion
(
self
):
# make sure here that pndm scheduler skips prk
sd_pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-1"
)
sd_pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-1"
,
device_map
=
"auto"
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -548,7 +548,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_fast_ddim
(
self
):
sd_pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-1"
)
sd_pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-1"
,
device_map
=
"auto"
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -576,7 +576,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def
test_lms_stable_diffusion_pipeline
(
self
):
model_id
=
"CompVis/stable-diffusion-v1-1"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
model_id
).
to
(
torch_device
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
model_id
,
device_map
=
"auto"
).
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
scheduler
=
LMSDiscreteScheduler
.
from_config
(
model_id
,
subfolder
=
"scheduler"
)
pipe
.
scheduler
=
scheduler
...
...
@@ -595,9 +595,10 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def
test_stable_diffusion_memory_chunking
(
self
):
torch
.
cuda
.
reset_peak_memory_stats
()
model_id
=
"CompVis/stable-diffusion-v1-4"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
model_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
).
to
(
torch_device
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
model_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
,
device_map
=
"auto"
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"a photograph of an astronaut riding a horse"
...
...
@@ -633,9 +634,10 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def
test_stable_diffusion_text2img_pipeline_fp16
(
self
):
torch
.
cuda
.
reset_peak_memory_stats
()
model_id
=
"CompVis/stable-diffusion-v1-4"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
model_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
).
to
(
torch_device
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
model_id
,
revision
=
"fp16"
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"a photograph of an astronaut riding a horse"
...
...
@@ -670,6 +672,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
model_id
,
safety_checker
=
None
,
device_map
=
"auto"
,
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -711,7 +714,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
test_callback_fn
.
has_been_called
=
False
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
"CompVis/stable-diffusion-v1-4"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
,
device_map
=
"auto"
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -737,7 +740,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
start_time
=
time
.
time
()
pipeline_normal_load
=
StableDiffusionPipeline
.
from_pretrained
(
pipeline_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
,
use_auth_token
=
True
pipeline_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
,
device_map
=
"auto"
)
pipeline_normal_load
.
to
(
torch_device
)
normal_load_time
=
time
.
time
()
-
start_time
...
...
@@ -758,7 +761,9 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
pipeline_id
=
"CompVis/stable-diffusion-v1-4"
prompt
=
"Andromeda galaxy in a bottle"
pipeline
=
StableDiffusionPipeline
.
from_pretrained
(
pipeline_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
)
pipeline
=
StableDiffusionPipeline
.
from_pretrained
(
pipeline_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
,
device_map
=
"auto"
)
pipeline
.
enable_attention_slicing
(
1
)
pipeline
.
enable_sequential_cpu_offload
()
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
View file @
d2d9764f
...
...
@@ -488,6 +488,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
model_id
,
safety_checker
=
None
,
device_map
=
"auto"
,
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -529,6 +530,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
model_id
,
scheduler
=
lms
,
safety_checker
=
None
,
device_map
=
"auto"
,
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -580,7 +582,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image
=
init_image
.
resize
((
768
,
512
))
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
"CompVis/stable-diffusion-v1-4"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
,
device_map
=
"auto"
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
View file @
d2d9764f
...
...
@@ -288,6 +288,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
safety_checker
=
None
,
device_map
=
"auto"
,
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -329,6 +330,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
,
safety_checker
=
None
,
device_map
=
"auto"
,
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -366,7 +368,9 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
pndm
=
PNDMScheduler
(
beta_start
=
0.00085
,
beta_end
=
0.012
,
beta_schedule
=
"scaled_linear"
,
skip_prk_steps
=
True
)
model_id
=
"runwayml/stable-diffusion-inpainting"
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
safety_checker
=
None
,
scheduler
=
pndm
)
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
safety_checker
=
None
,
scheduler
=
pndm
,
device_map
=
"auto"
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
enable_attention_slicing
()
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py
View file @
d2d9764f
...
...
@@ -368,6 +368,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
safety_checker
=
None
,
device_map
=
"auto"
,
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -413,6 +414,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
model_id
,
scheduler
=
lms
,
safety_checker
=
None
,
device_map
=
"auto"
,
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -469,7 +471,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
)
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
"CompVis/stable-diffusion-v1-4"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
,
device_map
=
"auto"
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
tests/test_pipelines.py
View file @
d2d9764f
...
...
@@ -108,8 +108,8 @@ class CustomPipelineTests(unittest.TestCase):
def
test_load_pipeline_from_git
(
self
):
clip_model_id
=
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
feature_extractor
=
CLIPFeatureExtractor
.
from_pretrained
(
clip_model_id
)
clip_model
=
CLIPModel
.
from_pretrained
(
clip_model_id
,
torch_dtype
=
torch
.
float16
)
feature_extractor
=
CLIPFeatureExtractor
.
from_pretrained
(
clip_model_id
,
device_map
=
"auto"
)
clip_model
=
CLIPModel
.
from_pretrained
(
clip_model_id
,
torch_dtype
=
torch
.
float16
,
device_map
=
"auto"
)
pipeline
=
DiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
...
...
@@ -118,6 +118,7 @@ class CustomPipelineTests(unittest.TestCase):
feature_extractor
=
feature_extractor
,
torch_dtype
=
torch
.
float16
,
revision
=
"fp16"
,
device_map
=
"auto"
,
)
pipeline
.
enable_attention_slicing
()
pipeline
=
pipeline
.
to
(
torch_device
)
...
...
@@ -312,7 +313,9 @@ class PipelineSlowTests(unittest.TestCase):
def
test_smart_download
(
self
):
model_id
=
"hf-internal-testing/unet-pipeline-dummy"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
_
=
DiffusionPipeline
.
from_pretrained
(
model_id
,
cache_dir
=
tmpdirname
,
force_download
=
True
)
_
=
DiffusionPipeline
.
from_pretrained
(
model_id
,
cache_dir
=
tmpdirname
,
force_download
=
True
,
device_map
=
"auto"
)
local_repo_name
=
"--"
.
join
([
"models"
]
+
model_id
.
split
(
"/"
))
snapshot_dir
=
os
.
path
.
join
(
tmpdirname
,
local_repo_name
,
"snapshots"
)
snapshot_dir
=
os
.
path
.
join
(
snapshot_dir
,
os
.
listdir
(
snapshot_dir
)[
0
])
...
...
@@ -335,7 +338,9 @@ class PipelineSlowTests(unittest.TestCase):
logger
=
logging
.
get_logger
(
"diffusers.pipeline_utils"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
DiffusionPipeline
.
from_pretrained
(
model_id
,
not_used
=
True
,
cache_dir
=
tmpdirname
,
force_download
=
True
)
DiffusionPipeline
.
from_pretrained
(
model_id
,
not_used
=
True
,
cache_dir
=
tmpdirname
,
force_download
=
True
,
device_map
=
"auto"
)
assert
cap_logger
.
out
==
"Keyword arguments {'not_used': True} not recognized.
\n
"
...
...
@@ -358,7 +363,7 @@ class PipelineSlowTests(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
ddpm
.
save_pretrained
(
tmpdirname
)
new_ddpm
=
DDPMPipeline
.
from_pretrained
(
tmpdirname
)
new_ddpm
=
DDPMPipeline
.
from_pretrained
(
tmpdirname
,
device_map
=
"auto"
)
new_ddpm
.
to
(
torch_device
)
generator
=
torch
.
manual_seed
(
0
)
...
...
@@ -374,10 +379,10 @@ class PipelineSlowTests(unittest.TestCase):
scheduler
=
DDPMScheduler
(
num_train_timesteps
=
10
)
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
)
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
,
device_map
=
"auto"
)
ddpm
.
to
(
torch_device
)
ddpm
.
set_progress_bar_config
(
disable
=
None
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
,
device_map
=
"auto"
)
ddpm_from_hub
.
to
(
torch_device
)
ddpm_from_hub
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -395,12 +400,14 @@ class PipelineSlowTests(unittest.TestCase):
scheduler
=
DDPMScheduler
(
num_train_timesteps
=
10
)
# pass unet into DiffusionPipeline
unet
=
UNet2DModel
.
from_pretrained
(
model_path
)
ddpm_from_hub_custom_model
=
DiffusionPipeline
.
from_pretrained
(
model_path
,
unet
=
unet
,
scheduler
=
scheduler
)
unet
=
UNet2DModel
.
from_pretrained
(
model_path
,
device_map
=
"auto"
)
ddpm_from_hub_custom_model
=
DiffusionPipeline
.
from_pretrained
(
model_path
,
unet
=
unet
,
scheduler
=
scheduler
,
device_map
=
"auto"
)
ddpm_from_hub_custom_model
.
to
(
torch_device
)
ddpm_from_hub_custom_model
.
set_progress_bar_config
(
disable
=
None
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
,
device_map
=
"auto"
)
ddpm_from_hub
.
to
(
torch_device
)
ddpm_from_hub_custom_model
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -415,7 +422,7 @@ class PipelineSlowTests(unittest.TestCase):
def
test_output_format
(
self
):
model_path
=
"google/ddpm-cifar10-32"
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
)
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
,
device_map
=
"auto"
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -437,7 +444,7 @@ class PipelineSlowTests(unittest.TestCase):
def
test_ddpm_ddim_equality
(
self
):
model_id
=
"google/ddpm-cifar10-32"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
,
device_map
=
"auto"
)
ddpm_scheduler
=
DDPMScheduler
()
ddim_scheduler
=
DDIMScheduler
()
...
...
@@ -461,7 +468,7 @@ class PipelineSlowTests(unittest.TestCase):
def
test_ddpm_ddim_equality_batched
(
self
):
model_id
=
"google/ddpm-cifar10-32"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
,
device_map
=
"auto"
)
ddpm_scheduler
=
DDPMScheduler
()
ddim_scheduler
=
DDIMScheduler
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment