Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c2a38ef9
Unverified
Commit
c2a38ef9
authored
Dec 18, 2022
by
Anton Lozhkov
Committed by
GitHub
Dec 18, 2022
Browse files
Fix/update the LDM pipeline and tests (#1743)
* Fix/update LDM tests * batched generators
parent
08cc36dd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
145 additions
and
107 deletions
+145
-107
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+22
-9
tests/pipelines/latent_diffusion/test_latent_diffusion.py
tests/pipelines/latent_diffusion/test_latent_diffusion.py
+123
-98
No files found.
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
c2a38ef9
...
@@ -128,29 +128,42 @@ class LDMTextToImagePipeline(DiffusionPipeline):
...
@@ -128,29 +128,42 @@ class LDMTextToImagePipeline(DiffusionPipeline):
# get unconditional embeddings for classifier free guidance
# get unconditional embeddings for classifier free guidance
if
guidance_scale
!=
1.0
:
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
]
*
batch_size
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
uncond_input
=
self
.
tokenizer
(
[
""
]
*
batch_size
,
padding
=
"max_length"
,
max_length
=
77
,
truncation
=
True
,
return_tensors
=
"pt"
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
# get prompt text embeddings
# get prompt text embeddings
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
truncation
=
True
,
return_tensors
=
"pt"
)
text_embeddings
=
self
.
bert
(
text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
text_embeddings
=
self
.
bert
(
text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
# get the initial random noise unless the user supplied it
# get the initial random noise unless the user supplied it
latents_shape
=
(
batch_size
,
self
.
unet
.
in_channels
,
height
//
8
,
width
//
8
)
latents_shape
=
(
batch_size
,
self
.
unet
.
in_channels
,
height
//
8
,
width
//
8
)
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
raise
ValueError
(
f
"You have passed a list of generators of length
{
len
(
generator
)
}
, but requested an effective batch"
f
" size of
{
batch_size
}
. Make sure the batch size matches the length of the generators."
)
if
latents
is
None
:
if
latents
is
None
:
if
self
.
device
.
type
==
"mps"
:
rand_device
=
"cpu"
if
self
.
device
.
type
==
"mps"
else
self
.
device
# randn does not work reproducibly on mps
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
"cpu"
).
to
(
self
.
device
)
if
isinstance
(
generator
,
list
):
latents_shape
=
(
1
,)
+
latents_shape
[
1
:]
latents
=
[
torch
.
randn
(
latents_shape
,
generator
=
generator
[
i
],
device
=
rand_device
,
dtype
=
text_embeddings
.
dtype
)
for
i
in
range
(
batch_size
)
]
latents
=
torch
.
cat
(
latents
,
dim
=
0
)
else
:
else
:
latents
=
torch
.
randn
(
latents
=
torch
.
randn
(
latents_shape
,
latents_shape
,
generator
=
generator
,
device
=
rand_device
,
dtype
=
text_embeddings
.
dtype
generator
=
generator
,
device
=
self
.
device
,
)
)
latents
=
latents
.
to
(
self
.
device
)
else
:
else
:
if
latents
.
shape
!=
latents_shape
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
latents
=
latents
.
to
(
self
.
device
)
latents
=
latents
.
to
(
self
.
device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
...
...
tests/pipelines/latent_diffusion/test_latent_diffusion.py
View file @
c2a38ef9
...
@@ -13,24 +13,29 @@
...
@@ -13,24 +13,29 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
gc
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
diffusers
import
AutoencoderKL
,
DDIMScheduler
,
LDMTextToImagePipeline
,
UNet2DConditionModel
from
diffusers
import
AutoencoderKL
,
DDIMScheduler
,
LDMTextToImagePipeline
,
UNet2DConditionModel
from
diffusers.utils.testing_utils
import
require_torch
,
slow
,
torch_device
from
diffusers.utils.testing_utils
import
load_numpy
,
nightly
,
require_torch
_gpu
,
slow
,
torch_device
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
from
...test_pipelines_common
import
PipelineTesterMixin
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
LDMTextToImagePipelineFastTests
(
unittest
.
TestCase
):
class
LDMTextToImagePipelineFastTests
(
PipelineTesterMixin
,
unittest
.
TestCase
):
@
property
pipeline_class
=
LDMTextToImagePipeline
def
dummy_cond_unet
(
self
):
test_cpu_offload
=
False
def
get_dummy_components
(
self
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
model
=
UNet2DConditionModel
(
unet
=
UNet2DConditionModel
(
block_out_channels
=
(
32
,
64
),
block_out_channels
=
(
32
,
64
),
layers_per_block
=
2
,
layers_per_block
=
2
,
sample_size
=
32
,
sample_size
=
32
,
...
@@ -40,25 +45,24 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
...
@@ -40,25 +45,24 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
up_block_types
=
(
"CrossAttnUpBlock2D"
,
"UpBlock2D"
),
up_block_types
=
(
"CrossAttnUpBlock2D"
,
"UpBlock2D"
),
cross_attention_dim
=
32
,
cross_attention_dim
=
32
,
)
)
return
model
scheduler
=
DDIMScheduler
(
beta_start
=
0.00085
,
@
property
beta_end
=
0.012
,
def
dummy_vae
(
self
):
beta_schedule
=
"scaled_linear"
,
clip_sample
=
False
,
set_alpha_to_one
=
False
,
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
model
=
AutoencoderKL
(
vae
=
AutoencoderKL
(
block_out_channels
=
[
32
,
64
]
,
block_out_channels
=
(
32
,
64
)
,
in_channels
=
3
,
in_channels
=
3
,
out_channels
=
3
,
out_channels
=
3
,
down_block_types
=
[
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
]
,
down_block_types
=
(
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
)
,
up_block_types
=
[
"UpDecoderBlock2D"
,
"UpDecoderBlock2D"
]
,
up_block_types
=
(
"UpDecoderBlock2D"
,
"UpDecoderBlock2D"
)
,
latent_channels
=
4
,
latent_channels
=
4
,
)
)
return
model
@
property
def
dummy_text_encoder
(
self
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
config
=
CLIPTextConfig
(
text_encoder_
config
=
CLIPTextConfig
(
bos_token_id
=
0
,
bos_token_id
=
0
,
eos_token_id
=
2
,
eos_token_id
=
2
,
hidden_size
=
32
,
hidden_size
=
32
,
...
@@ -69,96 +73,117 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
...
@@ -69,96 +73,117 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
pad_token_id
=
1
,
pad_token_id
=
1
,
vocab_size
=
1000
,
vocab_size
=
1000
,
)
)
return
CLIPTextModel
(
config
)
text_encoder
=
CLIPTextModel
(
text_encoder_config
)
def
test_inference_text2img
(
self
):
if
torch_device
!=
"cpu"
:
return
unet
=
self
.
dummy_cond_unet
scheduler
=
DDIMScheduler
()
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
ldm
=
LDMTextToImagePipeline
(
vqvae
=
vae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
components
=
{
ldm
.
to
(
torch_device
)
"unet"
:
unet
,
ldm
.
set_progress_bar_config
(
disable
=
None
)
"scheduler"
:
scheduler
,
"vqvae"
:
vae
,
prompt
=
"A painting of a squirrel eating a burger"
"bert"
:
text_encoder
,
"tokenizer"
:
tokenizer
,
# Warmup pass when using mps (see #372)
}
if
torch_device
==
"mps"
:
return
components
generator
=
torch
.
manual_seed
(
0
)
_
=
ldm
(
def
get_dummy_inputs
(
self
,
device
,
seed
=
0
):
[
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
1
,
output_type
=
"numpy"
if
str
(
device
).
startswith
(
"mps"
):
).
images
generator
=
torch
.
manual_seed
(
seed
)
else
:
device
=
torch_device
if
torch_device
!=
"mps"
else
"cpu"
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
seed
)
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
inputs
=
{
"prompt"
:
"A painting of a squirrel eating a burger"
,
image
=
ldm
(
"generator"
:
generator
,
[
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
2
,
output_type
=
"numpy"
"num_inference_steps"
:
2
,
).
images
"guidance_scale"
:
6.0
,
"output_type"
:
"numpy"
,
device
=
torch_device
if
torch_device
!=
"mps"
else
"cpu"
}
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
return
inputs
image_from_tuple
=
ldm
(
[
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
2
,
output_type
=
"numpy"
,
return_dict
=
False
,
)[
0
]
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
image_from_tuple_slice
=
image_from_tuple
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
16
,
16
,
3
)
expected_slice
=
np
.
array
([
0.6806
,
0.5454
,
0.5638
,
0.4893
,
0.4656
,
0.4257
,
0.6248
,
0.5217
,
0.5498
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
@
slow
@
require_torch
class
LDMTextToImagePipelineIntegrationTests
(
unittest
.
TestCase
):
def
test_inference_text2img
(
self
):
def
test_inference_text2img
(
self
):
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
)
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
ldm
.
to
(
torch_device
)
ldm
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"A painting of a squirrel eating a burger"
device
=
torch_device
if
torch_device
!=
"mps"
else
"cpu"
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
image
=
ldm
(
components
=
self
.
get_dummy_components
()
[
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
20
,
output_type
=
"numpy"
pipe
=
LDMTextToImagePipeline
(
**
components
)
).
images
pipe
.
to
(
device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
assert
image
.
shape
==
(
1
,
16
,
16
,
3
)
expected_slice
=
np
.
array
([
0.9256
,
0.9340
,
0.8933
,
0.9361
,
0.9113
,
0.8727
,
0.9122
,
0.8745
,
0.8099
])
expected_slice
=
np
.
array
([
0.59450
,
0.64078
,
0.55509
,
0.51229
,
0.69640
,
0.36960
,
0.59296
,
0.60801
,
0.49332
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_inference_text2img_fast
(
self
):
ldm
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
)
ldm
.
to
(
torch_device
)
ldm
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"A painting of a squirrel eating a burger"
device
=
torch_device
if
torch_device
!=
"mps"
else
"cpu"
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-3
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
image
=
ldm
(
prompt
,
generator
=
generator
,
num_inference_steps
=
1
,
output_type
=
"numpy"
).
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
@
slow
@
require_torch_gpu
class
LDMTextToImagePipelineSlowTests
(
unittest
.
TestCase
):
def
tearDown
(
self
):
super
().
tearDown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
get_inputs
(
self
,
device
,
dtype
=
torch
.
float32
,
seed
=
0
):
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
seed
)
latents
=
np
.
random
.
RandomState
(
seed
).
standard_normal
((
1
,
4
,
32
,
32
))
latents
=
torch
.
from_numpy
(
latents
).
to
(
device
=
device
,
dtype
=
dtype
)
inputs
=
{
"prompt"
:
"A painting of a squirrel eating a burger"
,
"latents"
:
latents
,
"generator"
:
generator
,
"num_inference_steps"
:
3
,
"guidance_scale"
:
6.0
,
"output_type"
:
"numpy"
,
}
return
inputs
def
test_ldm_default_ddim
(
self
):
pipe
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
).
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_inputs
(
torch_device
)
image
=
pipe
(
**
inputs
).
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
()
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
np
.
array
([
0.3163
,
0.8670
,
0.6465
,
0.1865
,
0.6291
,
0.5139
,
0.2824
,
0.3723
,
0.4344
])
expected_slice
=
np
.
array
([
0.51825
,
0.52850
,
0.52543
,
0.54258
,
0.52304
,
0.52569
,
0.54363
,
0.55276
,
0.56878
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
max_diff
=
np
.
abs
(
expected_slice
-
image_slice
).
max
()
assert
max_diff
<
1e-3
@
nightly
@
require_torch_gpu
class
LDMTextToImagePipelineNightlyTests
(
unittest
.
TestCase
):
def
tearDown
(
self
):
super
().
tearDown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
get_inputs
(
self
,
device
,
dtype
=
torch
.
float32
,
seed
=
0
):
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
seed
)
latents
=
np
.
random
.
RandomState
(
seed
).
standard_normal
((
1
,
4
,
32
,
32
))
latents
=
torch
.
from_numpy
(
latents
).
to
(
device
=
device
,
dtype
=
dtype
)
inputs
=
{
"prompt"
:
"A painting of a squirrel eating a burger"
,
"latents"
:
latents
,
"generator"
:
generator
,
"num_inference_steps"
:
50
,
"guidance_scale"
:
6.0
,
"output_type"
:
"numpy"
,
}
return
inputs
def
test_ldm_default_ddim
(
self
):
pipe
=
LDMTextToImagePipeline
.
from_pretrained
(
"CompVis/ldm-text2im-large-256"
).
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_inputs
(
torch_device
)
image
=
pipe
(
**
inputs
).
images
[
0
]
expected_image
=
load_numpy
(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/ldm_text2img/ldm_large_256_ddim.npy"
)
max_diff
=
np
.
abs
(
expected_image
-
image
).
max
()
assert
max_diff
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment