Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3b48620f
Commit
3b48620f
authored
Nov 17, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
parents
3fb28c44
632dacea
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
126 additions
and
2 deletions
+126
-2
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+10
-1
tests/fixtures/custom_pipeline/what_ever.py
tests/fixtures/custom_pipeline/what_ever.py
+101
-0
tests/test_pipelines.py
tests/test_pipelines.py
+15
-1
No files found.
src/diffusers/pipeline_utils.py
View file @
3b48620f
...
@@ -18,6 +18,7 @@ import importlib
...
@@ -18,6 +18,7 @@ import importlib
import
inspect
import
inspect
import
os
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -483,8 +484,16 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -483,8 +484,16 @@ class DiffusionPipeline(ConfigMixin):
# 2. Load the pipeline class, if using custom module then load it from the hub
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
# if we load from explicit class, let's use it
if
custom_pipeline
is
not
None
:
if
custom_pipeline
is
not
None
:
if
custom_pipeline
.
endswith
(
".py"
):
path
=
Path
(
custom_pipeline
)
# decompose into folder & file
file_name
=
path
.
name
custom_pipeline
=
path
.
parent
.
absolute
()
else
:
file_name
=
CUSTOM_PIPELINE_FILE_NAME
pipeline_class
=
get_class_from_dynamic_module
(
pipeline_class
=
get_class_from_dynamic_module
(
custom_pipeline
,
module_file
=
CUSTOM_PIPELINE_FILE_NAME
,
cache_dir
=
custom_pipeline
custom_pipeline
,
module_file
=
file_name
,
cache_dir
=
custom_pipeline
)
)
elif
cls
!=
DiffusionPipeline
:
elif
cls
!=
DiffusionPipeline
:
pipeline_class
=
cls
pipeline_class
=
cls
...
...
tests/fixtures/custom_pipeline/what_ever.py
0 → 100644
View file @
3b48620f
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
,
Tuple
,
Union
import
torch
from
diffusers.pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
class
CustomLocalPipeline
(
DiffusionPipeline
):
r
"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
def
__init__
(
self
,
unet
,
scheduler
):
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
:
int
=
1
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
num_inference_steps
:
int
=
50
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
**
kwargs
,
)
->
Union
[
ImagePipelineOutput
,
Tuple
]:
r
"""
Args:
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
eta (`float`, *optional*, defaults to 0.0):
The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
# Sample gaussian noise to begin loop
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
sample_size
,
self
.
unet
.
sample_size
),
generator
=
generator
,
)
image
=
image
.
to
(
self
.
device
)
# set step values
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
for
t
in
self
.
progress_bar
(
self
.
scheduler
.
timesteps
):
# 1. predict noise model_output
model_output
=
self
.
unet
(
image
,
t
).
sample
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
).
prev_sample
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
if
not
return_dict
:
return
(
image
,),
"This is a local test"
return
ImagePipelineOutput
(
images
=
image
),
"This is a local test"
tests/test_pipelines.py
View file @
3b48620f
...
@@ -192,7 +192,7 @@ class CustomPipelineTests(unittest.TestCase):
...
@@ -192,7 +192,7 @@ class CustomPipelineTests(unittest.TestCase):
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
assert
output_str
==
"This is a test"
assert
output_str
==
"This is a test"
def
test_local_custom_pipeline
(
self
):
def
test_local_custom_pipeline
_repo
(
self
):
local_custom_pipeline_path
=
get_tests_dir
(
"fixtures/custom_pipeline"
)
local_custom_pipeline_path
=
get_tests_dir
(
"fixtures/custom_pipeline"
)
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
"google/ddpm-cifar10-32"
,
custom_pipeline
=
local_custom_pipeline_path
"google/ddpm-cifar10-32"
,
custom_pipeline
=
local_custom_pipeline_path
...
@@ -205,6 +205,20 @@ class CustomPipelineTests(unittest.TestCase):
...
@@ -205,6 +205,20 @@ class CustomPipelineTests(unittest.TestCase):
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
assert
output_str
==
"This is a local test"
assert
output_str
==
"This is a local test"
def
test_local_custom_pipeline_file
(
self
):
local_custom_pipeline_path
=
get_tests_dir
(
"fixtures/custom_pipeline"
)
local_custom_pipeline_path
=
os
.
path
.
join
(
local_custom_pipeline_path
,
"what_ever.py"
)
pipeline
=
DiffusionPipeline
.
from_pretrained
(
"google/ddpm-cifar10-32"
,
custom_pipeline
=
local_custom_pipeline_path
)
pipeline
=
pipeline
.
to
(
torch_device
)
images
,
output_str
=
pipeline
(
num_inference_steps
=
2
,
output_type
=
"np"
)
assert
pipeline
.
__class__
.
__name__
==
"CustomLocalPipeline"
assert
images
[
0
].
shape
==
(
1
,
32
,
32
,
3
)
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
assert
output_str
==
"This is a local test"
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
def
test_load_pipeline_from_git
(
self
):
def
test_load_pipeline_from_git
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment