Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
554b374d
Commit
554b374d
authored
Nov 15, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
d5ab55e4
a0520193
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
339 additions
and
158 deletions
+339
-158
tests/pipelines/ddpm/test_ddpm.py
tests/pipelines/ddpm/test_ddpm.py
+1
-1
tests/pipelines/repaint/test_repaint.py
tests/pipelines/repaint/test_repaint.py
+1
-1
tests/pipelines/score_sde_ve/test_score_sde_ve.py
tests/pipelines/score_sde_ve/test_score_sde_ve.py
+1
-1
tests/pipelines/stable_diffusion/test_cycle_diffusion.py
tests/pipelines/stable_diffusion/test_cycle_diffusion.py
+2
-2
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
.../pipelines/stable_diffusion/test_onnx_stable_diffusion.py
+2
-2
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
...es/stable_diffusion/test_onnx_stable_diffusion_img2img.py
+1
-1
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
...es/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
+1
-1
tests/pipelines/stable_diffusion/test_stable_diffusion.py
tests/pipelines/stable_diffusion/test_stable_diffusion.py
+2
-2
tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
...pelines/stable_diffusion/test_stable_diffusion_img2img.py
+3
-3
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
...pelines/stable_diffusion/test_stable_diffusion_inpaint.py
+2
-2
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py
.../stable_diffusion/test_stable_diffusion_inpaint_legacy.py
+1
-1
tests/test_config.py
tests/test_config.py
+13
-118
tests/test_modeling_common.py
tests/test_modeling_common.py
+3
-3
tests/test_pipelines.py
tests/test_pipelines.py
+81
-1
tests/test_scheduler.py
tests/test_scheduler.py
+217
-11
tests/test_scheduler_flax.py
tests/test_scheduler_flax.py
+8
-8
No files found.
tests/pipelines/ddpm/test_ddpm.py
View file @
554b374d
...
@@ -106,7 +106,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
...
@@ -106,7 +106,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
model_id
=
"google/ddpm-cifar10-32"
model_id
=
"google/ddpm-cifar10-32"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
DDPMScheduler
.
from_
config
(
model_id
)
scheduler
=
DDPMScheduler
.
from_
pretrained
(
model_id
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
.
to
(
torch_device
)
ddpm
.
to
(
torch_device
)
...
...
tests/pipelines/repaint/test_repaint.py
View file @
554b374d
...
@@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
model_id
=
"google/ddpm-ema-celebahq-256"
model_id
=
"google/ddpm-ema-celebahq-256"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
RePaintScheduler
.
from_
config
(
model_id
)
scheduler
=
RePaintScheduler
.
from_
pretrained
(
model_id
)
repaint
=
RePaintPipeline
(
unet
=
unet
,
scheduler
=
scheduler
).
to
(
torch_device
)
repaint
=
RePaintPipeline
(
unet
=
unet
,
scheduler
=
scheduler
).
to
(
torch_device
)
...
...
tests/pipelines/score_sde_ve/test_score_sde_ve.py
View file @
554b374d
...
@@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
...
@@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
model_id
=
"google/ncsnpp-church-256"
model_id
=
"google/ncsnpp-church-256"
model
=
UNet2DModel
.
from_pretrained
(
model_id
)
model
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
ScoreSdeVeScheduler
.
from_
config
(
model_id
)
scheduler
=
ScoreSdeVeScheduler
.
from_
pretrained
(
model_id
)
sde_ve
=
ScoreSdeVePipeline
(
unet
=
model
,
scheduler
=
scheduler
)
sde_ve
=
ScoreSdeVePipeline
(
unet
=
model
,
scheduler
=
scheduler
)
sde_ve
.
to
(
torch_device
)
sde_ve
.
to
(
torch_device
)
...
...
tests/pipelines/stable_diffusion/test_cycle_diffusion.py
View file @
554b374d
...
@@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image
=
init_image
.
resize
((
512
,
512
))
init_image
=
init_image
.
resize
((
512
,
512
))
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
scheduler
=
DDIMScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
model_id
,
scheduler
=
scheduler
,
safety_checker
=
None
,
torch_dtype
=
torch
.
float16
,
revision
=
"fp16"
model_id
,
scheduler
=
scheduler
,
safety_checker
=
None
,
torch_dtype
=
torch
.
float16
,
revision
=
"fp16"
)
)
...
@@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image
=
init_image
.
resize
((
512
,
512
))
init_image
=
init_image
.
resize
((
512
,
512
))
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
scheduler
=
DDIMScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
model_id
,
scheduler
=
scheduler
,
safety_checker
=
None
)
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
model_id
,
scheduler
=
scheduler
,
safety_checker
=
None
)
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
...
...
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
View file @
554b374d
...
@@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-3
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-3
def
test_inference_ddim
(
self
):
def
test_inference_ddim
(
self
):
ddim_scheduler
=
DDIMScheduler
.
from_
config
(
ddim_scheduler
=
DDIMScheduler
.
from_
pretrained
(
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
)
)
sd_pipe
=
OnnxStableDiffusionPipeline
.
from_pretrained
(
sd_pipe
=
OnnxStableDiffusionPipeline
.
from_pretrained
(
...
@@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-3
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-3
def
test_inference_k_lms
(
self
):
def
test_inference_k_lms
(
self
):
lms_scheduler
=
LMSDiscreteScheduler
.
from_
config
(
lms_scheduler
=
LMSDiscreteScheduler
.
from_
pretrained
(
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
)
)
sd_pipe
=
OnnxStableDiffusionPipeline
.
from_pretrained
(
sd_pipe
=
OnnxStableDiffusionPipeline
.
from_pretrained
(
...
...
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
View file @
554b374d
...
@@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
...
@@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
"/img2img/sketch-mountains-input.jpg"
"/img2img/sketch-mountains-input.jpg"
)
)
init_image
=
init_image
.
resize
((
768
,
512
))
init_image
=
init_image
.
resize
((
768
,
512
))
lms_scheduler
=
LMSDiscreteScheduler
.
from_
config
(
lms_scheduler
=
LMSDiscreteScheduler
.
from_
pretrained
(
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
)
)
pipe
=
OnnxStableDiffusionImg2ImgPipeline
.
from_pretrained
(
pipe
=
OnnxStableDiffusionImg2ImgPipeline
.
from_pretrained
(
...
...
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
View file @
554b374d
...
@@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
)
lms_scheduler
=
LMSDiscreteScheduler
.
from_
config
(
lms_scheduler
=
LMSDiscreteScheduler
.
from_
pretrained
(
"runwayml/stable-diffusion-inpainting"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
"runwayml/stable-diffusion-inpainting"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
)
)
pipe
=
OnnxStableDiffusionInpaintPipeline
.
from_pretrained
(
pipe
=
OnnxStableDiffusionInpaintPipeline
.
from_pretrained
(
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion.py
View file @
554b374d
...
@@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_fast_ddim
(
self
):
def
test_stable_diffusion_fast_ddim
(
self
):
scheduler
=
DDIMScheduler
.
from_
config
(
"CompVis/stable-diffusion-v1-1"
,
subfolder
=
"scheduler"
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
"CompVis/stable-diffusion-v1-1"
,
subfolder
=
"scheduler"
)
sd_pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-1"
,
scheduler
=
scheduler
)
sd_pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-1"
,
scheduler
=
scheduler
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
...
@@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
model_id
=
"CompVis/stable-diffusion-v1-1"
model_id
=
"CompVis/stable-diffusion-v1-1"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
model_id
).
to
(
torch_device
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
model_id
).
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
scheduler
=
LMSDiscreteScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
scheduler
=
LMSDiscreteScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
.
scheduler
=
scheduler
pipe
.
scheduler
=
scheduler
prompt
=
"a photograph of an astronaut riding a horse"
prompt
=
"a photograph of an astronaut riding a horse"
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
View file @
554b374d
...
@@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
...
@@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
)
)
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
lms
=
LMSDiscreteScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
lms
=
LMSDiscreteScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
model_id
,
model_id
,
scheduler
=
lms
,
scheduler
=
lms
,
...
@@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
...
@@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
)
)
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
ddim
=
DDIMScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
ddim
=
DDIMScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
model_id
,
model_id
,
scheduler
=
ddim
,
scheduler
=
ddim
,
...
@@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
...
@@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image
=
init_image
.
resize
((
768
,
512
))
init_image
=
init_image
.
resize
((
768
,
512
))
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
lms
=
LMSDiscreteScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
lms
=
LMSDiscreteScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
model_id
,
scheduler
=
lms
,
safety_checker
=
None
,
device_map
=
"auto"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
model_id
,
scheduler
=
lms
,
safety_checker
=
None
,
device_map
=
"auto"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
)
)
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
View file @
554b374d
...
@@ -400,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -400,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
)
model_id
=
"runwayml/stable-diffusion-inpainting"
model_id
=
"runwayml/stable-diffusion-inpainting"
pndm
=
PNDMScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
pndm
=
PNDMScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
safety_checker
=
None
,
scheduler
=
pndm
)
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
safety_checker
=
None
,
scheduler
=
pndm
)
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -437,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -437,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
)
model_id
=
"runwayml/stable-diffusion-inpainting"
model_id
=
"runwayml/stable-diffusion-inpainting"
pndm
=
PNDMScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
pndm
=
PNDMScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
model_id
,
safety_checker
=
None
,
safety_checker
=
None
,
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py
View file @
554b374d
...
@@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
...
@@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
)
)
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
lms
=
LMSDiscreteScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
lms
=
LMSDiscreteScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
model_id
,
scheduler
=
lms
,
scheduler
=
lms
,
...
...
tests/test_config.py
View file @
554b374d
...
@@ -13,12 +13,9 @@
...
@@ -13,12 +13,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
json
import
os
import
tempfile
import
tempfile
import
unittest
import
unittest
import
diffusers
from
diffusers
import
(
from
diffusers
import
(
DDIMScheduler
,
DDIMScheduler
,
DDPMScheduler
,
DDPMScheduler
,
...
@@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin):
...
@@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin):
class
ConfigTester
(
unittest
.
TestCase
):
class
ConfigTester
(
unittest
.
TestCase
):
def
test_load_not_from_mixin
(
self
):
def
test_load_not_from_mixin
(
self
):
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
ConfigMixin
.
from
_config
(
"dummy_path"
)
ConfigMixin
.
load
_config
(
"dummy_path"
)
def
test_register_to_config
(
self
):
def
test_register_to_config
(
self
):
obj
=
SampleObject
()
obj
=
SampleObject
()
...
@@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
obj
.
save_config
(
tmpdirname
)
new_obj
=
SampleObject
.
from_config
(
tmpdirname
)
new_obj
=
SampleObject
.
from_config
(
SampleObject
.
load_config
(
tmpdirname
)
)
new_config
=
new_obj
.
config
new_config
=
new_obj
.
config
# unfreeze configs
# unfreeze configs
...
@@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase):
...
@@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase):
assert
new_config
.
pop
(
"c"
)
==
[
2
,
5
]
# saved & loaded as list because of json
assert
new_config
.
pop
(
"c"
)
==
[
2
,
5
]
# saved & loaded as list because of json
assert
config
==
new_config
assert
config
==
new_config
def
test_save_load_from_different_config
(
self
):
obj
=
SampleObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SampleObject"
,
SampleObject
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_1
:
new_obj_1
=
SampleObject2
.
from_config
(
tmpdirname
)
# now save a config parameter that is not expected
with
open
(
os
.
path
.
join
(
tmpdirname
,
SampleObject
.
config_name
),
"r"
)
as
f
:
data
=
json
.
load
(
f
)
data
[
"unexpected"
]
=
True
with
open
(
os
.
path
.
join
(
tmpdirname
,
SampleObject
.
config_name
),
"w"
)
as
f
:
json
.
dump
(
data
,
f
)
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
new_obj_2
=
SampleObject
.
from_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_3
:
new_obj_3
=
SampleObject2
.
from_config
(
tmpdirname
)
assert
new_obj_1
.
__class__
==
SampleObject2
assert
new_obj_2
.
__class__
==
SampleObject
assert
new_obj_3
.
__class__
==
SampleObject2
assert
cap_logger_1
.
out
==
""
assert
(
cap_logger_2
.
out
==
"The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.
\n
"
)
assert
cap_logger_2
.
out
.
replace
(
"SampleObject"
,
"SampleObject2"
)
==
cap_logger_3
.
out
def
test_save_load_compatible_schedulers
(
self
):
SampleObject2
.
_compatible_classes
=
[
"SampleObject"
]
SampleObject
.
_compatible_classes
=
[
"SampleObject2"
]
obj
=
SampleObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SampleObject"
,
SampleObject
)
setattr
(
diffusers
,
"SampleObject2"
,
SampleObject2
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
# now save a config parameter that is expected by another class, but not origin class
with
open
(
os
.
path
.
join
(
tmpdirname
,
SampleObject
.
config_name
),
"r"
)
as
f
:
data
=
json
.
load
(
f
)
data
[
"f"
]
=
[
0
,
0
]
data
[
"unexpected"
]
=
True
with
open
(
os
.
path
.
join
(
tmpdirname
,
SampleObject
.
config_name
),
"w"
)
as
f
:
json
.
dump
(
data
,
f
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
new_obj
=
SampleObject
.
from_config
(
tmpdirname
)
assert
new_obj
.
__class__
==
SampleObject
assert
(
cap_logger
.
out
==
"The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.
\n
"
)
def
test_save_load_from_different_config_comp_schedulers
(
self
):
SampleObject3
.
_compatible_classes
=
[
"SampleObject"
,
"SampleObject2"
]
SampleObject2
.
_compatible_classes
=
[
"SampleObject"
,
"SampleObject3"
]
SampleObject
.
_compatible_classes
=
[
"SampleObject2"
,
"SampleObject3"
]
obj
=
SampleObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SampleObject"
,
SampleObject
)
setattr
(
diffusers
,
"SampleObject2"
,
SampleObject2
)
setattr
(
diffusers
,
"SampleObject3"
,
SampleObject3
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
.
setLevel
(
diffusers
.
logging
.
INFO
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_1
:
new_obj_1
=
SampleObject
.
from_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
new_obj_2
=
SampleObject2
.
from_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_3
:
new_obj_3
=
SampleObject3
.
from_config
(
tmpdirname
)
assert
new_obj_1
.
__class__
==
SampleObject
assert
new_obj_2
.
__class__
==
SampleObject2
assert
new_obj_3
.
__class__
==
SampleObject3
assert
cap_logger_1
.
out
==
""
assert
cap_logger_2
.
out
==
"{'f'} was not found in config. Values will be initialized to default values.
\n
"
assert
cap_logger_3
.
out
==
"{'f'} was not found in config. Values will be initialized to default values.
\n
"
def
test_load_ddim_from_pndm
(
self
):
def
test_load_ddim_from_pndm
(
self
):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
ddim
=
DDIMScheduler
.
from_config
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
ddim
=
DDIMScheduler
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
assert
ddim
.
__class__
==
DDIMScheduler
assert
ddim
.
__class__
==
DDIMScheduler
# no warning should be thrown
# no warning should be thrown
...
@@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
euler
=
EulerDiscreteScheduler
.
from_
config
(
euler
=
EulerDiscreteScheduler
.
from_
pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
)
...
@@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
euler
=
EulerAncestralDiscreteScheduler
.
from_
config
(
euler
=
EulerAncestralDiscreteScheduler
.
from_
pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
)
...
@@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase):
...
@@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
pndm
=
PNDMScheduler
.
from_config
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
pndm
=
PNDMScheduler
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
assert
pndm
.
__class__
==
PNDMScheduler
assert
pndm
.
__class__
==
PNDMScheduler
# no warning should be thrown
# no warning should be thrown
...
@@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
ddpm
=
DDPMScheduler
.
from_
config
(
ddpm
=
DDPMScheduler
.
from_
pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
,
subfolder
=
"scheduler"
,
predict_epsilon
=
False
,
predict_epsilon
=
False
,
...
@@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase):
)
)
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
ddpm_2
=
DDPMScheduler
.
from_
config
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
ddpm_2
=
DDPMScheduler
.
from_
pretrained
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
assert
ddpm
.
__class__
==
DDPMScheduler
assert
ddpm
.
__class__
==
DDPMScheduler
assert
ddpm
.
config
.
predict_epsilon
is
False
assert
ddpm
.
config
.
predict_epsilon
is
False
...
@@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
dpm
=
DPMSolverMultistepScheduler
.
from_
config
(
dpm
=
DPMSolverMultistepScheduler
.
from_
pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
)
...
...
tests/test_modeling_common.py
View file @
554b374d
...
@@ -130,7 +130,7 @@ class ModelTesterMixin:
...
@@ -130,7 +130,7 @@ class ModelTesterMixin:
expected_arg_names
=
[
"sample"
,
"timestep"
]
expected_arg_names
=
[
"sample"
,
"timestep"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_
config
(
self
):
def
test_model_from_
pretrained
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
=
self
.
model_class
(
**
init_dict
)
...
@@ -140,8 +140,8 @@ class ModelTesterMixin:
...
@@ -140,8 +140,8 @@ class ModelTesterMixin:
# test if the model can be loaded from the config
# test if the model can be loaded from the config
# and has all the expected shape
# and has all the expected shape
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_
config
(
tmpdirname
)
model
.
save_
pretrained
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_
config
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_
pretrained
(
tmpdirname
)
new_model
.
to
(
torch_device
)
new_model
.
to
(
torch_device
)
new_model
.
eval
()
new_model
.
eval
()
...
...
tests/test_pipelines.py
View file @
554b374d
...
@@ -29,6 +29,10 @@ from diffusers import (
...
@@ -29,6 +29,10 @@ from diffusers import (
DDIMScheduler
,
DDIMScheduler
,
DDPMPipeline
,
DDPMPipeline
,
DDPMScheduler
,
DDPMScheduler
,
DPMSolverMultistepScheduler
,
EulerAncestralDiscreteScheduler
,
EulerDiscreteScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
,
PNDMScheduler
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionInpaintPipelineLegacy
,
StableDiffusionInpaintPipelineLegacy
,
...
@@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase):
...
@@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase):
assert
image_img2img
.
shape
==
(
1
,
32
,
32
,
3
)
assert
image_img2img
.
shape
==
(
1
,
32
,
32
,
3
)
assert
image_text2img
.
shape
==
(
1
,
128
,
128
,
3
)
assert
image_text2img
.
shape
==
(
1
,
128
,
128
,
3
)
def
test_set_scheduler
(
self
):
unet
=
self
.
dummy_cond_unet
scheduler
=
PNDMScheduler
(
skip_prk_steps
=
True
)
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
sd
=
StableDiffusionPipeline
(
unet
=
unet
,
scheduler
=
scheduler
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
None
,
feature_extractor
=
self
.
dummy_extractor
,
)
sd
.
scheduler
=
DDIMScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
DDIMScheduler
)
sd
.
scheduler
=
DDPMScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
DDPMScheduler
)
sd
.
scheduler
=
PNDMScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
PNDMScheduler
)
sd
.
scheduler
=
LMSDiscreteScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
LMSDiscreteScheduler
)
sd
.
scheduler
=
EulerDiscreteScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
EulerDiscreteScheduler
)
sd
.
scheduler
=
EulerAncestralDiscreteScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
EulerAncestralDiscreteScheduler
)
sd
.
scheduler
=
DPMSolverMultistepScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
DPMSolverMultistepScheduler
)
def
test_set_scheduler_consistency
(
self
):
unet
=
self
.
dummy_cond_unet
pndm
=
PNDMScheduler
.
from_config
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
ddim
=
DDIMScheduler
.
from_config
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
sd
=
StableDiffusionPipeline
(
unet
=
unet
,
scheduler
=
pndm
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
None
,
feature_extractor
=
self
.
dummy_extractor
,
)
pndm_config
=
sd
.
scheduler
.
config
sd
.
scheduler
=
DDPMScheduler
.
from_config
(
pndm_config
)
sd
.
scheduler
=
PNDMScheduler
.
from_config
(
sd
.
scheduler
.
config
)
pndm_config_2
=
sd
.
scheduler
.
config
pndm_config_2
=
{
k
:
v
for
k
,
v
in
pndm_config_2
.
items
()
if
k
in
pndm_config
}
assert
dict
(
pndm_config
)
==
dict
(
pndm_config_2
)
sd
=
StableDiffusionPipeline
(
unet
=
unet
,
scheduler
=
ddim
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
None
,
feature_extractor
=
self
.
dummy_extractor
,
)
ddim_config
=
sd
.
scheduler
.
config
sd
.
scheduler
=
LMSDiscreteScheduler
.
from_config
(
ddim_config
)
sd
.
scheduler
=
DDIMScheduler
.
from_config
(
sd
.
scheduler
.
config
)
ddim_config_2
=
sd
.
scheduler
.
config
ddim_config_2
=
{
k
:
v
for
k
,
v
in
ddim_config_2
.
items
()
if
k
in
ddim_config
}
assert
dict
(
ddim_config
)
==
dict
(
ddim_config_2
)
@
slow
@
slow
class
PipelineSlowTests
(
unittest
.
TestCase
):
class
PipelineSlowTests
(
unittest
.
TestCase
):
...
@@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase):
...
@@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase):
def
test_output_format
(
self
):
def
test_output_format
(
self
):
model_path
=
"google/ddpm-cifar10-32"
model_path
=
"google/ddpm-cifar10-32"
scheduler
=
DDIMScheduler
.
from_
config
(
model_path
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
model_path
)
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
)
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
)
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
tests/test_scheduler.py
View file @
554b374d
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
inspect
import
inspect
import
json
import
os
import
tempfile
import
tempfile
import
unittest
import
unittest
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
...
@@ -21,6 +23,7 @@ import numpy as np
...
@@ -21,6 +23,7 @@ import numpy as np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
diffusers
from
diffusers
import
(
from
diffusers
import
(
DDIMScheduler
,
DDIMScheduler
,
DDPMScheduler
,
DDPMScheduler
,
...
@@ -32,13 +35,180 @@ from diffusers import (
...
@@ -32,13 +35,180 @@ from diffusers import (
PNDMScheduler
,
PNDMScheduler
,
ScoreSdeVeScheduler
,
ScoreSdeVeScheduler
,
VQDiffusionScheduler
,
VQDiffusionScheduler
,
logging
,
)
)
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.schedulers.scheduling_utils
import
SchedulerMixin
from
diffusers.utils
import
deprecate
,
torch_device
from
diffusers.utils
import
deprecate
,
torch_device
from
diffusers.utils.testing_utils
import
CaptureLogger
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
SchedulerObject
(
SchedulerMixin
,
ConfigMixin
):
config_name
=
"config.json"
@
register_to_config
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
):
pass
class
SchedulerObject2
(
SchedulerMixin
,
ConfigMixin
):
config_name
=
"config.json"
@
register_to_config
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
f
=
[
1
,
3
],
):
pass
class
SchedulerObject3
(
SchedulerMixin
,
ConfigMixin
):
config_name
=
"config.json"
@
register_to_config
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
f
=
[
1
,
3
],
):
pass
class
SchedulerBaseTests
(
unittest
.
TestCase
):
def
test_save_load_from_different_config
(
self
):
obj
=
SchedulerObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SchedulerObject"
,
SchedulerObject
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_1
:
config
=
SchedulerObject2
.
load_config
(
tmpdirname
)
new_obj_1
=
SchedulerObject2
.
from_config
(
config
)
# now save a config parameter that is not expected
with
open
(
os
.
path
.
join
(
tmpdirname
,
SchedulerObject
.
config_name
),
"r"
)
as
f
:
data
=
json
.
load
(
f
)
data
[
"unexpected"
]
=
True
with
open
(
os
.
path
.
join
(
tmpdirname
,
SchedulerObject
.
config_name
),
"w"
)
as
f
:
json
.
dump
(
data
,
f
)
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
config
=
SchedulerObject
.
load_config
(
tmpdirname
)
new_obj_2
=
SchedulerObject
.
from_config
(
config
)
with
CaptureLogger
(
logger
)
as
cap_logger_3
:
config
=
SchedulerObject2
.
load_config
(
tmpdirname
)
new_obj_3
=
SchedulerObject2
.
from_config
(
config
)
assert
new_obj_1
.
__class__
==
SchedulerObject2
assert
new_obj_2
.
__class__
==
SchedulerObject
assert
new_obj_3
.
__class__
==
SchedulerObject2
assert
cap_logger_1
.
out
==
""
assert
(
cap_logger_2
.
out
==
"The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
" will"
" be ignored. Please verify your config.json configuration file.
\n
"
)
assert
cap_logger_2
.
out
.
replace
(
"SchedulerObject"
,
"SchedulerObject2"
)
==
cap_logger_3
.
out
def
test_save_load_compatible_schedulers
(
self
):
SchedulerObject2
.
_compatibles
=
[
"SchedulerObject"
]
SchedulerObject
.
_compatibles
=
[
"SchedulerObject2"
]
obj
=
SchedulerObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SchedulerObject"
,
SchedulerObject
)
setattr
(
diffusers
,
"SchedulerObject2"
,
SchedulerObject2
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
# now save a config parameter that is expected by another class, but not origin class
with
open
(
os
.
path
.
join
(
tmpdirname
,
SchedulerObject
.
config_name
),
"r"
)
as
f
:
data
=
json
.
load
(
f
)
data
[
"f"
]
=
[
0
,
0
]
data
[
"unexpected"
]
=
True
with
open
(
os
.
path
.
join
(
tmpdirname
,
SchedulerObject
.
config_name
),
"w"
)
as
f
:
json
.
dump
(
data
,
f
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
config
=
SchedulerObject
.
load_config
(
tmpdirname
)
new_obj
=
SchedulerObject
.
from_config
(
config
)
assert
new_obj
.
__class__
==
SchedulerObject
assert
(
cap_logger
.
out
==
"The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
" will"
" be ignored. Please verify your config.json configuration file.
\n
"
)
def
test_save_load_from_different_config_comp_schedulers
(
self
):
SchedulerObject3
.
_compatibles
=
[
"SchedulerObject"
,
"SchedulerObject2"
]
SchedulerObject2
.
_compatibles
=
[
"SchedulerObject"
,
"SchedulerObject3"
]
SchedulerObject
.
_compatibles
=
[
"SchedulerObject2"
,
"SchedulerObject3"
]
obj
=
SchedulerObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SchedulerObject"
,
SchedulerObject
)
setattr
(
diffusers
,
"SchedulerObject2"
,
SchedulerObject2
)
setattr
(
diffusers
,
"SchedulerObject3"
,
SchedulerObject3
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
.
setLevel
(
diffusers
.
logging
.
INFO
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_1
:
config
=
SchedulerObject
.
load_config
(
tmpdirname
)
new_obj_1
=
SchedulerObject
.
from_config
(
config
)
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
config
=
SchedulerObject2
.
load_config
(
tmpdirname
)
new_obj_2
=
SchedulerObject2
.
from_config
(
config
)
with
CaptureLogger
(
logger
)
as
cap_logger_3
:
config
=
SchedulerObject3
.
load_config
(
tmpdirname
)
new_obj_3
=
SchedulerObject3
.
from_config
(
config
)
assert
new_obj_1
.
__class__
==
SchedulerObject
assert
new_obj_2
.
__class__
==
SchedulerObject2
assert
new_obj_3
.
__class__
==
SchedulerObject3
assert
cap_logger_1
.
out
==
""
assert
cap_logger_2
.
out
==
"{'f'} was not found in config. Values will be initialized to default values.
\n
"
assert
cap_logger_3
.
out
==
"{'f'} was not found in config. Values will be initialized to default values.
\n
"
class
SchedulerCommonTest
(
unittest
.
TestCase
):
class
SchedulerCommonTest
(
unittest
.
TestCase
):
scheduler_classes
=
()
scheduler_classes
=
()
forward_default_kwargs
=
()
forward_default_kwargs
=
()
...
@@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
scheduler
.
set_timesteps
(
num_inference_steps
)
...
@@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
scheduler
.
set_timesteps
(
num_inference_steps
)
...
@@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
scheduler
.
set_timesteps
(
num_inference_steps
)
...
@@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase):
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_compatibles
(
self
):
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
assert
all
(
c
is
not
None
for
c
in
scheduler
.
compatibles
)
for
comp_scheduler_cls
in
scheduler
.
compatibles
:
comp_scheduler
=
comp_scheduler_cls
.
from_config
(
scheduler
.
config
)
assert
comp_scheduler
is
not
None
new_scheduler
=
scheduler_class
.
from_config
(
comp_scheduler
.
config
)
new_scheduler_config
=
{
k
:
v
for
k
,
v
in
new_scheduler
.
config
.
items
()
if
k
in
scheduler
.
config
}
scheduler_diff
=
{
k
:
v
for
k
,
v
in
new_scheduler
.
config
.
items
()
if
k
not
in
scheduler
.
config
}
# make sure that configs are essentially identical
assert
new_scheduler_config
==
dict
(
scheduler
.
config
)
# make sure that only differences are for configs that are not in init
init_keys
=
inspect
.
signature
(
scheduler_class
.
__init__
).
parameters
.
keys
()
assert
set
(
scheduler_diff
.
keys
()).
intersection
(
set
(
init_keys
))
==
set
()
def
test_from_pretrained
(
self
):
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_pretrained
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_pretrained
(
tmpdirname
)
assert
scheduler
.
config
==
new_scheduler
.
config
def
test_step_shape
(
self
):
def
test_step_shape
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
...
@@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
...
@@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
model_outputs
=
dummy_past_residuals
[:
new_scheduler
.
config
.
solver_order
]
new_scheduler
.
model_outputs
=
dummy_past_residuals
[:
new_scheduler
.
config
.
solver_order
]
...
@@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
...
@@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
...
@@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
...
@@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
...
@@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
...
@@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
output
=
scheduler
.
step_pred
(
output
=
scheduler
.
step_pred
(
residual
,
time_step
,
sample
,
generator
=
torch
.
manual_seed
(
0
),
**
kwargs
residual
,
time_step
,
sample
,
generator
=
torch
.
manual_seed
(
0
),
**
kwargs
...
@@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
...
@@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
output
=
scheduler
.
step_pred
(
output
=
scheduler
.
step_pred
(
residual
,
time_step
,
sample
,
generator
=
torch
.
manual_seed
(
0
),
**
kwargs
residual
,
time_step
,
sample
,
generator
=
torch
.
manual_seed
(
0
),
**
kwargs
...
@@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
...
@@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
...
@@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
...
@@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
...
...
tests/test_scheduler_flax.py
View file @
554b374d
...
@@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
...
@@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
...
@@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
...
@@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
new_state
=
new_scheduler
.
set_timesteps
(
new_state
,
num_inference_steps
,
shape
=
sample
.
shape
)
new_state
=
new_scheduler
.
set_timesteps
(
new_state
,
num_inference_steps
,
shape
=
sample
.
shape
)
# copy over dummy past residuals
# copy over dummy past residuals
new_state
=
new_state
.
replace
(
ets
=
dummy_past_residuals
[:])
new_state
=
new_state
.
replace
(
ets
=
dummy_past_residuals
[:])
...
@@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
# copy over dummy past residuals
# copy over dummy past residuals
new_state
=
new_scheduler
.
set_timesteps
(
new_state
,
num_inference_steps
,
shape
=
sample
.
shape
)
new_state
=
new_scheduler
.
set_timesteps
(
new_state
,
num_inference_steps
,
shape
=
sample
.
shape
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment