Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
c2a38ef9
You need to sign in or sign up before continuing.
Unverified
Commit
c2a38ef9
authored
Dec 18, 2022
by
Anton Lozhkov
Committed by
GitHub
Dec 18, 2022
Browse files
Fix/update the LDM pipeline and tests (#1743)
* Fix/update LDM tests * batched generators
parent
08cc36dd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
145 additions
and
107 deletions
+145
-107
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+22
-9
tests/pipelines/latent_diffusion/test_latent_diffusion.py
tests/pipelines/latent_diffusion/test_latent_diffusion.py
+123
-98
No files found.
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
c2a38ef9
...
@@ -128,29 +128,42 @@ class LDMTextToImagePipeline(DiffusionPipeline):
...
@@ -128,29 +128,42 @@ class LDMTextToImagePipeline(DiffusionPipeline):
# get unconditional embeddings for classifier free guidance
# get unconditional embeddings for classifier free guidance
if
guidance_scale
!=
1.0
:
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
]
*
batch_size
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
uncond_input
=
self
.
tokenizer
(
[
""
]
*
batch_size
,
padding
=
"max_length"
,
max_length
=
77
,
truncation
=
True
,
return_tensors
=
"pt"
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
# get prompt text embeddings
# get prompt text embeddings
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
truncation
=
True
,
return_tensors
=
"pt"
)
text_embeddings
=
self
.
bert
(
text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
text_embeddings
=
self
.
bert
(
text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
# get the initial random noise unless the user supplied it
# get the initial random noise unless the user supplied it
latents_shape
=
(
batch_size
,
self
.
unet
.
in_channels
,
height
//
8
,
width
//
8
)
latents_shape
=
(
batch_size
,
self
.
unet
.
in_channels
,
height
//
8
,
width
//
8
)
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
raise
ValueError
(
f
"You have passed a list of generators of length
{
len
(
generator
)
}
, but requested an effective batch"
f
" size of
{
batch_size
}
. Make sure the batch size matches the length of the generators."
)
if
latents
is
None
:
if
latents
is
None
:
if
self
.
device
.
type
==
"mps"
:
rand_device
=
"cpu"
if
self
.
device
.
type
==
"mps"
else
self
.
device
# randn does not work reproducibly on mps
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
"cpu"
).
to
(
self
.
device
)
if
isinstance
(
generator
,
list
):
latents_shape
=
(
1
,)
+
latents_shape
[
1
:]
latents
=
[
torch
.
randn
(
latents_shape
,
generator
=
generator
[
i
],
device
=
rand_device
,
dtype
=
text_embeddings
.
dtype
)
for
i
in
range
(
batch_size
)
]
latents
=
torch
.
cat
(
latents
,
dim
=
0
)
else
:
else
:
latents
=
torch
.
randn
(
latents
=
torch
.
randn
(
latents_shape
,
latents_shape
,
generator
=
generator
,
device
=
rand_device
,
dtype
=
text_embeddings
.
dtype
generator
=
generator
,
device
=
self
.
device
,
)
)
latents
=
latents
.
to
(
self
.
device
)
else
:
else
:
if
latents
.
shape
!=
latents_shape
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
latents
=
latents
.
to
(
self
.
device
)
latents
=
latents
.
to
(
self
.
device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
...
...
tests/pipelines/latent_diffusion/test_latent_diffusion.py
View file @
c2a38ef9
...
@@ -13,24 +13,29 @@
...
@@ -13,24 +13,29 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
gc
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
diffusers
import
AutoencoderKL
,
DDIMScheduler
,
LDMTextToImagePipeline
,
UNet2DConditionModel
from
diffusers
import
AutoencoderKL
,
DDIMScheduler
,
LDMTextToImagePipeline
,
UNet2DConditionModel
from
diffusers.utils.testing_utils
import
require_torch
,
slow
,
torch_device
from
diffusers.utils.testing_utils
import
load_numpy
,
nightly
,
require_torch
_gpu
,
slow
,
torch_device
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
from
...test_pipelines_common
import
PipelineTesterMixin
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
LDMTextToImagePipelineFastTests
(
unittest
.
TestCase
):
class
LDMTextToImagePipelineFastTests
(
PipelineTesterMixin
,
unittest
.
TestCase
):
@
property
pipeline_class
=
LDMTextToImagePipeline
def
dummy_cond_unet
(
self
):
test_cpu_offload
=
False
def
get_dummy_components
(
self
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
model
=
UNet2DConditionModel
(
unet
=
UNet2DConditionModel
(
block_out_channels
=
(
32
,
64
),
block_out_channels
=
(
32
,
64
),
layers_per_block
=
2
,
layers_per_block
=
2
,
sample_size
=
32
,
sample_size
=
32
,
...
@@ -40,25 +45,24 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
...
@@ -40,25 +45,24 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
up_block_types
=
(
"CrossAttnUpBlock2D"
,
"UpBlock2D"
),
up_block_types
=
(
"CrossAttnUpBlock2D"
,
"UpBlock2D"
),
cross_attention_dim
=
32
,
cross_attention_dim
=
32
,
)
)
return
model
scheduler
=
DDIMScheduler
(
beta_start
=
0.00085
,
@
property
beta_end
=
0.012
,
def
dummy_vae
(
self
):
beta_schedule
=
"scaled_linear"
,
clip_sample
=
False
,
set_alpha_to_one
=
False
,
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
model
=
AutoencoderKL
(
vae
=
AutoencoderKL
(
block_out_channels
=
[
32
,
64
]
,
block_out_channels
=
(
32
,
64
)
,
in_channels
=
3
,
in_channels
=
3
,
out_channels
=
3
,
out_channels
=
3
,
down_block_types
=
[
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
]
,
down_block_types
=
(
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
)
,
up_block_types
=
[
"UpDecoderBlock2D"
,
"UpDecoderBlock2D"
]
,
up_block_types
=
(
"UpDecoderBlock2D"
,
"UpDecoderBlock2D"
)
,
latent_channels
=
4
,
latent_channels
=
4
,
)
)
return
model
@
property
def
dummy_text_encoder
(
self
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
config
=
CLIPTextConfig
(
text_encoder_
config
=
CLIPTextConfig
(
bos_token_id
=
0
,
bos_token_id
=
0
,
eos_token_id
=
2
,
eos_token_id
=
2
,
hidden_size
=
32
,
hidden_size
=
32
,
...
@@ -69,96 +73,117 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
...
@@ -69,96 +73,117 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
pad_token_id
=
1
,
pad_token_id
=
1
,
vocab_size
=
1000
,
vocab_size
=
1000
,
)
)
return
CLIPTextModel
(
config
)
text_encoder
=
CLIPTextModel
(
text_encoder_config
)
def
test_inference_text2img
(
self
):
if
torch_device
!=
"cpu"
:
return
unet
=
self
.
dummy_cond_unet
scheduler
=
DDIMScheduler
()
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
ldm
=
LDMTextToImagePipeline
(
vqvae
=
vae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
components
=
{
ldm
.
to
(
torch_device
)
"unet"
:
unet
,
ldm
.
set_progress_bar_config
(
disable
=
None
)
"scheduler"
:
scheduler
,
"vqvae"
:
vae
,
prompt
=
"A painting of a squirrel eating a burger"
"bert"
:
text_encoder
,
"tokenizer"
:
tokenizer
,
# Warmup pass when using mps (see #372)
}
if
torch_device
==
"mps"
:
return
components
generator
=
torch
.
manual_seed
(
0
)
_
=
ldm
(
def
get_dummy_inputs
(
self
,
device
,
seed
=
0
):
[
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
1
,
output_type
=
"numpy"
if
str
(
device
).
startswith
(
"mps"
):
).
images
generator
=
torch
.
manual_seed
(
seed
)
else
:
device
=
torch_device
if
torch_device
!=
"mps"
else
"cpu"
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
seed
)
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
inputs
=
{
"prompt"
:
"A painting of a squirrel eating a burger"
,
image
=
ldm
(
"generator"
:
generator
,
[
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
2
,
output_type
=
"numpy"
"num_inference_steps"
:
2
,
).
images
"guidance_scale"
:
6.0
,
"output_type"
:
"numpy"
,
device
=
torch_device
if
torch_device
!=
"mps"
else
"cpu"
}
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
return
inputs
image_from_tuple
=
ldm
(
[
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
2
,
output_type
=
"numpy"
,
return_dict
=
False
,
)[
0
]
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
image_from_tuple_slice
=
image_from_tuple
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
16
,
16
,
3
)
expected_slice
=
np
.
array
([
0.6806
,
0.5454
,
0.5638
,
0.4893
,
0.4656
,
0.4257
,
0.6248
,
0.5217
,
0.5498
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
@
slow
@
require_torch
class
LDMTextToImagePipelineIntegrationTests
(
unittest
.
TestCase
):
def
test_inference_text2img
(
self
):
def
test_inference_text2img
(
self
):
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
)
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
ldm
.
to
(
torch_device
)
ldm
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"A painting of a squirrel eating a burger"
device
=
torch_device
if
torch_device
!=
"mps"
else
"cpu"
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
image
=
ldm
(
components
=
self
.
get_dummy_components
()
[
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
20
,
output_type
=
"numpy"
pipe
=
LDMTextToImagePipeline
(
**
components
)
).
images
pipe
.
to
(
device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
assert
image
.
shape
==
(
1
,
16
,
16
,
3
)
expected_slice
=
np
.
array
([
0.9256
,
0.9340
,
0.8933
,
0.9361
,
0.9113
,
0.8727
,
0.9122
,
0.8745
,
0.8099
])
expected_slice
=
np
.
array
([
0.59450
,
0.64078
,
0.55509
,
0.51229
,
0.69640
,
0.36960
,
0.59296
,
0.60801
,
0.49332
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_inference_text2img_fast
(
self
):
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
)
ldm
.
to
(
torch_device
)
ldm
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"A painting of a squirrel eating a burger"
device
=
torch_device
if
torch_device
!=
"mps"
else
"cpu"
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-3
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
image
=
ldm
(
prompt
,
generator
=
generator
,
num_inference_steps
=
1
,
output_type
=
"numpy"
).
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
@
slow
@
require_torch_gpu
class
LDMTextToImagePipelineSlowTests
(
unittest
.
TestCase
):
def
tearDown
(
self
):
super
().
tearDown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
get_inputs
(
self
,
device
,
dtype
=
torch
.
float32
,
seed
=
0
):
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
seed
)
latents
=
np
.
random
.
RandomState
(
seed
).
standard_normal
((
1
,
4
,
32
,
32
))
latents
=
torch
.
from_numpy
(
latents
).
to
(
device
=
device
,
dtype
=
dtype
)
inputs
=
{
"prompt"
:
"A painting of a squirrel eating a burger"
,
"latents"
:
latents
,
"generator"
:
generator
,
"num_inference_steps"
:
3
,
"guidance_scale"
:
6.0
,
"output_type"
:
"numpy"
,
}
return
inputs
def
test_ldm_default_ddim
(
self
):
pipe
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
).
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_inputs
(
torch_device
)
image
=
pipe
(
**
inputs
).
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
()
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
np
.
array
([
0.3163
,
0.8670
,
0.6465
,
0.1865
,
0.6291
,
0.5139
,
0.2824
,
0.3723
,
0.4344
])
expected_slice
=
np
.
array
([
0.51825
,
0.52850
,
0.52543
,
0.54258
,
0.52304
,
0.52569
,
0.54363
,
0.55276
,
0.56878
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
max_diff
=
np
.
abs
(
expected_slice
-
image_slice
).
max
()
assert
max_diff
<
1e-3
@
nightly
@
require_torch_gpu
class
LDMTextToImagePipelineNightlyTests
(
unittest
.
TestCase
):
def
tearDown
(
self
):
super
().
tearDown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
get_inputs
(
self
,
device
,
dtype
=
torch
.
float32
,
seed
=
0
):
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
seed
)
latents
=
np
.
random
.
RandomState
(
seed
).
standard_normal
((
1
,
4
,
32
,
32
))
latents
=
torch
.
from_numpy
(
latents
).
to
(
device
=
device
,
dtype
=
dtype
)
inputs
=
{
"prompt"
:
"A painting of a squirrel eating a burger"
,
"latents"
:
latents
,
"generator"
:
generator
,
"num_inference_steps"
:
50
,
"guidance_scale"
:
6.0
,
"output_type"
:
"numpy"
,
}
return
inputs
def
test_ldm_default_ddim
(
self
):
pipe
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
).
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_inputs
(
torch_device
)
image
=
pipe
(
**
inputs
).
images
[
0
]
expected_image
=
load_numpy
(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/ldm_text2img/ldm_large_256_ddim.npy"
)
max_diff
=
np
.
abs
(
expected_image
-
image
).
max
()
assert
max_diff
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment