Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
7c226264
Unverified
Commit
7c226264
authored
Oct 13, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 13, 2022
Browse files
Align PT and Flax API - allow loading checkpoint from PyTorch configs (#827)
* up * finish * add more tests * up * up * finish
parent
78db11db
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
165 additions
and
31 deletions
+165
-31
src/diffusers/pipeline_flax_utils.py
src/diffusers/pipeline_flax_utils.py
+32
-23
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
...elines/stable_diffusion/pipeline_flax_stable_diffusion.py
+31
-4
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+1
-1
tests/test_pipelines_flax.py
tests/test_pipelines_flax.py
+99
-1
No files found.
src/diffusers/pipeline_flax_utils.py
View file @
7c226264
...
@@ -111,24 +111,27 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -111,24 +111,27 @@ class FlaxDiffusionPipeline(ConfigMixin):
from
diffusers
import
pipelines
from
diffusers
import
pipelines
for
name
,
module
in
kwargs
.
items
():
for
name
,
module
in
kwargs
.
items
():
# retrieve library
if
module
is
None
:
library
=
module
.
__module__
.
split
(
"."
)[
0
]
register_dict
=
{
name
:
(
None
,
None
)}
else
:
# retrieve library
library
=
module
.
__module__
.
split
(
"."
)[
0
]
# check if the module is a pipeline module
# check if the module is a pipeline module
pipeline_dir
=
module
.
__module__
.
split
(
"."
)[
-
2
]
pipeline_dir
=
module
.
__module__
.
split
(
"."
)[
-
2
]
path
=
module
.
__module__
.
split
(
"."
)
path
=
module
.
__module__
.
split
(
"."
)
is_pipeline_module
=
pipeline_dir
in
path
and
hasattr
(
pipelines
,
pipeline_dir
)
is_pipeline_module
=
pipeline_dir
in
path
and
hasattr
(
pipelines
,
pipeline_dir
)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
# folder so we set the library to module name.
if
library
not
in
LOADABLE_CLASSES
or
is_pipeline_module
:
if
library
not
in
LOADABLE_CLASSES
or
is_pipeline_module
:
library
=
pipeline_dir
library
=
pipeline_dir
# retrieve class_name
# retrieve class_name
class_name
=
module
.
__class__
.
__name__
class_name
=
module
.
__class__
.
__name__
register_dict
=
{
name
:
(
library
,
class_name
)}
register_dict
=
{
name
:
(
library
,
class_name
)}
# save model index config
# save model index config
self
.
register_to_config
(
**
register_dict
)
self
.
register_to_config
(
**
register_dict
)
...
@@ -320,6 +323,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -320,6 +323,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
pipeline_class
=
cls
pipeline_class
=
cls
else
:
else
:
diffusers_module
=
importlib
.
import_module
(
cls
.
__module__
.
split
(
"."
)[
0
])
diffusers_module
=
importlib
.
import_module
(
cls
.
__module__
.
split
(
"."
)[
0
])
class_name
=
(
config_dict
[
"_class_name"
]
if
config_dict
[
"_class_name"
].
startswith
(
"Flax"
)
else
"Flax"
+
config_dict
[
"_class_name"
]
)
pipeline_class
=
getattr
(
diffusers_module
,
config_dict
[
"_class_name"
])
pipeline_class
=
getattr
(
diffusers_module
,
config_dict
[
"_class_name"
])
# some modules can be passed directly to the init
# some modules can be passed directly to the init
...
@@ -342,6 +350,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -342,6 +350,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
loaded_sub_model
=
None
loaded_sub_model
=
None
sub_model_should_be_defined
=
True
# if the model is in a pipeline module, then we load it from the pipeline
# if the model is in a pipeline module, then we load it from the pipeline
if
name
in
passed_class_obj
:
if
name
in
passed_class_obj
:
...
@@ -362,6 +371,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -362,6 +371,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
f
"
{
passed_class_obj
[
name
]
}
is of type:
{
type
(
passed_class_obj
[
name
])
}
, but should be"
f
"
{
passed_class_obj
[
name
]
}
is of type:
{
type
(
passed_class_obj
[
name
])
}
, but should be"
f
"
{
expected_class_obj
}
"
f
"
{
expected_class_obj
}
"
)
)
elif
passed_class_obj
[
name
]
is
None
:
logger
.
warn
(
f
"You have passed `None` for
{
name
}
to disable its functionality in
{
pipeline_class
}
. Note"
f
" that this might lead to problems when using
{
pipeline_class
}
and is not recommended."
)
sub_model_should_be_defined
=
False
else
:
else
:
logger
.
warn
(
logger
.
warn
(
f
"You have passed a non-standard module
{
passed_class_obj
[
name
]
}
. We cannot verify whether it"
f
"You have passed a non-standard module
{
passed_class_obj
[
name
]
}
. We cannot verify whether it"
...
@@ -372,25 +387,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -372,25 +387,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model
=
passed_class_obj
[
name
]
loaded_sub_model
=
passed_class_obj
[
name
]
elif
is_pipeline_module
:
elif
is_pipeline_module
:
pipeline_module
=
getattr
(
pipelines
,
library_name
)
pipeline_module
=
getattr
(
pipelines
,
library_name
)
if
from_pt
:
class_obj
=
import_flax_or_no_model
(
pipeline_module
,
class_name
)
class_obj
=
import_flax_or_no_model
(
pipeline_module
,
class_name
)
else
:
class_obj
=
getattr
(
pipeline_module
,
class_name
)
importable_classes
=
ALL_IMPORTABLE_CLASSES
importable_classes
=
ALL_IMPORTABLE_CLASSES
class_candidates
=
{
c
:
class_obj
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
class_obj
for
c
in
importable_classes
.
keys
()}
else
:
else
:
# else we just import it from the library.
# else we just import it from the library.
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
if
from_pt
:
class_obj
=
import_flax_or_no_model
(
library
,
class_name
)
class_obj
=
import_flax_or_no_model
(
library
,
class_name
)
else
:
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
if
loaded_sub_model
is
None
:
if
loaded_sub_model
is
None
and
sub_model_should_be_defined
:
load_method_name
=
None
load_method_name
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
if
issubclass
(
class_obj
,
class_candidate
):
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
View file @
7c226264
...
@@ -14,10 +14,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
...
@@ -14,10 +14,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from
...models
import
FlaxAutoencoderKL
,
FlaxUNet2DConditionModel
from
...models
import
FlaxAutoencoderKL
,
FlaxUNet2DConditionModel
from
...pipeline_flax_utils
import
FlaxDiffusionPipeline
from
...pipeline_flax_utils
import
FlaxDiffusionPipeline
from
...schedulers
import
FlaxDDIMScheduler
,
FlaxLMSDiscreteScheduler
,
FlaxPNDMScheduler
from
...schedulers
import
FlaxDDIMScheduler
,
FlaxLMSDiscreteScheduler
,
FlaxPNDMScheduler
from
...utils
import
logging
from
.
import
FlaxStableDiffusionPipelineOutput
from
.
import
FlaxStableDiffusionPipelineOutput
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
class
FlaxStableDiffusionPipeline
(
FlaxDiffusionPipeline
):
class
FlaxStableDiffusionPipeline
(
FlaxDiffusionPipeline
):
r
"""
r
"""
Pipeline for text-to-image generation using Stable Diffusion.
Pipeline for text-to-image generation using Stable Diffusion.
...
@@ -60,6 +64,16 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
...
@@ -60,6 +64,16 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
super
().
__init__
()
super
().
__init__
()
self
.
dtype
=
dtype
self
.
dtype
=
dtype
if
safety_checker
is
None
:
logger
.
warn
(
f
"You have disabled the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self
.
register_modules
(
self
.
register_modules
(
vae
=
vae
,
vae
=
vae
,
text_encoder
=
text_encoder
,
text_encoder
=
text_encoder
,
...
@@ -265,10 +279,23 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
...
@@ -265,10 +279,23 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
,
height
,
width
,
guidance_scale
,
latents
,
debug
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
,
height
,
width
,
guidance_scale
,
latents
,
debug
)
)
safety_params
=
params
[
"safety_checker"
]
if
self
.
safety_checker
is
not
None
:
images
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
safety_params
=
params
[
"safety_checker"
]
images
=
np
.
asarray
(
images
).
reshape
(
-
1
,
height
,
width
,
3
)
images_uint8_casted
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
images
,
has_nsfw_concept
=
self
.
_run_safety_checker
(
images
,
safety_params
,
jit
)
num_devices
,
batch_size
=
images
.
shape
[:
2
]
images_uint8_casted
=
np
.
asarray
(
images_uint8_casted
).
reshape
(
num_devices
*
batch_size
,
height
,
width
,
3
)
images_uint8_casted
,
has_nsfw_concept
=
self
.
_run_safety_checker
(
images_uint8_casted
,
safety_params
,
jit
)
images
=
np
.
asarray
(
images
)
# block images
if
any
(
has_nsfw_concept
):
for
i
,
is_nsfw
in
enumerate
(
has_nsfw_concept
):
images
[
i
]
=
np
.
asarray
(
images_uint8_casted
[
i
])
images
=
images
.
reshape
(
num_devices
,
batch_size
,
height
,
width
,
3
)
else
:
has_nsfw_concept
=
False
if
not
return_dict
:
if
not
return_dict
:
return
(
images
,
has_nsfw_concept
)
return
(
images
,
has_nsfw_concept
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
7c226264
...
@@ -73,7 +73,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -73,7 +73,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
if
safety_checker
is
None
:
if
safety_checker
is
None
:
logger
.
warn
(
logger
.
warn
(
f
"You have disabed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
f
"You have disab
l
ed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
7c226264
...
@@ -85,7 +85,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -85,7 +85,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if
safety_checker
is
None
:
if
safety_checker
is
None
:
logger
.
warn
(
logger
.
warn
(
f
"You have disabed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
f
"You have disab
l
ed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
7c226264
...
@@ -100,7 +100,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -100,7 +100,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if
safety_checker
is
None
:
if
safety_checker
is
None
:
logger
.
warn
(
logger
.
warn
(
f
"You have disabed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
f
"You have disab
l
ed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
...
...
tests/test_pipelines_flax.py
View file @
7c226264
...
@@ -23,6 +23,7 @@ from diffusers.utils.testing_utils import require_flax, slow
...
@@ -23,6 +23,7 @@ from diffusers.utils.testing_utils import require_flax, slow
if
is_flax_available
():
if
is_flax_available
():
import
jax
import
jax
import
jax.numpy
as
jnp
from
diffusers
import
FlaxStableDiffusionPipeline
from
diffusers
import
FlaxStableDiffusionPipeline
from
flax.jax_utils
import
replicate
from
flax.jax_utils
import
replicate
from
flax.training.common_utils
import
shard
from
flax.training.common_utils
import
shard
...
@@ -34,7 +35,7 @@ if is_flax_available():
...
@@ -34,7 +35,7 @@ if is_flax_available():
class
FlaxPipelineTests
(
unittest
.
TestCase
):
class
FlaxPipelineTests
(
unittest
.
TestCase
):
def
test_dummy_all_tpus
(
self
):
def
test_dummy_all_tpus
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
"hf-internal-testing/tiny-stable-diffusion-pipe"
,
safety_checker
=
None
)
)
prompt
=
(
prompt
=
(
...
@@ -57,6 +58,103 @@ class FlaxPipelineTests(unittest.TestCase):
...
@@ -57,6 +58,103 @@ class FlaxPipelineTests(unittest.TestCase):
prompt_ids
=
shard
(
prompt_ids
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
assert
images
.
shape
==
(
8
,
1
,
64
,
64
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
4.151474
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
49947.875
))
<
1e-2
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
assert
len
(
images_pil
)
==
8
assert
len
(
images_pil
)
==
8
def
test_stable_diffusion_v1_4
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"flax"
,
safety_checker
=
None
)
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
50
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
p_sample
=
pmap
(
pipeline
.
__call__
,
static_broadcasted_argnums
=
(
3
,))
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
for
i
,
image
in
enumerate
(
images_pil
):
image
.
save
(
f
"/home/patrick/images/flax-test-
{
i
}
_fp32.png"
)
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.05652401
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2383808.2
))
<
1e-2
def
test_stable_diffusion_v1_4_bfloat_16
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"bf16"
,
dtype
=
jnp
.
bfloat16
,
safety_checker
=
None
)
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
50
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
p_sample
=
pmap
(
pipeline
.
__call__
,
static_broadcasted_argnums
=
(
3
,))
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.06652832
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2384849.8
))
<
1e-2
def
test_stable_diffusion_v1_4_bfloat_16_with_safety
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"bf16"
,
dtype
=
jnp
.
bfloat16
)
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
50
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
pipeline
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
,
jit
=
True
).
images
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.06652832
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2384849.8
))
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment