Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6e9730d7
Unverified
Commit
6e9730d7
authored
Nov 08, 2022
by
Fazzie-Maqianli
Committed by
GitHub
Nov 08, 2022
Browse files
[example] add stable diffuser (#1825)
parent
b1263d32
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1610 additions
and
0 deletions
+1610
-0
examples/images/diffusion/scripts/img2img.py
examples/images/diffusion/scripts/img2img.py
+293
-0
examples/images/diffusion/scripts/inpaint.py
examples/images/diffusion/scripts/inpaint.py
+98
-0
examples/images/diffusion/scripts/knn2img.py
examples/images/diffusion/scripts/knn2img.py
+398
-0
examples/images/diffusion/scripts/sample_diffusion.py
examples/images/diffusion/scripts/sample_diffusion.py
+313
-0
examples/images/diffusion/scripts/train_searcher.py
examples/images/diffusion/scripts/train_searcher.py
+147
-0
examples/images/diffusion/scripts/txt2img.py
examples/images/diffusion/scripts/txt2img.py
+344
-0
examples/images/diffusion/setup.py
examples/images/diffusion/setup.py
+13
-0
examples/images/diffusion/train.sh
examples/images/diffusion/train.sh
+4
-0
No files found.
examples/images/diffusion/scripts/img2img.py
0 → 100644
View file @
6e9730d7
"""make variations of input image"""
import
argparse
,
os
,
sys
,
glob
import
PIL
import
torch
import
numpy
as
np
from
omegaconf
import
OmegaConf
from
PIL
import
Image
from
tqdm
import
tqdm
,
trange
from
itertools
import
islice
from
einops
import
rearrange
,
repeat
from
torchvision.utils
import
make_grid
from
torch
import
autocast
from
contextlib
import
nullcontext
import
time
from
pytorch_lightning
import
seed_everything
from
ldm.util
import
instantiate_from_config
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
def
chunk
(
it
,
size
):
it
=
iter
(
it
)
return
iter
(
lambda
:
tuple
(
islice
(
it
,
size
)),
())
def
load_model_from_config
(
config
,
ckpt
,
verbose
=
False
):
print
(
f
"Loading model from
{
ckpt
}
"
)
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
)
if
"global_step"
in
pl_sd
:
print
(
f
"Global Step:
{
pl_sd
[
'global_step'
]
}
"
)
sd
=
pl_sd
[
"state_dict"
]
model
=
instantiate_from_config
(
config
.
model
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
if
len
(
m
)
>
0
and
verbose
:
print
(
"missing keys:"
)
print
(
m
)
if
len
(
u
)
>
0
and
verbose
:
print
(
"unexpected keys:"
)
print
(
u
)
model
.
cuda
()
model
.
eval
()
return
model
def
load_img
(
path
):
image
=
Image
.
open
(
path
).
convert
(
"RGB"
)
w
,
h
=
image
.
size
print
(
f
"loaded input image of size (
{
w
}
,
{
h
}
) from
{
path
}
"
)
w
,
h
=
map
(
lambda
x
:
x
-
x
%
32
,
(
w
,
h
))
# resize to integer multiple of 32
image
=
image
.
resize
((
w
,
h
),
resample
=
PIL
.
Image
.
LANCZOS
)
image
=
np
.
array
(
image
).
astype
(
np
.
float32
)
/
255.0
image
=
image
[
None
].
transpose
(
0
,
3
,
1
,
2
)
image
=
torch
.
from_numpy
(
image
)
return
2.
*
image
-
1.
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
nargs
=
"?"
,
default
=
"a painting of a virus monster playing guitar"
,
help
=
"the prompt to render"
)
parser
.
add_argument
(
"--init-img"
,
type
=
str
,
nargs
=
"?"
,
help
=
"path to the input image"
)
parser
.
add_argument
(
"--outdir"
,
type
=
str
,
nargs
=
"?"
,
help
=
"dir to write results to"
,
default
=
"outputs/img2img-samples"
)
parser
.
add_argument
(
"--skip_grid"
,
action
=
'store_true'
,
help
=
"do not save a grid, only individual samples. Helpful when evaluating lots of samples"
,
)
parser
.
add_argument
(
"--skip_save"
,
action
=
'store_true'
,
help
=
"do not save indiviual samples. For speed measurements."
,
)
parser
.
add_argument
(
"--ddim_steps"
,
type
=
int
,
default
=
50
,
help
=
"number of ddim sampling steps"
,
)
parser
.
add_argument
(
"--plms"
,
action
=
'store_true'
,
help
=
"use plms sampling"
,
)
parser
.
add_argument
(
"--fixed_code"
,
action
=
'store_true'
,
help
=
"if enabled, uses the same starting code across all samples "
,
)
parser
.
add_argument
(
"--ddim_eta"
,
type
=
float
,
default
=
0.0
,
help
=
"ddim eta (eta=0.0 corresponds to deterministic sampling"
,
)
parser
.
add_argument
(
"--n_iter"
,
type
=
int
,
default
=
1
,
help
=
"sample this often"
,
)
parser
.
add_argument
(
"--C"
,
type
=
int
,
default
=
4
,
help
=
"latent channels"
,
)
parser
.
add_argument
(
"--f"
,
type
=
int
,
default
=
8
,
help
=
"downsampling factor, most often 8 or 16"
,
)
parser
.
add_argument
(
"--n_samples"
,
type
=
int
,
default
=
2
,
help
=
"how many samples to produce for each given prompt. A.k.a batch size"
,
)
parser
.
add_argument
(
"--n_rows"
,
type
=
int
,
default
=
0
,
help
=
"rows in the grid (default: n_samples)"
,
)
parser
.
add_argument
(
"--scale"
,
type
=
float
,
default
=
5.0
,
help
=
"unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))"
,
)
parser
.
add_argument
(
"--strength"
,
type
=
float
,
default
=
0.75
,
help
=
"strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image"
,
)
parser
.
add_argument
(
"--from-file"
,
type
=
str
,
help
=
"if specified, load prompts from this file"
,
)
parser
.
add_argument
(
"--config"
,
type
=
str
,
default
=
"configs/stable-diffusion/v1-inference.yaml"
,
help
=
"path to config which constructs model"
,
)
parser
.
add_argument
(
"--ckpt"
,
type
=
str
,
default
=
"models/ldm/stable-diffusion-v1/model.ckpt"
,
help
=
"path to checkpoint of model"
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"the seed (for reproducible sampling)"
,
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
help
=
"evaluate at this precision"
,
choices
=
[
"full"
,
"autocast"
],
default
=
"autocast"
)
opt
=
parser
.
parse_args
()
seed_everything
(
opt
.
seed
)
config
=
OmegaConf
.
load
(
f
"
{
opt
.
config
}
"
)
model
=
load_model_from_config
(
config
,
f
"
{
opt
.
ckpt
}
"
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
model
=
model
.
to
(
device
)
if
opt
.
plms
:
raise
NotImplementedError
(
"PLMS sampler not (yet) supported"
)
sampler
=
PLMSSampler
(
model
)
else
:
sampler
=
DDIMSampler
(
model
)
os
.
makedirs
(
opt
.
outdir
,
exist_ok
=
True
)
outpath
=
opt
.
outdir
batch_size
=
opt
.
n_samples
n_rows
=
opt
.
n_rows
if
opt
.
n_rows
>
0
else
batch_size
if
not
opt
.
from_file
:
prompt
=
opt
.
prompt
assert
prompt
is
not
None
data
=
[
batch_size
*
[
prompt
]]
else
:
print
(
f
"reading prompts from
{
opt
.
from_file
}
"
)
with
open
(
opt
.
from_file
,
"r"
)
as
f
:
data
=
f
.
read
().
splitlines
()
data
=
list
(
chunk
(
data
,
batch_size
))
sample_path
=
os
.
path
.
join
(
outpath
,
"samples"
)
os
.
makedirs
(
sample_path
,
exist_ok
=
True
)
base_count
=
len
(
os
.
listdir
(
sample_path
))
grid_count
=
len
(
os
.
listdir
(
outpath
))
-
1
assert
os
.
path
.
isfile
(
opt
.
init_img
)
init_image
=
load_img
(
opt
.
init_img
).
to
(
device
)
init_image
=
repeat
(
init_image
,
'1 ... -> b ...'
,
b
=
batch_size
)
init_latent
=
model
.
get_first_stage_encoding
(
model
.
encode_first_stage
(
init_image
))
# move to latent space
sampler
.
make_schedule
(
ddim_num_steps
=
opt
.
ddim_steps
,
ddim_eta
=
opt
.
ddim_eta
,
verbose
=
False
)
assert
0.
<=
opt
.
strength
<=
1.
,
'can only work with strength in [0.0, 1.0]'
t_enc
=
int
(
opt
.
strength
*
opt
.
ddim_steps
)
print
(
f
"target t_enc is
{
t_enc
}
steps"
)
precision_scope
=
autocast
if
opt
.
precision
==
"autocast"
else
nullcontext
with
torch
.
no_grad
():
with
precision_scope
(
"cuda"
):
with
model
.
ema_scope
():
tic
=
time
.
time
()
all_samples
=
list
()
for
n
in
trange
(
opt
.
n_iter
,
desc
=
"Sampling"
):
for
prompts
in
tqdm
(
data
,
desc
=
"data"
):
uc
=
None
if
opt
.
scale
!=
1.0
:
uc
=
model
.
get_learned_conditioning
(
batch_size
*
[
""
])
if
isinstance
(
prompts
,
tuple
):
prompts
=
list
(
prompts
)
c
=
model
.
get_learned_conditioning
(
prompts
)
# encode (scaled latent)
z_enc
=
sampler
.
stochastic_encode
(
init_latent
,
torch
.
tensor
([
t_enc
]
*
batch_size
).
to
(
device
))
# decode it
samples
=
sampler
.
decode
(
z_enc
,
c
,
t_enc
,
unconditional_guidance_scale
=
opt
.
scale
,
unconditional_conditioning
=
uc
,)
x_samples
=
model
.
decode_first_stage
(
samples
)
x_samples
=
torch
.
clamp
((
x_samples
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
if
not
opt
.
skip_save
:
for
x_sample
in
x_samples
:
x_sample
=
255.
*
rearrange
(
x_sample
.
cpu
().
numpy
(),
'c h w -> h w c'
)
Image
.
fromarray
(
x_sample
.
astype
(
np
.
uint8
)).
save
(
os
.
path
.
join
(
sample_path
,
f
"
{
base_count
:
05
}
.png"
))
base_count
+=
1
all_samples
.
append
(
x_samples
)
if
not
opt
.
skip_grid
:
# additionally, save as grid
grid
=
torch
.
stack
(
all_samples
,
0
)
grid
=
rearrange
(
grid
,
'n b c h w -> (n b) c h w'
)
grid
=
make_grid
(
grid
,
nrow
=
n_rows
)
# to image
grid
=
255.
*
rearrange
(
grid
,
'c h w -> h w c'
).
cpu
().
numpy
()
Image
.
fromarray
(
grid
.
astype
(
np
.
uint8
)).
save
(
os
.
path
.
join
(
outpath
,
f
'grid-
{
grid_count
:
04
}
.png'
))
grid_count
+=
1
toc
=
time
.
time
()
print
(
f
"Your samples are ready and waiting for you here:
\n
{
outpath
}
\n
"
f
"
\n
Enjoy."
)
if
__name__
==
"__main__"
:
main
()
examples/images/diffusion/scripts/inpaint.py
0 → 100644
View file @
6e9730d7
import
argparse
,
os
,
sys
,
glob
from
omegaconf
import
OmegaConf
from
PIL
import
Image
from
tqdm
import
tqdm
import
numpy
as
np
import
torch
from
main
import
instantiate_from_config
from
ldm.models.diffusion.ddim
import
DDIMSampler
def
make_batch
(
image
,
mask
,
device
):
image
=
np
.
array
(
Image
.
open
(
image
).
convert
(
"RGB"
))
image
=
image
.
astype
(
np
.
float32
)
/
255.0
image
=
image
[
None
].
transpose
(
0
,
3
,
1
,
2
)
image
=
torch
.
from_numpy
(
image
)
mask
=
np
.
array
(
Image
.
open
(
mask
).
convert
(
"L"
))
mask
=
mask
.
astype
(
np
.
float32
)
/
255.0
mask
=
mask
[
None
,
None
]
mask
[
mask
<
0.5
]
=
0
mask
[
mask
>=
0.5
]
=
1
mask
=
torch
.
from_numpy
(
mask
)
masked_image
=
(
1
-
mask
)
*
image
batch
=
{
"image"
:
image
,
"mask"
:
mask
,
"masked_image"
:
masked_image
}
for
k
in
batch
:
batch
[
k
]
=
batch
[
k
].
to
(
device
=
device
)
batch
[
k
]
=
batch
[
k
]
*
2.0
-
1.0
return
batch
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--indir"
,
type
=
str
,
nargs
=
"?"
,
help
=
"dir containing image-mask pairs (`example.png` and `example_mask.png`)"
,
)
parser
.
add_argument
(
"--outdir"
,
type
=
str
,
nargs
=
"?"
,
help
=
"dir to write results to"
,
)
parser
.
add_argument
(
"--steps"
,
type
=
int
,
default
=
50
,
help
=
"number of ddim sampling steps"
,
)
opt
=
parser
.
parse_args
()
masks
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
opt
.
indir
,
"*_mask.png"
)))
images
=
[
x
.
replace
(
"_mask.png"
,
".png"
)
for
x
in
masks
]
print
(
f
"Found
{
len
(
masks
)
}
inputs."
)
config
=
OmegaConf
.
load
(
"models/ldm/inpainting_big/config.yaml"
)
model
=
instantiate_from_config
(
config
.
model
)
model
.
load_state_dict
(
torch
.
load
(
"models/ldm/inpainting_big/last.ckpt"
)[
"state_dict"
],
strict
=
False
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
model
=
model
.
to
(
device
)
sampler
=
DDIMSampler
(
model
)
os
.
makedirs
(
opt
.
outdir
,
exist_ok
=
True
)
with
torch
.
no_grad
():
with
model
.
ema_scope
():
for
image
,
mask
in
tqdm
(
zip
(
images
,
masks
)):
outpath
=
os
.
path
.
join
(
opt
.
outdir
,
os
.
path
.
split
(
image
)[
1
])
batch
=
make_batch
(
image
,
mask
,
device
=
device
)
# encode masked image and concat downsampled mask
c
=
model
.
cond_stage_model
.
encode
(
batch
[
"masked_image"
])
cc
=
torch
.
nn
.
functional
.
interpolate
(
batch
[
"mask"
],
size
=
c
.
shape
[
-
2
:])
c
=
torch
.
cat
((
c
,
cc
),
dim
=
1
)
shape
=
(
c
.
shape
[
1
]
-
1
,)
+
c
.
shape
[
2
:]
samples_ddim
,
_
=
sampler
.
sample
(
S
=
opt
.
steps
,
conditioning
=
c
,
batch_size
=
c
.
shape
[
0
],
shape
=
shape
,
verbose
=
False
)
x_samples_ddim
=
model
.
decode_first_stage
(
samples_ddim
)
image
=
torch
.
clamp
((
batch
[
"image"
]
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
mask
=
torch
.
clamp
((
batch
[
"mask"
]
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
predicted_image
=
torch
.
clamp
((
x_samples_ddim
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
inpainted
=
(
1
-
mask
)
*
image
+
mask
*
predicted_image
inpainted
=
inpainted
.
cpu
().
numpy
().
transpose
(
0
,
2
,
3
,
1
)[
0
]
*
255
Image
.
fromarray
(
inpainted
.
astype
(
np
.
uint8
)).
save
(
outpath
)
examples/images/diffusion/scripts/knn2img.py
0 → 100644
View file @
6e9730d7
import
argparse
,
os
,
sys
,
glob
import
clip
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
omegaconf
import
OmegaConf
from
PIL
import
Image
from
tqdm
import
tqdm
,
trange
from
itertools
import
islice
from
einops
import
rearrange
,
repeat
from
torchvision.utils
import
make_grid
import
scann
import
time
from
multiprocessing
import
cpu_count
from
ldm.util
import
instantiate_from_config
,
parallel_data_prefetch
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
from
ldm.modules.encoders.modules
import
FrozenClipImageEmbedder
,
FrozenCLIPTextEmbedder
DATABASES
=
[
"openimages"
,
"artbench-art_nouveau"
,
"artbench-baroque"
,
"artbench-expressionism"
,
"artbench-impressionism"
,
"artbench-post_impressionism"
,
"artbench-realism"
,
"artbench-romanticism"
,
"artbench-renaissance"
,
"artbench-surrealism"
,
"artbench-ukiyo_e"
,
]
def
chunk
(
it
,
size
):
it
=
iter
(
it
)
return
iter
(
lambda
:
tuple
(
islice
(
it
,
size
)),
())
def
load_model_from_config
(
config
,
ckpt
,
verbose
=
False
):
print
(
f
"Loading model from
{
ckpt
}
"
)
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
)
if
"global_step"
in
pl_sd
:
print
(
f
"Global Step:
{
pl_sd
[
'global_step'
]
}
"
)
sd
=
pl_sd
[
"state_dict"
]
model
=
instantiate_from_config
(
config
.
model
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
if
len
(
m
)
>
0
and
verbose
:
print
(
"missing keys:"
)
print
(
m
)
if
len
(
u
)
>
0
and
verbose
:
print
(
"unexpected keys:"
)
print
(
u
)
model
.
cuda
()
model
.
eval
()
return
model
class
Searcher
(
object
):
def
__init__
(
self
,
database
,
retriever_version
=
'ViT-L/14'
):
assert
database
in
DATABASES
# self.database = self.load_database(database)
self
.
database_name
=
database
self
.
searcher_savedir
=
f
'data/rdm/searchers/
{
self
.
database_name
}
'
self
.
database_path
=
f
'data/rdm/retrieval_databases/
{
self
.
database_name
}
'
self
.
retriever
=
self
.
load_retriever
(
version
=
retriever_version
)
self
.
database
=
{
'embedding'
:
[],
'img_id'
:
[],
'patch_coords'
:
[]}
self
.
load_database
()
self
.
load_searcher
()
def
train_searcher
(
self
,
k
,
metric
=
'dot_product'
,
searcher_savedir
=
None
):
print
(
'Start training searcher'
)
searcher
=
scann
.
scann_ops_pybind
.
builder
(
self
.
database
[
'embedding'
]
/
np
.
linalg
.
norm
(
self
.
database
[
'embedding'
],
axis
=
1
)[:,
np
.
newaxis
],
k
,
metric
)
self
.
searcher
=
searcher
.
score_brute_force
().
build
()
print
(
'Finish training searcher'
)
if
searcher_savedir
is
not
None
:
print
(
f
'Save trained searcher under "
{
searcher_savedir
}
"'
)
os
.
makedirs
(
searcher_savedir
,
exist_ok
=
True
)
self
.
searcher
.
serialize
(
searcher_savedir
)
def
load_single_file
(
self
,
saved_embeddings
):
compressed
=
np
.
load
(
saved_embeddings
)
self
.
database
=
{
key
:
compressed
[
key
]
for
key
in
compressed
.
files
}
print
(
'Finished loading of clip embeddings.'
)
def
load_multi_files
(
self
,
data_archive
):
out_data
=
{
key
:
[]
for
key
in
self
.
database
}
for
d
in
tqdm
(
data_archive
,
desc
=
f
'Loading datapool from
{
len
(
data_archive
)
}
individual files.'
):
for
key
in
d
.
files
:
out_data
[
key
].
append
(
d
[
key
])
return
out_data
def
load_database
(
self
):
print
(
f
'Load saved patch embedding from "
{
self
.
database_path
}
"'
)
file_content
=
glob
.
glob
(
os
.
path
.
join
(
self
.
database_path
,
'*.npz'
))
if
len
(
file_content
)
==
1
:
self
.
load_single_file
(
file_content
[
0
])
elif
len
(
file_content
)
>
1
:
data
=
[
np
.
load
(
f
)
for
f
in
file_content
]
prefetched_data
=
parallel_data_prefetch
(
self
.
load_multi_files
,
data
,
n_proc
=
min
(
len
(
data
),
cpu_count
()),
target_data_type
=
'dict'
)
self
.
database
=
{
key
:
np
.
concatenate
([
od
[
key
]
for
od
in
prefetched_data
],
axis
=
1
)[
0
]
for
key
in
self
.
database
}
else
:
raise
ValueError
(
f
'No npz-files in specified path "
{
self
.
database_path
}
" is this directory existing?'
)
print
(
f
'Finished loading of retrieval database of length
{
self
.
database
[
"embedding"
].
shape
[
0
]
}
.'
)
def
load_retriever
(
self
,
version
=
'ViT-L/14'
,
):
model
=
FrozenClipImageEmbedder
(
model
=
version
)
if
torch
.
cuda
.
is_available
():
model
.
cuda
()
model
.
eval
()
return
model
def
load_searcher
(
self
):
print
(
f
'load searcher for database
{
self
.
database_name
}
from
{
self
.
searcher_savedir
}
'
)
self
.
searcher
=
scann
.
scann_ops_pybind
.
load_searcher
(
self
.
searcher_savedir
)
print
(
'Finished loading searcher.'
)
def
search
(
self
,
x
,
k
):
if
self
.
searcher
is
None
and
self
.
database
[
'embedding'
].
shape
[
0
]
<
2e4
:
self
.
train_searcher
(
k
)
# quickly fit searcher on the fly for small databases
assert
self
.
searcher
is
not
None
,
'Cannot search with uninitialized searcher'
if
isinstance
(
x
,
torch
.
Tensor
):
x
=
x
.
detach
().
cpu
().
numpy
()
if
len
(
x
.
shape
)
==
3
:
x
=
x
[:,
0
]
query_embeddings
=
x
/
np
.
linalg
.
norm
(
x
,
axis
=
1
)[:,
np
.
newaxis
]
start
=
time
.
time
()
nns
,
distances
=
self
.
searcher
.
search_batched
(
query_embeddings
,
final_num_neighbors
=
k
)
end
=
time
.
time
()
out_embeddings
=
self
.
database
[
'embedding'
][
nns
]
out_img_ids
=
self
.
database
[
'img_id'
][
nns
]
out_pc
=
self
.
database
[
'patch_coords'
][
nns
]
out
=
{
'nn_embeddings'
:
out_embeddings
/
np
.
linalg
.
norm
(
out_embeddings
,
axis
=-
1
)[...,
np
.
newaxis
],
'img_ids'
:
out_img_ids
,
'patch_coords'
:
out_pc
,
'queries'
:
x
,
'exec_time'
:
end
-
start
,
'nns'
:
nns
,
'q_embeddings'
:
query_embeddings
}
return
out
def
__call__
(
self
,
x
,
n
):
return
self
.
search
(
x
,
n
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
# TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
# TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
nargs
=
"?"
,
default
=
"a painting of a virus monster playing guitar"
,
help
=
"the prompt to render"
)
parser
.
add_argument
(
"--outdir"
,
type
=
str
,
nargs
=
"?"
,
help
=
"dir to write results to"
,
default
=
"outputs/txt2img-samples"
)
parser
.
add_argument
(
"--skip_grid"
,
action
=
'store_true'
,
help
=
"do not save a grid, only individual samples. Helpful when evaluating lots of samples"
,
)
parser
.
add_argument
(
"--ddim_steps"
,
type
=
int
,
default
=
50
,
help
=
"number of ddim sampling steps"
,
)
parser
.
add_argument
(
"--n_repeat"
,
type
=
int
,
default
=
1
,
help
=
"number of repeats in CLIP latent space"
,
)
parser
.
add_argument
(
"--plms"
,
action
=
'store_true'
,
help
=
"use plms sampling"
,
)
parser
.
add_argument
(
"--ddim_eta"
,
type
=
float
,
default
=
0.0
,
help
=
"ddim eta (eta=0.0 corresponds to deterministic sampling"
,
)
parser
.
add_argument
(
"--n_iter"
,
type
=
int
,
default
=
1
,
help
=
"sample this often"
,
)
parser
.
add_argument
(
"--H"
,
type
=
int
,
default
=
768
,
help
=
"image height, in pixel space"
,
)
parser
.
add_argument
(
"--W"
,
type
=
int
,
default
=
768
,
help
=
"image width, in pixel space"
,
)
parser
.
add_argument
(
"--n_samples"
,
type
=
int
,
default
=
3
,
help
=
"how many samples to produce for each given prompt. A.k.a batch size"
,
)
parser
.
add_argument
(
"--n_rows"
,
type
=
int
,
default
=
0
,
help
=
"rows in the grid (default: n_samples)"
,
)
parser
.
add_argument
(
"--scale"
,
type
=
float
,
default
=
5.0
,
help
=
"unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))"
,
)
parser
.
add_argument
(
"--from-file"
,
type
=
str
,
help
=
"if specified, load prompts from this file"
,
)
parser
.
add_argument
(
"--config"
,
type
=
str
,
default
=
"configs/retrieval-augmented-diffusion/768x768.yaml"
,
help
=
"path to config which constructs model"
,
)
parser
.
add_argument
(
"--ckpt"
,
type
=
str
,
default
=
"models/rdm/rdm768x768/model.ckpt"
,
help
=
"path to checkpoint of model"
,
)
parser
.
add_argument
(
"--clip_type"
,
type
=
str
,
default
=
"ViT-L/14"
,
help
=
"which CLIP model to use for retrieval and NN encoding"
,
)
parser
.
add_argument
(
"--database"
,
type
=
str
,
default
=
'artbench-surrealism'
,
choices
=
DATABASES
,
help
=
"The database used for the search, only applied when --use_neighbors=True"
,
)
parser
.
add_argument
(
"--use_neighbors"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Include neighbors in addition to text prompt for conditioning"
,
)
parser
.
add_argument
(
"--knn"
,
default
=
10
,
type
=
int
,
help
=
"The number of included neighbors, only applied when --use_neighbors=True"
,
)
opt
=
parser
.
parse_args
()
config
=
OmegaConf
.
load
(
f
"
{
opt
.
config
}
"
)
model
=
load_model_from_config
(
config
,
f
"
{
opt
.
ckpt
}
"
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
model
=
model
.
to
(
device
)
clip_text_encoder
=
FrozenCLIPTextEmbedder
(
opt
.
clip_type
).
to
(
device
)
if
opt
.
plms
:
sampler
=
PLMSSampler
(
model
)
else
:
sampler
=
DDIMSampler
(
model
)
os
.
makedirs
(
opt
.
outdir
,
exist_ok
=
True
)
outpath
=
opt
.
outdir
batch_size
=
opt
.
n_samples
n_rows
=
opt
.
n_rows
if
opt
.
n_rows
>
0
else
batch_size
if
not
opt
.
from_file
:
prompt
=
opt
.
prompt
assert
prompt
is
not
None
data
=
[
batch_size
*
[
prompt
]]
else
:
print
(
f
"reading prompts from
{
opt
.
from_file
}
"
)
with
open
(
opt
.
from_file
,
"r"
)
as
f
:
data
=
f
.
read
().
splitlines
()
data
=
list
(
chunk
(
data
,
batch_size
))
sample_path
=
os
.
path
.
join
(
outpath
,
"samples"
)
os
.
makedirs
(
sample_path
,
exist_ok
=
True
)
base_count
=
len
(
os
.
listdir
(
sample_path
))
grid_count
=
len
(
os
.
listdir
(
outpath
))
-
1
print
(
f
"sampling scale for cfg is
{
opt
.
scale
:.
2
f
}
"
)
searcher
=
None
if
opt
.
use_neighbors
:
searcher
=
Searcher
(
opt
.
database
)
with
torch
.
no_grad
():
with
model
.
ema_scope
():
for
n
in
trange
(
opt
.
n_iter
,
desc
=
"Sampling"
):
all_samples
=
list
()
for
prompts
in
tqdm
(
data
,
desc
=
"data"
):
print
(
"sampling prompts:"
,
prompts
)
if
isinstance
(
prompts
,
tuple
):
prompts
=
list
(
prompts
)
c
=
clip_text_encoder
.
encode
(
prompts
)
uc
=
None
if
searcher
is
not
None
:
nn_dict
=
searcher
(
c
,
opt
.
knn
)
c
=
torch
.
cat
([
c
,
torch
.
from_numpy
(
nn_dict
[
'nn_embeddings'
]).
cuda
()],
dim
=
1
)
if
opt
.
scale
!=
1.0
:
uc
=
torch
.
zeros_like
(
c
)
if
isinstance
(
prompts
,
tuple
):
prompts
=
list
(
prompts
)
shape
=
[
16
,
opt
.
H
//
16
,
opt
.
W
//
16
]
# note: currently hardcoded for f16 model
samples_ddim
,
_
=
sampler
.
sample
(
S
=
opt
.
ddim_steps
,
conditioning
=
c
,
batch_size
=
c
.
shape
[
0
],
shape
=
shape
,
verbose
=
False
,
unconditional_guidance_scale
=
opt
.
scale
,
unconditional_conditioning
=
uc
,
eta
=
opt
.
ddim_eta
,
)
x_samples_ddim
=
model
.
decode_first_stage
(
samples_ddim
)
x_samples_ddim
=
torch
.
clamp
((
x_samples_ddim
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
for
x_sample
in
x_samples_ddim
:
x_sample
=
255.
*
rearrange
(
x_sample
.
cpu
().
numpy
(),
'c h w -> h w c'
)
Image
.
fromarray
(
x_sample
.
astype
(
np
.
uint8
)).
save
(
os
.
path
.
join
(
sample_path
,
f
"
{
base_count
:
05
}
.png"
))
base_count
+=
1
all_samples
.
append
(
x_samples_ddim
)
if
not
opt
.
skip_grid
:
# additionally, save as grid
grid
=
torch
.
stack
(
all_samples
,
0
)
grid
=
rearrange
(
grid
,
'n b c h w -> (n b) c h w'
)
grid
=
make_grid
(
grid
,
nrow
=
n_rows
)
# to image
grid
=
255.
*
rearrange
(
grid
,
'c h w -> h w c'
).
cpu
().
numpy
()
Image
.
fromarray
(
grid
.
astype
(
np
.
uint8
)).
save
(
os
.
path
.
join
(
outpath
,
f
'grid-
{
grid_count
:
04
}
.png'
))
grid_count
+=
1
print
(
f
"Your samples are ready and waiting for you here:
\n
{
outpath
}
\n
Enjoy."
)
examples/images/diffusion/scripts/sample_diffusion.py
0 → 100644
View file @
6e9730d7
import
argparse
,
os
,
sys
,
glob
,
datetime
,
yaml
import
torch
import
time
import
numpy
as
np
from
tqdm
import
trange
from
omegaconf
import
OmegaConf
from
PIL
import
Image
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.util
import
instantiate_from_config
rescale
=
lambda
x
:
(
x
+
1.
)
/
2.
def
custom_to_pil
(
x
):
x
=
x
.
detach
().
cpu
()
x
=
torch
.
clamp
(
x
,
-
1.
,
1.
)
x
=
(
x
+
1.
)
/
2.
x
=
x
.
permute
(
1
,
2
,
0
).
numpy
()
x
=
(
255
*
x
).
astype
(
np
.
uint8
)
x
=
Image
.
fromarray
(
x
)
if
not
x
.
mode
==
"RGB"
:
x
=
x
.
convert
(
"RGB"
)
return
x
def
custom_to_np
(
x
):
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
sample
=
x
.
detach
().
cpu
()
sample
=
((
sample
+
1
)
*
127.5
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
)
sample
=
sample
.
permute
(
0
,
2
,
3
,
1
)
sample
=
sample
.
contiguous
()
return
sample
def
logs2pil
(
logs
,
keys
=
[
"sample"
]):
imgs
=
dict
()
for
k
in
logs
:
try
:
if
len
(
logs
[
k
].
shape
)
==
4
:
img
=
custom_to_pil
(
logs
[
k
][
0
,
...])
elif
len
(
logs
[
k
].
shape
)
==
3
:
img
=
custom_to_pil
(
logs
[
k
])
else
:
print
(
f
"Unknown format for key
{
k
}
. "
)
img
=
None
except
:
img
=
None
imgs
[
k
]
=
img
return
imgs
@
torch
.
no_grad
()
def
convsample
(
model
,
shape
,
return_intermediates
=
True
,
verbose
=
True
,
make_prog_row
=
False
):
if
not
make_prog_row
:
return
model
.
p_sample_loop
(
None
,
shape
,
return_intermediates
=
return_intermediates
,
verbose
=
verbose
)
else
:
return
model
.
progressive_denoising
(
None
,
shape
,
verbose
=
True
)
@
torch
.
no_grad
()
def
convsample_ddim
(
model
,
steps
,
shape
,
eta
=
1.0
):
ddim
=
DDIMSampler
(
model
)
bs
=
shape
[
0
]
shape
=
shape
[
1
:]
samples
,
intermediates
=
ddim
.
sample
(
steps
,
batch_size
=
bs
,
shape
=
shape
,
eta
=
eta
,
verbose
=
False
,)
return
samples
,
intermediates
@
torch
.
no_grad
()
def
make_convolutional_sample
(
model
,
batch_size
,
vanilla
=
False
,
custom_steps
=
None
,
eta
=
1.0
,):
log
=
dict
()
shape
=
[
batch_size
,
model
.
model
.
diffusion_model
.
in_channels
,
model
.
model
.
diffusion_model
.
image_size
,
model
.
model
.
diffusion_model
.
image_size
]
with
model
.
ema_scope
(
"Plotting"
):
t0
=
time
.
time
()
if
vanilla
:
sample
,
progrow
=
convsample
(
model
,
shape
,
make_prog_row
=
True
)
else
:
sample
,
intermediates
=
convsample_ddim
(
model
,
steps
=
custom_steps
,
shape
=
shape
,
eta
=
eta
)
t1
=
time
.
time
()
x_sample
=
model
.
decode_first_stage
(
sample
)
log
[
"sample"
]
=
x_sample
log
[
"time"
]
=
t1
-
t0
log
[
'throughput'
]
=
sample
.
shape
[
0
]
/
(
t1
-
t0
)
print
(
f
'Throughput for this batch:
{
log
[
"throughput"
]
}
'
)
return
log
def
run
(
model
,
logdir
,
batch_size
=
50
,
vanilla
=
False
,
custom_steps
=
None
,
eta
=
None
,
n_samples
=
50000
,
nplog
=
None
):
if
vanilla
:
print
(
f
'Using Vanilla DDPM sampling with
{
model
.
num_timesteps
}
sampling steps.'
)
else
:
print
(
f
'Using DDIM sampling with
{
custom_steps
}
sampling steps and eta=
{
eta
}
'
)
tstart
=
time
.
time
()
n_saved
=
len
(
glob
.
glob
(
os
.
path
.
join
(
logdir
,
'*.png'
)))
-
1
# path = logdir
if
model
.
cond_stage_model
is
None
:
all_images
=
[]
print
(
f
"Running unconditional sampling for
{
n_samples
}
samples"
)
for
_
in
trange
(
n_samples
//
batch_size
,
desc
=
"Sampling Batches (unconditional)"
):
logs
=
make_convolutional_sample
(
model
,
batch_size
=
batch_size
,
vanilla
=
vanilla
,
custom_steps
=
custom_steps
,
eta
=
eta
)
n_saved
=
save_logs
(
logs
,
logdir
,
n_saved
=
n_saved
,
key
=
"sample"
)
all_images
.
extend
([
custom_to_np
(
logs
[
"sample"
])])
if
n_saved
>=
n_samples
:
print
(
f
'Finish after generating
{
n_saved
}
samples'
)
break
all_img
=
np
.
concatenate
(
all_images
,
axis
=
0
)
all_img
=
all_img
[:
n_samples
]
shape_str
=
"x"
.
join
([
str
(
x
)
for
x
in
all_img
.
shape
])
nppath
=
os
.
path
.
join
(
nplog
,
f
"
{
shape_str
}
-samples.npz"
)
np
.
savez
(
nppath
,
all_img
)
else
:
raise
NotImplementedError
(
'Currently only sampling for unconditional models supported.'
)
print
(
f
"sampling of
{
n_saved
}
images finished in
{
(
time
.
time
()
-
tstart
)
/
60.
:.
2
f
}
minutes."
)
def
save_logs
(
logs
,
path
,
n_saved
=
0
,
key
=
"sample"
,
np_path
=
None
):
for
k
in
logs
:
if
k
==
key
:
batch
=
logs
[
key
]
if
np_path
is
None
:
for
x
in
batch
:
img
=
custom_to_pil
(
x
)
imgpath
=
os
.
path
.
join
(
path
,
f
"
{
key
}
_
{
n_saved
:
06
}
.png"
)
img
.
save
(
imgpath
)
n_saved
+=
1
else
:
npbatch
=
custom_to_np
(
batch
)
shape_str
=
"x"
.
join
([
str
(
x
)
for
x
in
npbatch
.
shape
])
nppath
=
os
.
path
.
join
(
np_path
,
f
"
{
n_saved
}
-
{
shape_str
}
-samples.npz"
)
np
.
savez
(
nppath
,
npbatch
)
n_saved
+=
npbatch
.
shape
[
0
]
return
n_saved
def
get_parser
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-r"
,
"--resume"
,
type
=
str
,
nargs
=
"?"
,
help
=
"load from logdir or checkpoint in logdir"
,
)
parser
.
add_argument
(
"-n"
,
"--n_samples"
,
type
=
int
,
nargs
=
"?"
,
help
=
"number of samples to draw"
,
default
=
50000
)
parser
.
add_argument
(
"-e"
,
"--eta"
,
type
=
float
,
nargs
=
"?"
,
help
=
"eta for ddim sampling (0.0 yields deterministic sampling)"
,
default
=
1.0
)
parser
.
add_argument
(
"-v"
,
"--vanilla_sample"
,
default
=
False
,
action
=
'store_true'
,
help
=
"vanilla sampling (default option is DDIM sampling)?"
,
)
parser
.
add_argument
(
"-l"
,
"--logdir"
,
type
=
str
,
nargs
=
"?"
,
help
=
"extra logdir"
,
default
=
"none"
)
parser
.
add_argument
(
"-c"
,
"--custom_steps"
,
type
=
int
,
nargs
=
"?"
,
help
=
"number of steps for ddim and fastdpm sampling"
,
default
=
50
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
nargs
=
"?"
,
help
=
"the bs"
,
default
=
10
)
return
parser
def
load_model_from_config
(
config
,
sd
):
model
=
instantiate_from_config
(
config
)
model
.
load_state_dict
(
sd
,
strict
=
False
)
model
.
cuda
()
model
.
eval
()
return
model
def
load_model
(
config
,
ckpt
,
gpu
,
eval_mode
):
if
ckpt
:
print
(
f
"Loading model from
{
ckpt
}
"
)
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
)
global_step
=
pl_sd
[
"global_step"
]
else
:
pl_sd
=
{
"state_dict"
:
None
}
global_step
=
None
model
=
load_model_from_config
(
config
.
model
,
pl_sd
[
"state_dict"
])
return
model
,
global_step
if
__name__
==
"__main__"
:
now
=
datetime
.
datetime
.
now
().
strftime
(
"%Y-%m-%d-%H-%M-%S"
)
sys
.
path
.
append
(
os
.
getcwd
())
command
=
" "
.
join
(
sys
.
argv
)
parser
=
get_parser
()
opt
,
unknown
=
parser
.
parse_known_args
()
ckpt
=
None
if
not
os
.
path
.
exists
(
opt
.
resume
):
raise
ValueError
(
"Cannot find {}"
.
format
(
opt
.
resume
))
if
os
.
path
.
isfile
(
opt
.
resume
):
# paths = opt.resume.split("/")
try
:
logdir
=
'/'
.
join
(
opt
.
resume
.
split
(
'/'
)[:
-
1
])
# idx = len(paths)-paths[::-1].index("logs")+1
print
(
f
'Logdir is
{
logdir
}
'
)
except
ValueError
:
paths
=
opt
.
resume
.
split
(
"/"
)
idx
=
-
2
# take a guess: path/to/logdir/checkpoints/model.ckpt
logdir
=
"/"
.
join
(
paths
[:
idx
])
ckpt
=
opt
.
resume
else
:
assert
os
.
path
.
isdir
(
opt
.
resume
),
f
"
{
opt
.
resume
}
is not a directory"
logdir
=
opt
.
resume
.
rstrip
(
"/"
)
ckpt
=
os
.
path
.
join
(
logdir
,
"model.ckpt"
)
base_configs
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
logdir
,
"config.yaml"
)))
opt
.
base
=
base_configs
configs
=
[
OmegaConf
.
load
(
cfg
)
for
cfg
in
opt
.
base
]
cli
=
OmegaConf
.
from_dotlist
(
unknown
)
config
=
OmegaConf
.
merge
(
*
configs
,
cli
)
gpu
=
True
eval_mode
=
True
if
opt
.
logdir
!=
"none"
:
locallog
=
logdir
.
split
(
os
.
sep
)[
-
1
]
if
locallog
==
""
:
locallog
=
logdir
.
split
(
os
.
sep
)[
-
2
]
print
(
f
"Switching logdir from '
{
logdir
}
' to '
{
os
.
path
.
join
(
opt
.
logdir
,
locallog
)
}
'"
)
logdir
=
os
.
path
.
join
(
opt
.
logdir
,
locallog
)
print
(
config
)
model
,
global_step
=
load_model
(
config
,
ckpt
,
gpu
,
eval_mode
)
print
(
f
"global step:
{
global_step
}
"
)
print
(
75
*
"="
)
print
(
"logging to:"
)
logdir
=
os
.
path
.
join
(
logdir
,
"samples"
,
f
"
{
global_step
:
08
}
"
,
now
)
imglogdir
=
os
.
path
.
join
(
logdir
,
"img"
)
numpylogdir
=
os
.
path
.
join
(
logdir
,
"numpy"
)
os
.
makedirs
(
imglogdir
)
os
.
makedirs
(
numpylogdir
)
print
(
logdir
)
print
(
75
*
"="
)
# write config out
sampling_file
=
os
.
path
.
join
(
logdir
,
"sampling_config.yaml"
)
sampling_conf
=
vars
(
opt
)
with
open
(
sampling_file
,
'w'
)
as
f
:
yaml
.
dump
(
sampling_conf
,
f
,
default_flow_style
=
False
)
print
(
sampling_conf
)
run
(
model
,
imglogdir
,
eta
=
opt
.
eta
,
vanilla
=
opt
.
vanilla_sample
,
n_samples
=
opt
.
n_samples
,
custom_steps
=
opt
.
custom_steps
,
batch_size
=
opt
.
batch_size
,
nplog
=
numpylogdir
)
print
(
"done."
)
examples/images/diffusion/scripts/train_searcher.py
0 → 100644
View file @
6e9730d7
import
os
,
sys
import
numpy
as
np
import
scann
import
argparse
import
glob
from
multiprocessing
import
cpu_count
from
tqdm
import
tqdm
from
ldm.util
import
parallel_data_prefetch
def
search_bruteforce
(
searcher
):
return
searcher
.
score_brute_force
().
build
()
def
search_partioned_ah
(
searcher
,
dims_per_block
,
aiq_threshold
,
reorder_k
,
partioning_trainsize
,
num_leaves
,
num_leaves_to_search
):
return
searcher
.
tree
(
num_leaves
=
num_leaves
,
num_leaves_to_search
=
num_leaves_to_search
,
training_sample_size
=
partioning_trainsize
).
\
score_ah
(
dims_per_block
,
anisotropic_quantization_threshold
=
aiq_threshold
).
reorder
(
reorder_k
).
build
()
def
search_ah
(
searcher
,
dims_per_block
,
aiq_threshold
,
reorder_k
):
return
searcher
.
score_ah
(
dims_per_block
,
anisotropic_quantization_threshold
=
aiq_threshold
).
reorder
(
reorder_k
).
build
()
def
load_datapool
(
dpath
):
def
load_single_file
(
saved_embeddings
):
compressed
=
np
.
load
(
saved_embeddings
)
database
=
{
key
:
compressed
[
key
]
for
key
in
compressed
.
files
}
return
database
def
load_multi_files
(
data_archive
):
database
=
{
key
:
[]
for
key
in
data_archive
[
0
].
files
}
for
d
in
tqdm
(
data_archive
,
desc
=
f
'Loading datapool from
{
len
(
data_archive
)
}
individual files.'
):
for
key
in
d
.
files
:
database
[
key
].
append
(
d
[
key
])
return
database
print
(
f
'Load saved patch embedding from "
{
dpath
}
"'
)
file_content
=
glob
.
glob
(
os
.
path
.
join
(
dpath
,
'*.npz'
))
if
len
(
file_content
)
==
1
:
data_pool
=
load_single_file
(
file_content
[
0
])
elif
len
(
file_content
)
>
1
:
data
=
[
np
.
load
(
f
)
for
f
in
file_content
]
prefetched_data
=
parallel_data_prefetch
(
load_multi_files
,
data
,
n_proc
=
min
(
len
(
data
),
cpu_count
()),
target_data_type
=
'dict'
)
data_pool
=
{
key
:
np
.
concatenate
([
od
[
key
]
for
od
in
prefetched_data
],
axis
=
1
)[
0
]
for
key
in
prefetched_data
[
0
].
keys
()}
else
:
raise
ValueError
(
f
'No npz-files in specified path "
{
dpath
}
" is this directory existing?'
)
print
(
f
'Finished loading of retrieval database of length
{
data_pool
[
"embedding"
].
shape
[
0
]
}
.'
)
return
data_pool
def
train_searcher
(
opt
,
metric
=
'dot_product'
,
partioning_trainsize
=
None
,
reorder_k
=
None
,
# todo tune
aiq_thld
=
0.2
,
dims_per_block
=
2
,
num_leaves
=
None
,
num_leaves_to_search
=
None
,):
data_pool
=
load_datapool
(
opt
.
database
)
k
=
opt
.
knn
if
not
reorder_k
:
reorder_k
=
2
*
k
# normalize
# embeddings =
searcher
=
scann
.
scann_ops_pybind
.
builder
(
data_pool
[
'embedding'
]
/
np
.
linalg
.
norm
(
data_pool
[
'embedding'
],
axis
=
1
)[:,
np
.
newaxis
],
k
,
metric
)
pool_size
=
data_pool
[
'embedding'
].
shape
[
0
]
print
(
*
([
'#'
]
*
100
))
print
(
'Initializing scaNN searcher with the following values:'
)
print
(
f
'k:
{
k
}
'
)
print
(
f
'metric:
{
metric
}
'
)
print
(
f
'reorder_k:
{
reorder_k
}
'
)
print
(
f
'anisotropic_quantization_threshold:
{
aiq_thld
}
'
)
print
(
f
'dims_per_block:
{
dims_per_block
}
'
)
print
(
*
([
'#'
]
*
100
))
print
(
'Start training searcher....'
)
print
(
f
'N samples in pool is
{
pool_size
}
'
)
# this reflects the recommended design choices proposed at
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
if
pool_size
<
2e4
:
print
(
'Using brute force search.'
)
searcher
=
search_bruteforce
(
searcher
)
elif
2e4
<=
pool_size
and
pool_size
<
1e5
:
print
(
'Using asymmetric hashing search and reordering.'
)
searcher
=
search_ah
(
searcher
,
dims_per_block
,
aiq_thld
,
reorder_k
)
else
:
print
(
'Using using partioning, asymmetric hashing search and reordering.'
)
if
not
partioning_trainsize
:
partioning_trainsize
=
data_pool
[
'embedding'
].
shape
[
0
]
//
10
if
not
num_leaves
:
num_leaves
=
int
(
np
.
sqrt
(
pool_size
))
if
not
num_leaves_to_search
:
num_leaves_to_search
=
max
(
num_leaves
//
20
,
1
)
print
(
'Partitioning params:'
)
print
(
f
'num_leaves:
{
num_leaves
}
'
)
print
(
f
'num_leaves_to_search:
{
num_leaves_to_search
}
'
)
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
searcher
=
search_partioned_ah
(
searcher
,
dims_per_block
,
aiq_thld
,
reorder_k
,
partioning_trainsize
,
num_leaves
,
num_leaves_to_search
)
print
(
'Finish training searcher'
)
searcher_savedir
=
opt
.
target_path
os
.
makedirs
(
searcher_savedir
,
exist_ok
=
True
)
searcher
.
serialize
(
searcher_savedir
)
print
(
f
'Saved trained searcher under "
{
searcher_savedir
}
"'
)
if
__name__
==
'__main__'
:
sys
.
path
.
append
(
os
.
getcwd
())
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--database'
,
'-d'
,
default
=
'data/rdm/retrieval_databases/openimages'
,
type
=
str
,
help
=
'path to folder containing the clip feature of the database'
)
parser
.
add_argument
(
'--target_path'
,
'-t'
,
default
=
'data/rdm/searchers/openimages'
,
type
=
str
,
help
=
'path to the target folder where the searcher shall be stored.'
)
parser
.
add_argument
(
'--knn'
,
'-k'
,
default
=
20
,
type
=
int
,
help
=
'number of nearest neighbors, for which the searcher shall be optimized'
)
opt
,
_
=
parser
.
parse_known_args
()
train_searcher
(
opt
,)
\ No newline at end of file
examples/images/diffusion/scripts/txt2img.py
0 → 100644
View file @
6e9730d7
import
argparse
,
os
,
sys
,
glob
import
cv2
import
torch
import
numpy
as
np
from
omegaconf
import
OmegaConf
from
PIL
import
Image
from
tqdm
import
tqdm
,
trange
from
imwatermark
import
WatermarkEncoder
from
itertools
import
islice
from
einops
import
rearrange
from
torchvision.utils
import
make_grid
import
time
from
pytorch_lightning
import
seed_everything
from
torch
import
autocast
from
contextlib
import
contextmanager
,
nullcontext
from
ldm.util
import
instantiate_from_config
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
from
diffusers.pipelines.stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
transformers
import
AutoFeatureExtractor
# load safety model
safety_model_id
=
"CompVis/stable-diffusion-safety-checker"
safety_feature_extractor
=
AutoFeatureExtractor
.
from_pretrained
(
safety_model_id
)
safety_checker
=
StableDiffusionSafetyChecker
.
from_pretrained
(
safety_model_id
)
def
chunk
(
it
,
size
):
it
=
iter
(
it
)
return
iter
(
lambda
:
tuple
(
islice
(
it
,
size
)),
())
def
numpy_to_pil
(
images
):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if
images
.
ndim
==
3
:
images
=
images
[
None
,
...]
images
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
pil_images
=
[
Image
.
fromarray
(
image
)
for
image
in
images
]
return
pil_images
def
load_model_from_config
(
config
,
ckpt
,
verbose
=
False
):
print
(
f
"Loading model from
{
ckpt
}
"
)
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
)
if
"global_step"
in
pl_sd
:
print
(
f
"Global Step:
{
pl_sd
[
'global_step'
]
}
"
)
sd
=
pl_sd
[
"state_dict"
]
model
=
instantiate_from_config
(
config
.
model
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
if
len
(
m
)
>
0
and
verbose
:
print
(
"missing keys:"
)
print
(
m
)
if
len
(
u
)
>
0
and
verbose
:
print
(
"unexpected keys:"
)
print
(
u
)
model
.
cuda
()
model
.
eval
()
return
model
def
put_watermark
(
img
,
wm_encoder
=
None
):
if
wm_encoder
is
not
None
:
img
=
cv2
.
cvtColor
(
np
.
array
(
img
),
cv2
.
COLOR_RGB2BGR
)
img
=
wm_encoder
.
encode
(
img
,
'dwtDct'
)
img
=
Image
.
fromarray
(
img
[:,
:,
::
-
1
])
return
img
def
load_replacement
(
x
):
try
:
hwc
=
x
.
shape
y
=
Image
.
open
(
"assets/rick.jpeg"
).
convert
(
"RGB"
).
resize
((
hwc
[
1
],
hwc
[
0
]))
y
=
(
np
.
array
(
y
)
/
255.0
).
astype
(
x
.
dtype
)
assert
y
.
shape
==
x
.
shape
return
y
except
Exception
:
return
x
def
check_safety
(
x_image
):
safety_checker_input
=
safety_feature_extractor
(
numpy_to_pil
(
x_image
),
return_tensors
=
"pt"
)
x_checked_image
,
has_nsfw_concept
=
safety_checker
(
images
=
x_image
,
clip_input
=
safety_checker_input
.
pixel_values
)
assert
x_checked_image
.
shape
[
0
]
==
len
(
has_nsfw_concept
)
for
i
in
range
(
len
(
has_nsfw_concept
)):
if
has_nsfw_concept
[
i
]:
x_checked_image
[
i
]
=
load_replacement
(
x_checked_image
[
i
])
return
x_checked_image
,
has_nsfw_concept
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
nargs
=
"?"
,
default
=
"a painting of a virus monster playing guitar"
,
help
=
"the prompt to render"
)
parser
.
add_argument
(
"--outdir"
,
type
=
str
,
nargs
=
"?"
,
help
=
"dir to write results to"
,
default
=
"outputs/txt2img-samples"
)
parser
.
add_argument
(
"--skip_grid"
,
action
=
'store_true'
,
help
=
"do not save a grid, only individual samples. Helpful when evaluating lots of samples"
,
)
parser
.
add_argument
(
"--skip_save"
,
action
=
'store_true'
,
help
=
"do not save individual samples. For speed measurements."
,
)
parser
.
add_argument
(
"--ddim_steps"
,
type
=
int
,
default
=
50
,
help
=
"number of ddim sampling steps"
,
)
parser
.
add_argument
(
"--plms"
,
action
=
'store_true'
,
help
=
"use plms sampling"
,
)
parser
.
add_argument
(
"--laion400m"
,
action
=
'store_true'
,
help
=
"uses the LAION400M model"
,
)
parser
.
add_argument
(
"--fixed_code"
,
action
=
'store_true'
,
help
=
"if enabled, uses the same starting code across samples "
,
)
parser
.
add_argument
(
"--ddim_eta"
,
type
=
float
,
default
=
0.0
,
help
=
"ddim eta (eta=0.0 corresponds to deterministic sampling"
,
)
parser
.
add_argument
(
"--n_iter"
,
type
=
int
,
default
=
2
,
help
=
"sample this often"
,
)
parser
.
add_argument
(
"--H"
,
type
=
int
,
default
=
512
,
help
=
"image height, in pixel space"
,
)
parser
.
add_argument
(
"--W"
,
type
=
int
,
default
=
512
,
help
=
"image width, in pixel space"
,
)
parser
.
add_argument
(
"--C"
,
type
=
int
,
default
=
4
,
help
=
"latent channels"
,
)
parser
.
add_argument
(
"--f"
,
type
=
int
,
default
=
8
,
help
=
"downsampling factor"
,
)
parser
.
add_argument
(
"--n_samples"
,
type
=
int
,
default
=
3
,
help
=
"how many samples to produce for each given prompt. A.k.a. batch size"
,
)
parser
.
add_argument
(
"--n_rows"
,
type
=
int
,
default
=
0
,
help
=
"rows in the grid (default: n_samples)"
,
)
parser
.
add_argument
(
"--scale"
,
type
=
float
,
default
=
7.5
,
help
=
"unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))"
,
)
parser
.
add_argument
(
"--from-file"
,
type
=
str
,
help
=
"if specified, load prompts from this file"
,
)
parser
.
add_argument
(
"--config"
,
type
=
str
,
default
=
"configs/stable-diffusion/v1-inference.yaml"
,
help
=
"path to config which constructs model"
,
)
parser
.
add_argument
(
"--ckpt"
,
type
=
str
,
default
=
"models/ldm/stable-diffusion-v1/model.ckpt"
,
help
=
"path to checkpoint of model"
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"the seed (for reproducible sampling)"
,
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
help
=
"evaluate at this precision"
,
choices
=
[
"full"
,
"autocast"
],
default
=
"autocast"
)
opt
=
parser
.
parse_args
()
if
opt
.
laion400m
:
print
(
"Falling back to LAION 400M model..."
)
opt
.
config
=
"configs/latent-diffusion/txt2img-1p4B-eval.yaml"
opt
.
ckpt
=
"models/ldm/text2img-large/model.ckpt"
opt
.
outdir
=
"outputs/txt2img-samples-laion400m"
seed_everything
(
opt
.
seed
)
config
=
OmegaConf
.
load
(
f
"
{
opt
.
config
}
"
)
model
=
load_model_from_config
(
config
,
f
"
{
opt
.
ckpt
}
"
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
model
=
model
.
to
(
device
)
if
opt
.
plms
:
sampler
=
PLMSSampler
(
model
)
else
:
sampler
=
DDIMSampler
(
model
)
os
.
makedirs
(
opt
.
outdir
,
exist_ok
=
True
)
outpath
=
opt
.
outdir
print
(
"Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)..."
)
wm
=
"StableDiffusionV1"
wm_encoder
=
WatermarkEncoder
()
wm_encoder
.
set_watermark
(
'bytes'
,
wm
.
encode
(
'utf-8'
))
batch_size
=
opt
.
n_samples
n_rows
=
opt
.
n_rows
if
opt
.
n_rows
>
0
else
batch_size
if
not
opt
.
from_file
:
prompt
=
opt
.
prompt
assert
prompt
is
not
None
data
=
[
batch_size
*
[
prompt
]]
else
:
print
(
f
"reading prompts from
{
opt
.
from_file
}
"
)
with
open
(
opt
.
from_file
,
"r"
)
as
f
:
data
=
f
.
read
().
splitlines
()
data
=
list
(
chunk
(
data
,
batch_size
))
sample_path
=
os
.
path
.
join
(
outpath
,
"samples"
)
os
.
makedirs
(
sample_path
,
exist_ok
=
True
)
base_count
=
len
(
os
.
listdir
(
sample_path
))
grid_count
=
len
(
os
.
listdir
(
outpath
))
-
1
start_code
=
None
if
opt
.
fixed_code
:
start_code
=
torch
.
randn
([
opt
.
n_samples
,
opt
.
C
,
opt
.
H
//
opt
.
f
,
opt
.
W
//
opt
.
f
],
device
=
device
)
precision_scope
=
autocast
if
opt
.
precision
==
"autocast"
else
nullcontext
with
torch
.
no_grad
():
with
precision_scope
(
"cuda"
):
with
model
.
ema_scope
():
tic
=
time
.
time
()
all_samples
=
list
()
for
n
in
trange
(
opt
.
n_iter
,
desc
=
"Sampling"
):
for
prompts
in
tqdm
(
data
,
desc
=
"data"
):
uc
=
None
if
opt
.
scale
!=
1.0
:
uc
=
model
.
get_learned_conditioning
(
batch_size
*
[
""
])
if
isinstance
(
prompts
,
tuple
):
prompts
=
list
(
prompts
)
c
=
model
.
get_learned_conditioning
(
prompts
)
shape
=
[
opt
.
C
,
opt
.
H
//
opt
.
f
,
opt
.
W
//
opt
.
f
]
samples_ddim
,
_
=
sampler
.
sample
(
S
=
opt
.
ddim_steps
,
conditioning
=
c
,
batch_size
=
opt
.
n_samples
,
shape
=
shape
,
verbose
=
False
,
unconditional_guidance_scale
=
opt
.
scale
,
unconditional_conditioning
=
uc
,
eta
=
opt
.
ddim_eta
,
x_T
=
start_code
)
x_samples_ddim
=
model
.
decode_first_stage
(
samples_ddim
)
x_samples_ddim
=
torch
.
clamp
((
x_samples_ddim
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
x_samples_ddim
=
x_samples_ddim
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
x_checked_image
,
has_nsfw_concept
=
check_safety
(
x_samples_ddim
)
x_checked_image_torch
=
torch
.
from_numpy
(
x_checked_image
).
permute
(
0
,
3
,
1
,
2
)
if
not
opt
.
skip_save
:
for
x_sample
in
x_checked_image_torch
:
x_sample
=
255.
*
rearrange
(
x_sample
.
cpu
().
numpy
(),
'c h w -> h w c'
)
img
=
Image
.
fromarray
(
x_sample
.
astype
(
np
.
uint8
))
img
=
put_watermark
(
img
,
wm_encoder
)
img
.
save
(
os
.
path
.
join
(
sample_path
,
f
"
{
base_count
:
05
}
.png"
))
base_count
+=
1
if
not
opt
.
skip_grid
:
all_samples
.
append
(
x_checked_image_torch
)
if
not
opt
.
skip_grid
:
# additionally, save as grid
grid
=
torch
.
stack
(
all_samples
,
0
)
grid
=
rearrange
(
grid
,
'n b c h w -> (n b) c h w'
)
grid
=
make_grid
(
grid
,
nrow
=
n_rows
)
# to image
grid
=
255.
*
rearrange
(
grid
,
'c h w -> h w c'
).
cpu
().
numpy
()
img
=
Image
.
fromarray
(
grid
.
astype
(
np
.
uint8
))
img
=
put_watermark
(
img
,
wm_encoder
)
img
.
save
(
os
.
path
.
join
(
outpath
,
f
'grid-
{
grid_count
:
04
}
.png'
))
grid_count
+=
1
toc
=
time
.
time
()
print
(
f
"Your samples are ready and waiting for you here:
\n
{
outpath
}
\n
"
f
"
\n
Enjoy."
)
if
__name__
==
"__main__"
:
main
()
examples/images/diffusion/setup.py
0 → 100644
View file @
6e9730d7
from
setuptools
import
setup
,
find_packages
setup
(
name
=
'latent-diffusion'
,
version
=
'0.0.1'
,
description
=
''
,
packages
=
find_packages
(),
install_requires
=
[
'torch'
,
'numpy'
,
'tqdm'
,
],
)
\ No newline at end of file
examples/images/diffusion/train.sh
0 → 100755
View file @
6e9730d7
HF_DATASETS_OFFLINE
=
1
TRANSFORMERS_OFFLINE
=
1
python main.py
--logdir
/tmp
-t
--postfix
test
-b
configs/train_colossalai.yaml
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment