Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
554b374d
Commit
554b374d
authored
Nov 15, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
d5ab55e4
a0520193
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
339 additions
and
158 deletions
+339
-158
tests/pipelines/ddpm/test_ddpm.py
tests/pipelines/ddpm/test_ddpm.py
+1
-1
tests/pipelines/repaint/test_repaint.py
tests/pipelines/repaint/test_repaint.py
+1
-1
tests/pipelines/score_sde_ve/test_score_sde_ve.py
tests/pipelines/score_sde_ve/test_score_sde_ve.py
+1
-1
tests/pipelines/stable_diffusion/test_cycle_diffusion.py
tests/pipelines/stable_diffusion/test_cycle_diffusion.py
+2
-2
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
.../pipelines/stable_diffusion/test_onnx_stable_diffusion.py
+2
-2
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
...es/stable_diffusion/test_onnx_stable_diffusion_img2img.py
+1
-1
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
...es/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
+1
-1
tests/pipelines/stable_diffusion/test_stable_diffusion.py
tests/pipelines/stable_diffusion/test_stable_diffusion.py
+2
-2
tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
...pelines/stable_diffusion/test_stable_diffusion_img2img.py
+3
-3
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
...pelines/stable_diffusion/test_stable_diffusion_inpaint.py
+2
-2
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py
.../stable_diffusion/test_stable_diffusion_inpaint_legacy.py
+1
-1
tests/test_config.py
tests/test_config.py
+13
-118
tests/test_modeling_common.py
tests/test_modeling_common.py
+3
-3
tests/test_pipelines.py
tests/test_pipelines.py
+81
-1
tests/test_scheduler.py
tests/test_scheduler.py
+217
-11
tests/test_scheduler_flax.py
tests/test_scheduler_flax.py
+8
-8
No files found.
tests/pipelines/ddpm/test_ddpm.py
View file @
554b374d
...
@@ -106,7 +106,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
...
@@ -106,7 +106,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
model_id
=
"google/ddpm-cifar10-32"
model_id
=
"google/ddpm-cifar10-32"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
DDPMScheduler
.
from_
config
(
model_id
)
scheduler
=
DDPMScheduler
.
from_
pretrained
(
model_id
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
.
to
(
torch_device
)
ddpm
.
to
(
torch_device
)
...
...
tests/pipelines/repaint/test_repaint.py
View file @
554b374d
...
@@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
model_id
=
"google/ddpm-ema-celebahq-256"
model_id
=
"google/ddpm-ema-celebahq-256"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
RePaintScheduler
.
from_
config
(
model_id
)
scheduler
=
RePaintScheduler
.
from_
pretrained
(
model_id
)
repaint
=
RePaintPipeline
(
unet
=
unet
,
scheduler
=
scheduler
).
to
(
torch_device
)
repaint
=
RePaintPipeline
(
unet
=
unet
,
scheduler
=
scheduler
).
to
(
torch_device
)
...
...
tests/pipelines/score_sde_ve/test_score_sde_ve.py
View file @
554b374d
...
@@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
...
@@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
model_id
=
"google/ncsnpp-church-256"
model_id
=
"google/ncsnpp-church-256"
model
=
UNet2DModel
.
from_pretrained
(
model_id
)
model
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
ScoreSdeVeScheduler
.
from_
config
(
model_id
)
scheduler
=
ScoreSdeVeScheduler
.
from_
pretrained
(
model_id
)
sde_ve
=
ScoreSdeVePipeline
(
unet
=
model
,
scheduler
=
scheduler
)
sde_ve
=
ScoreSdeVePipeline
(
unet
=
model
,
scheduler
=
scheduler
)
sde_ve
.
to
(
torch_device
)
sde_ve
.
to
(
torch_device
)
...
...
tests/pipelines/stable_diffusion/test_cycle_diffusion.py
View file @
554b374d
...
@@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image
=
init_image
.
resize
((
512
,
512
))
init_image
=
init_image
.
resize
((
512
,
512
))
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
scheduler
=
DDIMScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
model_id
,
scheduler
=
scheduler
,
safety_checker
=
None
,
torch_dtype
=
torch
.
float16
,
revision
=
"fp16"
model_id
,
scheduler
=
scheduler
,
safety_checker
=
None
,
torch_dtype
=
torch
.
float16
,
revision
=
"fp16"
)
)
...
@@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image
=
init_image
.
resize
((
512
,
512
))
init_image
=
init_image
.
resize
((
512
,
512
))
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
scheduler
=
DDIMScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
model_id
,
scheduler
=
scheduler
,
safety_checker
=
None
)
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
model_id
,
scheduler
=
scheduler
,
safety_checker
=
None
)
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
...
...
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
View file @
554b374d
...
@@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-3
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-3
def
test_inference_ddim
(
self
):
def
test_inference_ddim
(
self
):
ddim_scheduler
=
DDIMScheduler
.
from_
config
(
ddim_scheduler
=
DDIMScheduler
.
from_
pretrained
(
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
)
)
sd_pipe
=
OnnxStableDiffusionPipeline
.
from_pretrained
(
sd_pipe
=
OnnxStableDiffusionPipeline
.
from_pretrained
(
...
@@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-3
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-3
def
test_inference_k_lms
(
self
):
def
test_inference_k_lms
(
self
):
lms_scheduler
=
LMSDiscreteScheduler
.
from_
config
(
lms_scheduler
=
LMSDiscreteScheduler
.
from_
pretrained
(
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
)
)
sd_pipe
=
OnnxStableDiffusionPipeline
.
from_pretrained
(
sd_pipe
=
OnnxStableDiffusionPipeline
.
from_pretrained
(
...
...
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
View file @
554b374d
...
@@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
...
@@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
"/img2img/sketch-mountains-input.jpg"
"/img2img/sketch-mountains-input.jpg"
)
)
init_image
=
init_image
.
resize
((
768
,
512
))
init_image
=
init_image
.
resize
((
768
,
512
))
lms_scheduler
=
LMSDiscreteScheduler
.
from_
config
(
lms_scheduler
=
LMSDiscreteScheduler
.
from_
pretrained
(
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
"runwayml/stable-diffusion-v1-5"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
)
)
pipe
=
OnnxStableDiffusionImg2ImgPipeline
.
from_pretrained
(
pipe
=
OnnxStableDiffusionImg2ImgPipeline
.
from_pretrained
(
...
...
tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
View file @
554b374d
...
@@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
)
lms_scheduler
=
LMSDiscreteScheduler
.
from_
config
(
lms_scheduler
=
LMSDiscreteScheduler
.
from_
pretrained
(
"runwayml/stable-diffusion-inpainting"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
"runwayml/stable-diffusion-inpainting"
,
subfolder
=
"scheduler"
,
revision
=
"onnx"
)
)
pipe
=
OnnxStableDiffusionInpaintPipeline
.
from_pretrained
(
pipe
=
OnnxStableDiffusionInpaintPipeline
.
from_pretrained
(
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion.py
View file @
554b374d
...
@@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_fast_ddim
(
self
):
def
test_stable_diffusion_fast_ddim
(
self
):
scheduler
=
DDIMScheduler
.
from_
config
(
"CompVis/stable-diffusion-v1-1"
,
subfolder
=
"scheduler"
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
"CompVis/stable-diffusion-v1-1"
,
subfolder
=
"scheduler"
)
sd_pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-1"
,
scheduler
=
scheduler
)
sd_pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-1"
,
scheduler
=
scheduler
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
...
@@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
model_id
=
"CompVis/stable-diffusion-v1-1"
model_id
=
"CompVis/stable-diffusion-v1-1"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
model_id
).
to
(
torch_device
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
model_id
).
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
scheduler
=
LMSDiscreteScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
scheduler
=
LMSDiscreteScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
.
scheduler
=
scheduler
pipe
.
scheduler
=
scheduler
prompt
=
"a photograph of an astronaut riding a horse"
prompt
=
"a photograph of an astronaut riding a horse"
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
View file @
554b374d
...
@@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
...
@@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
)
)
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
lms
=
LMSDiscreteScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
lms
=
LMSDiscreteScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
model_id
,
model_id
,
scheduler
=
lms
,
scheduler
=
lms
,
...
@@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
...
@@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
)
)
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
ddim
=
DDIMScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
ddim
=
DDIMScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
model_id
,
model_id
,
scheduler
=
ddim
,
scheduler
=
ddim
,
...
@@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
...
@@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image
=
init_image
.
resize
((
768
,
512
))
init_image
=
init_image
.
resize
((
768
,
512
))
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
lms
=
LMSDiscreteScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
lms
=
LMSDiscreteScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
model_id
,
scheduler
=
lms
,
safety_checker
=
None
,
device_map
=
"auto"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
model_id
,
scheduler
=
lms
,
safety_checker
=
None
,
device_map
=
"auto"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
)
)
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
View file @
554b374d
...
@@ -400,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -400,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
)
model_id
=
"runwayml/stable-diffusion-inpainting"
model_id
=
"runwayml/stable-diffusion-inpainting"
pndm
=
PNDMScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
pndm
=
PNDMScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
safety_checker
=
None
,
scheduler
=
pndm
)
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
safety_checker
=
None
,
scheduler
=
pndm
)
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -437,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -437,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
)
model_id
=
"runwayml/stable-diffusion-inpainting"
model_id
=
"runwayml/stable-diffusion-inpainting"
pndm
=
PNDMScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
pndm
=
PNDMScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
model_id
,
safety_checker
=
None
,
safety_checker
=
None
,
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py
View file @
554b374d
...
@@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
...
@@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
)
)
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
lms
=
LMSDiscreteScheduler
.
from_
config
(
model_id
,
subfolder
=
"scheduler"
)
lms
=
LMSDiscreteScheduler
.
from_
pretrained
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
model_id
,
scheduler
=
lms
,
scheduler
=
lms
,
...
...
tests/test_config.py
View file @
554b374d
...
@@ -13,12 +13,9 @@
...
@@ -13,12 +13,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
json
import
os
import
tempfile
import
tempfile
import
unittest
import
unittest
import
diffusers
from
diffusers
import
(
from
diffusers
import
(
DDIMScheduler
,
DDIMScheduler
,
DDPMScheduler
,
DDPMScheduler
,
...
@@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin):
...
@@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin):
class
ConfigTester
(
unittest
.
TestCase
):
class
ConfigTester
(
unittest
.
TestCase
):
def
test_load_not_from_mixin
(
self
):
def
test_load_not_from_mixin
(
self
):
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
ConfigMixin
.
from
_config
(
"dummy_path"
)
ConfigMixin
.
load
_config
(
"dummy_path"
)
def
test_register_to_config
(
self
):
def
test_register_to_config
(
self
):
obj
=
SampleObject
()
obj
=
SampleObject
()
...
@@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
obj
.
save_config
(
tmpdirname
)
new_obj
=
SampleObject
.
from_config
(
tmpdirname
)
new_obj
=
SampleObject
.
from_config
(
SampleObject
.
load_config
(
tmpdirname
)
)
new_config
=
new_obj
.
config
new_config
=
new_obj
.
config
# unfreeze configs
# unfreeze configs
...
@@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase):
...
@@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase):
assert
new_config
.
pop
(
"c"
)
==
[
2
,
5
]
# saved & loaded as list because of json
assert
new_config
.
pop
(
"c"
)
==
[
2
,
5
]
# saved & loaded as list because of json
assert
config
==
new_config
assert
config
==
new_config
def
test_save_load_from_different_config
(
self
):
obj
=
SampleObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SampleObject"
,
SampleObject
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_1
:
new_obj_1
=
SampleObject2
.
from_config
(
tmpdirname
)
# now save a config parameter that is not expected
with
open
(
os
.
path
.
join
(
tmpdirname
,
SampleObject
.
config_name
),
"r"
)
as
f
:
data
=
json
.
load
(
f
)
data
[
"unexpected"
]
=
True
with
open
(
os
.
path
.
join
(
tmpdirname
,
SampleObject
.
config_name
),
"w"
)
as
f
:
json
.
dump
(
data
,
f
)
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
new_obj_2
=
SampleObject
.
from_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_3
:
new_obj_3
=
SampleObject2
.
from_config
(
tmpdirname
)
assert
new_obj_1
.
__class__
==
SampleObject2
assert
new_obj_2
.
__class__
==
SampleObject
assert
new_obj_3
.
__class__
==
SampleObject2
assert
cap_logger_1
.
out
==
""
assert
(
cap_logger_2
.
out
==
"The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.
\n
"
)
assert
cap_logger_2
.
out
.
replace
(
"SampleObject"
,
"SampleObject2"
)
==
cap_logger_3
.
out
def
test_save_load_compatible_schedulers
(
self
):
SampleObject2
.
_compatible_classes
=
[
"SampleObject"
]
SampleObject
.
_compatible_classes
=
[
"SampleObject2"
]
obj
=
SampleObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SampleObject"
,
SampleObject
)
setattr
(
diffusers
,
"SampleObject2"
,
SampleObject2
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
# now save a config parameter that is expected by another class, but not origin class
with
open
(
os
.
path
.
join
(
tmpdirname
,
SampleObject
.
config_name
),
"r"
)
as
f
:
data
=
json
.
load
(
f
)
data
[
"f"
]
=
[
0
,
0
]
data
[
"unexpected"
]
=
True
with
open
(
os
.
path
.
join
(
tmpdirname
,
SampleObject
.
config_name
),
"w"
)
as
f
:
json
.
dump
(
data
,
f
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
new_obj
=
SampleObject
.
from_config
(
tmpdirname
)
assert
new_obj
.
__class__
==
SampleObject
assert
(
cap_logger
.
out
==
"The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.
\n
"
)
def
test_save_load_from_different_config_comp_schedulers
(
self
):
SampleObject3
.
_compatible_classes
=
[
"SampleObject"
,
"SampleObject2"
]
SampleObject2
.
_compatible_classes
=
[
"SampleObject"
,
"SampleObject3"
]
SampleObject
.
_compatible_classes
=
[
"SampleObject2"
,
"SampleObject3"
]
obj
=
SampleObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SampleObject"
,
SampleObject
)
setattr
(
diffusers
,
"SampleObject2"
,
SampleObject2
)
setattr
(
diffusers
,
"SampleObject3"
,
SampleObject3
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
.
setLevel
(
diffusers
.
logging
.
INFO
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_1
:
new_obj_1
=
SampleObject
.
from_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
new_obj_2
=
SampleObject2
.
from_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_3
:
new_obj_3
=
SampleObject3
.
from_config
(
tmpdirname
)
assert
new_obj_1
.
__class__
==
SampleObject
assert
new_obj_2
.
__class__
==
SampleObject2
assert
new_obj_3
.
__class__
==
SampleObject3
assert
cap_logger_1
.
out
==
""
assert
cap_logger_2
.
out
==
"{'f'} was not found in config. Values will be initialized to default values.
\n
"
assert
cap_logger_3
.
out
==
"{'f'} was not found in config. Values will be initialized to default values.
\n
"
def
test_load_ddim_from_pndm
(
self
):
def
test_load_ddim_from_pndm
(
self
):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
ddim
=
DDIMScheduler
.
from_config
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
ddim
=
DDIMScheduler
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
assert
ddim
.
__class__
==
DDIMScheduler
assert
ddim
.
__class__
==
DDIMScheduler
# no warning should be thrown
# no warning should be thrown
...
@@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
euler
=
EulerDiscreteScheduler
.
from_
config
(
euler
=
EulerDiscreteScheduler
.
from_
pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
)
...
@@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
euler
=
EulerAncestralDiscreteScheduler
.
from_
config
(
euler
=
EulerAncestralDiscreteScheduler
.
from_
pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
)
...
@@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase):
...
@@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
pndm
=
PNDMScheduler
.
from_config
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
pndm
=
PNDMScheduler
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
assert
pndm
.
__class__
==
PNDMScheduler
assert
pndm
.
__class__
==
PNDMScheduler
# no warning should be thrown
# no warning should be thrown
...
@@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
ddpm
=
DDPMScheduler
.
from_
config
(
ddpm
=
DDPMScheduler
.
from_
pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
,
subfolder
=
"scheduler"
,
predict_epsilon
=
False
,
predict_epsilon
=
False
,
...
@@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase):
)
)
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
ddpm_2
=
DDPMScheduler
.
from_
config
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
ddpm_2
=
DDPMScheduler
.
from_
pretrained
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
assert
ddpm
.
__class__
==
DDPMScheduler
assert
ddpm
.
__class__
==
DDPMScheduler
assert
ddpm
.
config
.
predict_epsilon
is
False
assert
ddpm
.
config
.
predict_epsilon
is
False
...
@@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase):
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
dpm
=
DPMSolverMultistepScheduler
.
from_
config
(
dpm
=
DPMSolverMultistepScheduler
.
from_
pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
)
...
...
tests/test_modeling_common.py
View file @
554b374d
...
@@ -130,7 +130,7 @@ class ModelTesterMixin:
...
@@ -130,7 +130,7 @@ class ModelTesterMixin:
expected_arg_names
=
[
"sample"
,
"timestep"
]
expected_arg_names
=
[
"sample"
,
"timestep"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_
config
(
self
):
def
test_model_from_
pretrained
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
=
self
.
model_class
(
**
init_dict
)
...
@@ -140,8 +140,8 @@ class ModelTesterMixin:
...
@@ -140,8 +140,8 @@ class ModelTesterMixin:
# test if the model can be loaded from the config
# test if the model can be loaded from the config
# and has all the expected shape
# and has all the expected shape
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_
config
(
tmpdirname
)
model
.
save_
pretrained
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_
config
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_
pretrained
(
tmpdirname
)
new_model
.
to
(
torch_device
)
new_model
.
to
(
torch_device
)
new_model
.
eval
()
new_model
.
eval
()
...
...
tests/test_pipelines.py
View file @
554b374d
...
@@ -29,6 +29,10 @@ from diffusers import (
...
@@ -29,6 +29,10 @@ from diffusers import (
DDIMScheduler
,
DDIMScheduler
,
DDPMPipeline
,
DDPMPipeline
,
DDPMScheduler
,
DDPMScheduler
,
DPMSolverMultistepScheduler
,
EulerAncestralDiscreteScheduler
,
EulerDiscreteScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
,
PNDMScheduler
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionInpaintPipelineLegacy
,
StableDiffusionInpaintPipelineLegacy
,
...
@@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase):
...
@@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase):
assert
image_img2img
.
shape
==
(
1
,
32
,
32
,
3
)
assert
image_img2img
.
shape
==
(
1
,
32
,
32
,
3
)
assert
image_text2img
.
shape
==
(
1
,
128
,
128
,
3
)
assert
image_text2img
.
shape
==
(
1
,
128
,
128
,
3
)
def
test_set_scheduler
(
self
):
unet
=
self
.
dummy_cond_unet
scheduler
=
PNDMScheduler
(
skip_prk_steps
=
True
)
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
sd
=
StableDiffusionPipeline
(
unet
=
unet
,
scheduler
=
scheduler
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
None
,
feature_extractor
=
self
.
dummy_extractor
,
)
sd
.
scheduler
=
DDIMScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
DDIMScheduler
)
sd
.
scheduler
=
DDPMScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
DDPMScheduler
)
sd
.
scheduler
=
PNDMScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
PNDMScheduler
)
sd
.
scheduler
=
LMSDiscreteScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
LMSDiscreteScheduler
)
sd
.
scheduler
=
EulerDiscreteScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
EulerDiscreteScheduler
)
sd
.
scheduler
=
EulerAncestralDiscreteScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
EulerAncestralDiscreteScheduler
)
sd
.
scheduler
=
DPMSolverMultistepScheduler
.
from_config
(
sd
.
scheduler
.
config
)
assert
isinstance
(
sd
.
scheduler
,
DPMSolverMultistepScheduler
)
def
test_set_scheduler_consistency
(
self
):
unet
=
self
.
dummy_cond_unet
pndm
=
PNDMScheduler
.
from_config
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
ddim
=
DDIMScheduler
.
from_config
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
)
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
sd
=
StableDiffusionPipeline
(
unet
=
unet
,
scheduler
=
pndm
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
None
,
feature_extractor
=
self
.
dummy_extractor
,
)
pndm_config
=
sd
.
scheduler
.
config
sd
.
scheduler
=
DDPMScheduler
.
from_config
(
pndm_config
)
sd
.
scheduler
=
PNDMScheduler
.
from_config
(
sd
.
scheduler
.
config
)
pndm_config_2
=
sd
.
scheduler
.
config
pndm_config_2
=
{
k
:
v
for
k
,
v
in
pndm_config_2
.
items
()
if
k
in
pndm_config
}
assert
dict
(
pndm_config
)
==
dict
(
pndm_config_2
)
sd
=
StableDiffusionPipeline
(
unet
=
unet
,
scheduler
=
ddim
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
None
,
feature_extractor
=
self
.
dummy_extractor
,
)
ddim_config
=
sd
.
scheduler
.
config
sd
.
scheduler
=
LMSDiscreteScheduler
.
from_config
(
ddim_config
)
sd
.
scheduler
=
DDIMScheduler
.
from_config
(
sd
.
scheduler
.
config
)
ddim_config_2
=
sd
.
scheduler
.
config
ddim_config_2
=
{
k
:
v
for
k
,
v
in
ddim_config_2
.
items
()
if
k
in
ddim_config
}
assert
dict
(
ddim_config
)
==
dict
(
ddim_config_2
)
@
slow
@
slow
class
PipelineSlowTests
(
unittest
.
TestCase
):
class
PipelineSlowTests
(
unittest
.
TestCase
):
...
@@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase):
...
@@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase):
def
test_output_format
(
self
):
def
test_output_format
(
self
):
model_path
=
"google/ddpm-cifar10-32"
model_path
=
"google/ddpm-cifar10-32"
scheduler
=
DDIMScheduler
.
from_
config
(
model_path
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
model_path
)
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
)
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
)
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
tests/test_scheduler.py
View file @
554b374d
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
inspect
import
inspect
import
json
import
os
import
tempfile
import
tempfile
import
unittest
import
unittest
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
...
@@ -21,6 +23,7 @@ import numpy as np
...
@@ -21,6 +23,7 @@ import numpy as np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
diffusers
from
diffusers
import
(
from
diffusers
import
(
DDIMScheduler
,
DDIMScheduler
,
DDPMScheduler
,
DDPMScheduler
,
...
@@ -32,13 +35,180 @@ from diffusers import (
...
@@ -32,13 +35,180 @@ from diffusers import (
PNDMScheduler
,
PNDMScheduler
,
ScoreSdeVeScheduler
,
ScoreSdeVeScheduler
,
VQDiffusionScheduler
,
VQDiffusionScheduler
,
logging
,
)
)
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.schedulers.scheduling_utils
import
SchedulerMixin
from
diffusers.utils
import
deprecate
,
torch_device
from
diffusers.utils
import
deprecate
,
torch_device
from
diffusers.utils.testing_utils
import
CaptureLogger
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
SchedulerObject
(
SchedulerMixin
,
ConfigMixin
):
config_name
=
"config.json"
@
register_to_config
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
):
pass
class
SchedulerObject2
(
SchedulerMixin
,
ConfigMixin
):
config_name
=
"config.json"
@
register_to_config
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
f
=
[
1
,
3
],
):
pass
class
SchedulerObject3
(
SchedulerMixin
,
ConfigMixin
):
config_name
=
"config.json"
@
register_to_config
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
f
=
[
1
,
3
],
):
pass
class
SchedulerBaseTests
(
unittest
.
TestCase
):
def
test_save_load_from_different_config
(
self
):
obj
=
SchedulerObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SchedulerObject"
,
SchedulerObject
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_1
:
config
=
SchedulerObject2
.
load_config
(
tmpdirname
)
new_obj_1
=
SchedulerObject2
.
from_config
(
config
)
# now save a config parameter that is not expected
with
open
(
os
.
path
.
join
(
tmpdirname
,
SchedulerObject
.
config_name
),
"r"
)
as
f
:
data
=
json
.
load
(
f
)
data
[
"unexpected"
]
=
True
with
open
(
os
.
path
.
join
(
tmpdirname
,
SchedulerObject
.
config_name
),
"w"
)
as
f
:
json
.
dump
(
data
,
f
)
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
config
=
SchedulerObject
.
load_config
(
tmpdirname
)
new_obj_2
=
SchedulerObject
.
from_config
(
config
)
with
CaptureLogger
(
logger
)
as
cap_logger_3
:
config
=
SchedulerObject2
.
load_config
(
tmpdirname
)
new_obj_3
=
SchedulerObject2
.
from_config
(
config
)
assert
new_obj_1
.
__class__
==
SchedulerObject2
assert
new_obj_2
.
__class__
==
SchedulerObject
assert
new_obj_3
.
__class__
==
SchedulerObject2
assert
cap_logger_1
.
out
==
""
assert
(
cap_logger_2
.
out
==
"The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
" will"
" be ignored. Please verify your config.json configuration file.
\n
"
)
assert
cap_logger_2
.
out
.
replace
(
"SchedulerObject"
,
"SchedulerObject2"
)
==
cap_logger_3
.
out
def
test_save_load_compatible_schedulers
(
self
):
SchedulerObject2
.
_compatibles
=
[
"SchedulerObject"
]
SchedulerObject
.
_compatibles
=
[
"SchedulerObject2"
]
obj
=
SchedulerObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SchedulerObject"
,
SchedulerObject
)
setattr
(
diffusers
,
"SchedulerObject2"
,
SchedulerObject2
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
# now save a config parameter that is expected by another class, but not origin class
with
open
(
os
.
path
.
join
(
tmpdirname
,
SchedulerObject
.
config_name
),
"r"
)
as
f
:
data
=
json
.
load
(
f
)
data
[
"f"
]
=
[
0
,
0
]
data
[
"unexpected"
]
=
True
with
open
(
os
.
path
.
join
(
tmpdirname
,
SchedulerObject
.
config_name
),
"w"
)
as
f
:
json
.
dump
(
data
,
f
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
config
=
SchedulerObject
.
load_config
(
tmpdirname
)
new_obj
=
SchedulerObject
.
from_config
(
config
)
assert
new_obj
.
__class__
==
SchedulerObject
assert
(
cap_logger
.
out
==
"The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
" will"
" be ignored. Please verify your config.json configuration file.
\n
"
)
def
test_save_load_from_different_config_comp_schedulers
(
self
):
SchedulerObject3
.
_compatibles
=
[
"SchedulerObject"
,
"SchedulerObject2"
]
SchedulerObject2
.
_compatibles
=
[
"SchedulerObject"
,
"SchedulerObject3"
]
SchedulerObject
.
_compatibles
=
[
"SchedulerObject2"
,
"SchedulerObject3"
]
obj
=
SchedulerObject
()
# mock add obj class to `diffusers`
setattr
(
diffusers
,
"SchedulerObject"
,
SchedulerObject
)
setattr
(
diffusers
,
"SchedulerObject2"
,
SchedulerObject2
)
setattr
(
diffusers
,
"SchedulerObject3"
,
SchedulerObject3
)
logger
=
logging
.
get_logger
(
"diffusers.configuration_utils"
)
logger
.
setLevel
(
diffusers
.
logging
.
INFO
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
with
CaptureLogger
(
logger
)
as
cap_logger_1
:
config
=
SchedulerObject
.
load_config
(
tmpdirname
)
new_obj_1
=
SchedulerObject
.
from_config
(
config
)
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
config
=
SchedulerObject2
.
load_config
(
tmpdirname
)
new_obj_2
=
SchedulerObject2
.
from_config
(
config
)
with
CaptureLogger
(
logger
)
as
cap_logger_3
:
config
=
SchedulerObject3
.
load_config
(
tmpdirname
)
new_obj_3
=
SchedulerObject3
.
from_config
(
config
)
assert
new_obj_1
.
__class__
==
SchedulerObject
assert
new_obj_2
.
__class__
==
SchedulerObject2
assert
new_obj_3
.
__class__
==
SchedulerObject3
assert
cap_logger_1
.
out
==
""
assert
cap_logger_2
.
out
==
"{'f'} was not found in config. Values will be initialized to default values.
\n
"
assert
cap_logger_3
.
out
==
"{'f'} was not found in config. Values will be initialized to default values.
\n
"
class
SchedulerCommonTest
(
unittest
.
TestCase
):
class
SchedulerCommonTest
(
unittest
.
TestCase
):
scheduler_classes
=
()
scheduler_classes
=
()
forward_default_kwargs
=
()
forward_default_kwargs
=
()
...
@@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
scheduler
.
set_timesteps
(
num_inference_steps
)
...
@@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
scheduler
.
set_timesteps
(
num_inference_steps
)
...
@@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
scheduler
.
set_timesteps
(
num_inference_steps
)
...
@@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase):
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_compatibles
(
self
):
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
assert
all
(
c
is
not
None
for
c
in
scheduler
.
compatibles
)
for
comp_scheduler_cls
in
scheduler
.
compatibles
:
comp_scheduler
=
comp_scheduler_cls
.
from_config
(
scheduler
.
config
)
assert
comp_scheduler
is
not
None
new_scheduler
=
scheduler_class
.
from_config
(
comp_scheduler
.
config
)
new_scheduler_config
=
{
k
:
v
for
k
,
v
in
new_scheduler
.
config
.
items
()
if
k
in
scheduler
.
config
}
scheduler_diff
=
{
k
:
v
for
k
,
v
in
new_scheduler
.
config
.
items
()
if
k
not
in
scheduler
.
config
}
# make sure that configs are essentially identical
assert
new_scheduler_config
==
dict
(
scheduler
.
config
)
# make sure that only differences are for configs that are not in init
init_keys
=
inspect
.
signature
(
scheduler_class
.
__init__
).
parameters
.
keys
()
assert
set
(
scheduler_diff
.
keys
()).
intersection
(
set
(
init_keys
))
==
set
()
def
test_from_pretrained
(
self
):
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_pretrained
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_pretrained
(
tmpdirname
)
assert
scheduler
.
config
==
new_scheduler
.
config
def
test_step_shape
(
self
):
def
test_step_shape
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
...
@@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
...
@@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
model_outputs
=
dummy_past_residuals
[:
new_scheduler
.
config
.
solver_order
]
new_scheduler
.
model_outputs
=
dummy_past_residuals
[:
new_scheduler
.
config
.
solver_order
]
...
@@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
...
@@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
...
@@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
...
@@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
...
@@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
...
@@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
output
=
scheduler
.
step_pred
(
output
=
scheduler
.
step_pred
(
residual
,
time_step
,
sample
,
generator
=
torch
.
manual_seed
(
0
),
**
kwargs
residual
,
time_step
,
sample
,
generator
=
torch
.
manual_seed
(
0
),
**
kwargs
...
@@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
...
@@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
output
=
scheduler
.
step_pred
(
output
=
scheduler
.
step_pred
(
residual
,
time_step
,
sample
,
generator
=
torch
.
manual_seed
(
0
),
**
kwargs
residual
,
time_step
,
sample
,
generator
=
torch
.
manual_seed
(
0
),
**
kwargs
...
@@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
...
@@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
...
@@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
...
@@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
...
...
tests/test_scheduler_flax.py
View file @
554b374d
...
@@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
...
@@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
...
@@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
...
@@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
state
=
scheduler
.
set_timesteps
(
state
,
num_inference_steps
)
...
@@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
new_state
=
new_scheduler
.
set_timesteps
(
new_state
,
num_inference_steps
,
shape
=
sample
.
shape
)
new_state
=
new_scheduler
.
set_timesteps
(
new_state
,
num_inference_steps
,
shape
=
sample
.
shape
)
# copy over dummy past residuals
# copy over dummy past residuals
new_state
=
new_state
.
replace
(
ets
=
dummy_past_residuals
[:])
new_state
=
new_state
.
replace
(
ets
=
dummy_past_residuals
[:])
...
@@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
config
(
tmpdirname
)
new_scheduler
,
new_state
=
scheduler_class
.
from_
pretrained
(
tmpdirname
)
# copy over dummy past residuals
# copy over dummy past residuals
new_state
=
new_scheduler
.
set_timesteps
(
new_state
,
num_inference_steps
,
shape
=
sample
.
shape
)
new_state
=
new_scheduler
.
set_timesteps
(
new_state
,
num_inference_steps
,
shape
=
sample
.
shape
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment