Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6c4c6a04
Unverified
Commit
6c4c6a04
authored
Dec 13, 2022
by
Fazzie-Maqianli
Committed by
GitHub
Dec 13, 2022
Browse files
Merge pull request #2120 from Fazziekey/example/stablediffusion-v2
[example] support stable diffusion v2
parents
5efda697
cea4292a
Changes
51
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1421 additions
and
1097 deletions
+1421
-1097
examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py
...ges/diffusion/ldm/modules/midas/midas/midas_net_custom.py
+128
-0
examples/images/diffusion/ldm/modules/midas/midas/transforms.py
...es/images/diffusion/ldm/modules/midas/midas/transforms.py
+234
-0
examples/images/diffusion/ldm/modules/midas/midas/vit.py
examples/images/diffusion/ldm/modules/midas/midas/vit.py
+491
-0
examples/images/diffusion/ldm/modules/midas/utils.py
examples/images/diffusion/ldm/modules/midas/utils.py
+189
-0
examples/images/diffusion/ldm/modules/x_transformer.py
examples/images/diffusion/ldm/modules/x_transformer.py
+0
-641
examples/images/diffusion/ldm/util.py
examples/images/diffusion/ldm/util.py
+111
-117
examples/images/diffusion/main.py
examples/images/diffusion/main.py
+121
-125
examples/images/diffusion/requirements.txt
examples/images/diffusion/requirements.txt
+8
-13
examples/images/diffusion/scripts/img2img.py
examples/images/diffusion/scripts/img2img.py
+43
-54
examples/images/diffusion/scripts/txt2img.py
examples/images/diffusion/scripts/txt2img.py
+92
-144
examples/images/diffusion/train.sh
examples/images/diffusion/train.sh
+4
-3
No files found.
examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py
0 → 100644
View file @
6c4c6a04
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import
torch
import
torch.nn
as
nn
from
.base_model
import
BaseModel
from
.blocks
import
FeatureFusionBlock
,
FeatureFusionBlock_custom
,
Interpolate
,
_make_encoder
class
MidasNet_small
(
BaseModel
):
"""Network for monocular depth estimation.
"""
def
__init__
(
self
,
path
=
None
,
features
=
64
,
backbone
=
"efficientnet_lite3"
,
non_negative
=
True
,
exportable
=
True
,
channels_last
=
False
,
align_corners
=
True
,
blocks
=
{
'expand'
:
True
}):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print
(
"Loading weights: "
,
path
)
super
(
MidasNet_small
,
self
).
__init__
()
use_pretrained
=
False
if
path
else
True
self
.
channels_last
=
channels_last
self
.
blocks
=
blocks
self
.
backbone
=
backbone
self
.
groups
=
1
features1
=
features
features2
=
features
features3
=
features
features4
=
features
self
.
expand
=
False
if
"expand"
in
self
.
blocks
and
self
.
blocks
[
'expand'
]
==
True
:
self
.
expand
=
True
features1
=
features
features2
=
features
*
2
features3
=
features
*
4
features4
=
features
*
8
self
.
pretrained
,
self
.
scratch
=
_make_encoder
(
self
.
backbone
,
features
,
use_pretrained
,
groups
=
self
.
groups
,
expand
=
self
.
expand
,
exportable
=
exportable
)
self
.
scratch
.
activation
=
nn
.
ReLU
(
False
)
self
.
scratch
.
refinenet4
=
FeatureFusionBlock_custom
(
features4
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
expand
=
self
.
expand
,
align_corners
=
align_corners
)
self
.
scratch
.
refinenet3
=
FeatureFusionBlock_custom
(
features3
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
expand
=
self
.
expand
,
align_corners
=
align_corners
)
self
.
scratch
.
refinenet2
=
FeatureFusionBlock_custom
(
features2
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
expand
=
self
.
expand
,
align_corners
=
align_corners
)
self
.
scratch
.
refinenet1
=
FeatureFusionBlock_custom
(
features1
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
align_corners
=
align_corners
)
self
.
scratch
.
output_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
features
,
features
//
2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
groups
=
self
.
groups
),
Interpolate
(
scale_factor
=
2
,
mode
=
"bilinear"
),
nn
.
Conv2d
(
features
//
2
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
self
.
scratch
.
activation
,
nn
.
Conv2d
(
32
,
1
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
),
nn
.
ReLU
(
True
)
if
non_negative
else
nn
.
Identity
(),
nn
.
Identity
(),
)
if
path
:
self
.
load
(
path
)
def
forward
(
self
,
x
):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
if
self
.
channels_last
==
True
:
print
(
"self.channels_last = "
,
self
.
channels_last
)
x
.
contiguous
(
memory_format
=
torch
.
channels_last
)
layer_1
=
self
.
pretrained
.
layer1
(
x
)
layer_2
=
self
.
pretrained
.
layer2
(
layer_1
)
layer_3
=
self
.
pretrained
.
layer3
(
layer_2
)
layer_4
=
self
.
pretrained
.
layer4
(
layer_3
)
layer_1_rn
=
self
.
scratch
.
layer1_rn
(
layer_1
)
layer_2_rn
=
self
.
scratch
.
layer2_rn
(
layer_2
)
layer_3_rn
=
self
.
scratch
.
layer3_rn
(
layer_3
)
layer_4_rn
=
self
.
scratch
.
layer4_rn
(
layer_4
)
path_4
=
self
.
scratch
.
refinenet4
(
layer_4_rn
)
path_3
=
self
.
scratch
.
refinenet3
(
path_4
,
layer_3_rn
)
path_2
=
self
.
scratch
.
refinenet2
(
path_3
,
layer_2_rn
)
path_1
=
self
.
scratch
.
refinenet1
(
path_2
,
layer_1_rn
)
out
=
self
.
scratch
.
output_conv
(
path_1
)
return
torch
.
squeeze
(
out
,
dim
=
1
)
def
fuse_model
(
m
):
prev_previous_type
=
nn
.
Identity
()
prev_previous_name
=
''
previous_type
=
nn
.
Identity
()
previous_name
=
''
for
name
,
module
in
m
.
named_modules
():
if
prev_previous_type
==
nn
.
Conv2d
and
previous_type
==
nn
.
BatchNorm2d
and
type
(
module
)
==
nn
.
ReLU
:
# print("FUSED ", prev_previous_name, previous_name, name)
torch
.
quantization
.
fuse_modules
(
m
,
[
prev_previous_name
,
previous_name
,
name
],
inplace
=
True
)
elif
prev_previous_type
==
nn
.
Conv2d
and
previous_type
==
nn
.
BatchNorm2d
:
# print("FUSED ", prev_previous_name, previous_name)
torch
.
quantization
.
fuse_modules
(
m
,
[
prev_previous_name
,
previous_name
],
inplace
=
True
)
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
# print("FUSED ", previous_name, name)
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
prev_previous_type
=
previous_type
prev_previous_name
=
previous_name
previous_type
=
type
(
module
)
previous_name
=
name
\ No newline at end of file
examples/images/diffusion/ldm/modules/midas/midas/transforms.py
0 → 100644
View file @
6c4c6a04
import
numpy
as
np
import
cv2
import
math
def
apply_min_size
(
sample
,
size
,
image_interpolation_method
=
cv2
.
INTER_AREA
):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape
=
list
(
sample
[
"disparity"
].
shape
)
if
shape
[
0
]
>=
size
[
0
]
and
shape
[
1
]
>=
size
[
1
]:
return
sample
scale
=
[
0
,
0
]
scale
[
0
]
=
size
[
0
]
/
shape
[
0
]
scale
[
1
]
=
size
[
1
]
/
shape
[
1
]
scale
=
max
(
scale
)
shape
[
0
]
=
math
.
ceil
(
scale
*
shape
[
0
])
shape
[
1
]
=
math
.
ceil
(
scale
*
shape
[
1
])
# resize
sample
[
"image"
]
=
cv2
.
resize
(
sample
[
"image"
],
tuple
(
shape
[::
-
1
]),
interpolation
=
image_interpolation_method
)
sample
[
"disparity"
]
=
cv2
.
resize
(
sample
[
"disparity"
],
tuple
(
shape
[::
-
1
]),
interpolation
=
cv2
.
INTER_NEAREST
)
sample
[
"mask"
]
=
cv2
.
resize
(
sample
[
"mask"
].
astype
(
np
.
float32
),
tuple
(
shape
[::
-
1
]),
interpolation
=
cv2
.
INTER_NEAREST
,
)
sample
[
"mask"
]
=
sample
[
"mask"
].
astype
(
bool
)
return
tuple
(
shape
)
class
Resize
(
object
):
"""Resize sample to given size (width, height).
"""
def
__init__
(
self
,
width
,
height
,
resize_target
=
True
,
keep_aspect_ratio
=
False
,
ensure_multiple_of
=
1
,
resize_method
=
"lower_bound"
,
image_interpolation_method
=
cv2
.
INTER_AREA
,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self
.
__width
=
width
self
.
__height
=
height
self
.
__resize_target
=
resize_target
self
.
__keep_aspect_ratio
=
keep_aspect_ratio
self
.
__multiple_of
=
ensure_multiple_of
self
.
__resize_method
=
resize_method
self
.
__image_interpolation_method
=
image_interpolation_method
def
constrain_to_multiple_of
(
self
,
x
,
min_val
=
0
,
max_val
=
None
):
y
=
(
np
.
round
(
x
/
self
.
__multiple_of
)
*
self
.
__multiple_of
).
astype
(
int
)
if
max_val
is
not
None
and
y
>
max_val
:
y
=
(
np
.
floor
(
x
/
self
.
__multiple_of
)
*
self
.
__multiple_of
).
astype
(
int
)
if
y
<
min_val
:
y
=
(
np
.
ceil
(
x
/
self
.
__multiple_of
)
*
self
.
__multiple_of
).
astype
(
int
)
return
y
def
get_size
(
self
,
width
,
height
):
# determine new height and width
scale_height
=
self
.
__height
/
height
scale_width
=
self
.
__width
/
width
if
self
.
__keep_aspect_ratio
:
if
self
.
__resize_method
==
"lower_bound"
:
# scale such that output size is lower bound
if
scale_width
>
scale_height
:
# fit width
scale_height
=
scale_width
else
:
# fit height
scale_width
=
scale_height
elif
self
.
__resize_method
==
"upper_bound"
:
# scale such that output size is upper bound
if
scale_width
<
scale_height
:
# fit width
scale_height
=
scale_width
else
:
# fit height
scale_width
=
scale_height
elif
self
.
__resize_method
==
"minimal"
:
# scale as least as possbile
if
abs
(
1
-
scale_width
)
<
abs
(
1
-
scale_height
):
# fit width
scale_height
=
scale_width
else
:
# fit height
scale_width
=
scale_height
else
:
raise
ValueError
(
f
"resize_method
{
self
.
__resize_method
}
not implemented"
)
if
self
.
__resize_method
==
"lower_bound"
:
new_height
=
self
.
constrain_to_multiple_of
(
scale_height
*
height
,
min_val
=
self
.
__height
)
new_width
=
self
.
constrain_to_multiple_of
(
scale_width
*
width
,
min_val
=
self
.
__width
)
elif
self
.
__resize_method
==
"upper_bound"
:
new_height
=
self
.
constrain_to_multiple_of
(
scale_height
*
height
,
max_val
=
self
.
__height
)
new_width
=
self
.
constrain_to_multiple_of
(
scale_width
*
width
,
max_val
=
self
.
__width
)
elif
self
.
__resize_method
==
"minimal"
:
new_height
=
self
.
constrain_to_multiple_of
(
scale_height
*
height
)
new_width
=
self
.
constrain_to_multiple_of
(
scale_width
*
width
)
else
:
raise
ValueError
(
f
"resize_method
{
self
.
__resize_method
}
not implemented"
)
return
(
new_width
,
new_height
)
def
__call__
(
self
,
sample
):
width
,
height
=
self
.
get_size
(
sample
[
"image"
].
shape
[
1
],
sample
[
"image"
].
shape
[
0
]
)
# resize sample
sample
[
"image"
]
=
cv2
.
resize
(
sample
[
"image"
],
(
width
,
height
),
interpolation
=
self
.
__image_interpolation_method
,
)
if
self
.
__resize_target
:
if
"disparity"
in
sample
:
sample
[
"disparity"
]
=
cv2
.
resize
(
sample
[
"disparity"
],
(
width
,
height
),
interpolation
=
cv2
.
INTER_NEAREST
,
)
if
"depth"
in
sample
:
sample
[
"depth"
]
=
cv2
.
resize
(
sample
[
"depth"
],
(
width
,
height
),
interpolation
=
cv2
.
INTER_NEAREST
)
sample
[
"mask"
]
=
cv2
.
resize
(
sample
[
"mask"
].
astype
(
np
.
float32
),
(
width
,
height
),
interpolation
=
cv2
.
INTER_NEAREST
,
)
sample
[
"mask"
]
=
sample
[
"mask"
].
astype
(
bool
)
return
sample
class
NormalizeImage
(
object
):
"""Normlize image by given mean and std.
"""
def
__init__
(
self
,
mean
,
std
):
self
.
__mean
=
mean
self
.
__std
=
std
def
__call__
(
self
,
sample
):
sample
[
"image"
]
=
(
sample
[
"image"
]
-
self
.
__mean
)
/
self
.
__std
return
sample
class
PrepareForNet
(
object
):
"""Prepare sample for usage as network input.
"""
def
__init__
(
self
):
pass
def
__call__
(
self
,
sample
):
image
=
np
.
transpose
(
sample
[
"image"
],
(
2
,
0
,
1
))
sample
[
"image"
]
=
np
.
ascontiguousarray
(
image
).
astype
(
np
.
float32
)
if
"mask"
in
sample
:
sample
[
"mask"
]
=
sample
[
"mask"
].
astype
(
np
.
float32
)
sample
[
"mask"
]
=
np
.
ascontiguousarray
(
sample
[
"mask"
])
if
"disparity"
in
sample
:
disparity
=
sample
[
"disparity"
].
astype
(
np
.
float32
)
sample
[
"disparity"
]
=
np
.
ascontiguousarray
(
disparity
)
if
"depth"
in
sample
:
depth
=
sample
[
"depth"
].
astype
(
np
.
float32
)
sample
[
"depth"
]
=
np
.
ascontiguousarray
(
depth
)
return
sample
examples/images/diffusion/ldm/modules/midas/midas/vit.py
0 → 100644
View file @
6c4c6a04
import
torch
import
torch.nn
as
nn
import
timm
import
types
import
math
import
torch.nn.functional
as
F
class
Slice
(
nn
.
Module
):
def
__init__
(
self
,
start_index
=
1
):
super
(
Slice
,
self
).
__init__
()
self
.
start_index
=
start_index
def
forward
(
self
,
x
):
return
x
[:,
self
.
start_index
:]
class
AddReadout
(
nn
.
Module
):
def
__init__
(
self
,
start_index
=
1
):
super
(
AddReadout
,
self
).
__init__
()
self
.
start_index
=
start_index
def
forward
(
self
,
x
):
if
self
.
start_index
==
2
:
readout
=
(
x
[:,
0
]
+
x
[:,
1
])
/
2
else
:
readout
=
x
[:,
0
]
return
x
[:,
self
.
start_index
:]
+
readout
.
unsqueeze
(
1
)
class
ProjectReadout
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
start_index
=
1
):
super
(
ProjectReadout
,
self
).
__init__
()
self
.
start_index
=
start_index
self
.
project
=
nn
.
Sequential
(
nn
.
Linear
(
2
*
in_features
,
in_features
),
nn
.
GELU
())
def
forward
(
self
,
x
):
readout
=
x
[:,
0
].
unsqueeze
(
1
).
expand_as
(
x
[:,
self
.
start_index
:])
features
=
torch
.
cat
((
x
[:,
self
.
start_index
:],
readout
),
-
1
)
return
self
.
project
(
features
)
class
Transpose
(
nn
.
Module
):
def
__init__
(
self
,
dim0
,
dim1
):
super
(
Transpose
,
self
).
__init__
()
self
.
dim0
=
dim0
self
.
dim1
=
dim1
def
forward
(
self
,
x
):
x
=
x
.
transpose
(
self
.
dim0
,
self
.
dim1
)
return
x
def
forward_vit
(
pretrained
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
glob
=
pretrained
.
model
.
forward_flex
(
x
)
layer_1
=
pretrained
.
activations
[
"1"
]
layer_2
=
pretrained
.
activations
[
"2"
]
layer_3
=
pretrained
.
activations
[
"3"
]
layer_4
=
pretrained
.
activations
[
"4"
]
layer_1
=
pretrained
.
act_postprocess1
[
0
:
2
](
layer_1
)
layer_2
=
pretrained
.
act_postprocess2
[
0
:
2
](
layer_2
)
layer_3
=
pretrained
.
act_postprocess3
[
0
:
2
](
layer_3
)
layer_4
=
pretrained
.
act_postprocess4
[
0
:
2
](
layer_4
)
unflatten
=
nn
.
Sequential
(
nn
.
Unflatten
(
2
,
torch
.
Size
(
[
h
//
pretrained
.
model
.
patch_size
[
1
],
w
//
pretrained
.
model
.
patch_size
[
0
],
]
),
)
)
if
layer_1
.
ndim
==
3
:
layer_1
=
unflatten
(
layer_1
)
if
layer_2
.
ndim
==
3
:
layer_2
=
unflatten
(
layer_2
)
if
layer_3
.
ndim
==
3
:
layer_3
=
unflatten
(
layer_3
)
if
layer_4
.
ndim
==
3
:
layer_4
=
unflatten
(
layer_4
)
layer_1
=
pretrained
.
act_postprocess1
[
3
:
len
(
pretrained
.
act_postprocess1
)](
layer_1
)
layer_2
=
pretrained
.
act_postprocess2
[
3
:
len
(
pretrained
.
act_postprocess2
)](
layer_2
)
layer_3
=
pretrained
.
act_postprocess3
[
3
:
len
(
pretrained
.
act_postprocess3
)](
layer_3
)
layer_4
=
pretrained
.
act_postprocess4
[
3
:
len
(
pretrained
.
act_postprocess4
)](
layer_4
)
return
layer_1
,
layer_2
,
layer_3
,
layer_4
def
_resize_pos_embed
(
self
,
posemb
,
gs_h
,
gs_w
):
posemb_tok
,
posemb_grid
=
(
posemb
[:,
:
self
.
start_index
],
posemb
[
0
,
self
.
start_index
:],
)
gs_old
=
int
(
math
.
sqrt
(
len
(
posemb_grid
)))
posemb_grid
=
posemb_grid
.
reshape
(
1
,
gs_old
,
gs_old
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
posemb_grid
=
F
.
interpolate
(
posemb_grid
,
size
=
(
gs_h
,
gs_w
),
mode
=
"bilinear"
)
posemb_grid
=
posemb_grid
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
1
,
gs_h
*
gs_w
,
-
1
)
posemb
=
torch
.
cat
([
posemb_tok
,
posemb_grid
],
dim
=
1
)
return
posemb
def
forward_flex
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
pos_embed
=
self
.
_resize_pos_embed
(
self
.
pos_embed
,
h
//
self
.
patch_size
[
1
],
w
//
self
.
patch_size
[
0
]
)
B
=
x
.
shape
[
0
]
if
hasattr
(
self
.
patch_embed
,
"backbone"
):
x
=
self
.
patch_embed
.
backbone
(
x
)
if
isinstance
(
x
,
(
list
,
tuple
)):
x
=
x
[
-
1
]
# last feature if backbone outputs list/tuple of features
x
=
self
.
patch_embed
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
if
getattr
(
self
,
"dist_token"
,
None
)
is
not
None
:
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
# stole cls_tokens impl from Phil Wang, thanks
dist_token
=
self
.
dist_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
dist_token
,
x
),
dim
=
1
)
else
:
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
# stole cls_tokens impl from Phil Wang, thanks
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
x
=
x
+
pos_embed
x
=
self
.
pos_drop
(
x
)
for
blk
in
self
.
blocks
:
x
=
blk
(
x
)
x
=
self
.
norm
(
x
)
return
x
activations
=
{}
def
get_activation
(
name
):
def
hook
(
model
,
input
,
output
):
activations
[
name
]
=
output
return
hook
def
get_readout_oper
(
vit_features
,
features
,
use_readout
,
start_index
=
1
):
if
use_readout
==
"ignore"
:
readout_oper
=
[
Slice
(
start_index
)]
*
len
(
features
)
elif
use_readout
==
"add"
:
readout_oper
=
[
AddReadout
(
start_index
)]
*
len
(
features
)
elif
use_readout
==
"project"
:
readout_oper
=
[
ProjectReadout
(
vit_features
,
start_index
)
for
out_feat
in
features
]
else
:
assert
(
False
),
"wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
return
readout_oper
def
_make_vit_b16_backbone
(
model
,
features
=
[
96
,
192
,
384
,
768
],
size
=
[
384
,
384
],
hooks
=
[
2
,
5
,
8
,
11
],
vit_features
=
768
,
use_readout
=
"ignore"
,
start_index
=
1
,
):
pretrained
=
nn
.
Module
()
pretrained
.
model
=
model
pretrained
.
model
.
blocks
[
hooks
[
0
]].
register_forward_hook
(
get_activation
(
"1"
))
pretrained
.
model
.
blocks
[
hooks
[
1
]].
register_forward_hook
(
get_activation
(
"2"
))
pretrained
.
model
.
blocks
[
hooks
[
2
]].
register_forward_hook
(
get_activation
(
"3"
))
pretrained
.
model
.
blocks
[
hooks
[
3
]].
register_forward_hook
(
get_activation
(
"4"
))
pretrained
.
activations
=
activations
readout_oper
=
get_readout_oper
(
vit_features
,
features
,
use_readout
,
start_index
)
# 32, 48, 136, 384
pretrained
.
act_postprocess1
=
nn
.
Sequential
(
readout_oper
[
0
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
0
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
ConvTranspose2d
(
in_channels
=
features
[
0
],
out_channels
=
features
[
0
],
kernel_size
=
4
,
stride
=
4
,
padding
=
0
,
bias
=
True
,
dilation
=
1
,
groups
=
1
,
),
)
pretrained
.
act_postprocess2
=
nn
.
Sequential
(
readout_oper
[
1
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
1
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
ConvTranspose2d
(
in_channels
=
features
[
1
],
out_channels
=
features
[
1
],
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
bias
=
True
,
dilation
=
1
,
groups
=
1
,
),
)
pretrained
.
act_postprocess3
=
nn
.
Sequential
(
readout_oper
[
2
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
2
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
)
pretrained
.
act_postprocess4
=
nn
.
Sequential
(
readout_oper
[
3
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
3
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
Conv2d
(
in_channels
=
features
[
3
],
out_channels
=
features
[
3
],
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
),
)
pretrained
.
model
.
start_index
=
start_index
pretrained
.
model
.
patch_size
=
[
16
,
16
]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained
.
model
.
forward_flex
=
types
.
MethodType
(
forward_flex
,
pretrained
.
model
)
pretrained
.
model
.
_resize_pos_embed
=
types
.
MethodType
(
_resize_pos_embed
,
pretrained
.
model
)
return
pretrained
def
_make_pretrained_vitl16_384
(
pretrained
,
use_readout
=
"ignore"
,
hooks
=
None
):
model
=
timm
.
create_model
(
"vit_large_patch16_384"
,
pretrained
=
pretrained
)
hooks
=
[
5
,
11
,
17
,
23
]
if
hooks
==
None
else
hooks
return
_make_vit_b16_backbone
(
model
,
features
=
[
256
,
512
,
1024
,
1024
],
hooks
=
hooks
,
vit_features
=
1024
,
use_readout
=
use_readout
,
)
def
_make_pretrained_vitb16_384
(
pretrained
,
use_readout
=
"ignore"
,
hooks
=
None
):
model
=
timm
.
create_model
(
"vit_base_patch16_384"
,
pretrained
=
pretrained
)
hooks
=
[
2
,
5
,
8
,
11
]
if
hooks
==
None
else
hooks
return
_make_vit_b16_backbone
(
model
,
features
=
[
96
,
192
,
384
,
768
],
hooks
=
hooks
,
use_readout
=
use_readout
)
def
_make_pretrained_deitb16_384
(
pretrained
,
use_readout
=
"ignore"
,
hooks
=
None
):
model
=
timm
.
create_model
(
"vit_deit_base_patch16_384"
,
pretrained
=
pretrained
)
hooks
=
[
2
,
5
,
8
,
11
]
if
hooks
==
None
else
hooks
return
_make_vit_b16_backbone
(
model
,
features
=
[
96
,
192
,
384
,
768
],
hooks
=
hooks
,
use_readout
=
use_readout
)
def
_make_pretrained_deitb16_distil_384
(
pretrained
,
use_readout
=
"ignore"
,
hooks
=
None
):
model
=
timm
.
create_model
(
"vit_deit_base_distilled_patch16_384"
,
pretrained
=
pretrained
)
hooks
=
[
2
,
5
,
8
,
11
]
if
hooks
==
None
else
hooks
return
_make_vit_b16_backbone
(
model
,
features
=
[
96
,
192
,
384
,
768
],
hooks
=
hooks
,
use_readout
=
use_readout
,
start_index
=
2
,
)
def
_make_vit_b_rn50_backbone
(
model
,
features
=
[
256
,
512
,
768
,
768
],
size
=
[
384
,
384
],
hooks
=
[
0
,
1
,
8
,
11
],
vit_features
=
768
,
use_vit_only
=
False
,
use_readout
=
"ignore"
,
start_index
=
1
,
):
pretrained
=
nn
.
Module
()
pretrained
.
model
=
model
if
use_vit_only
==
True
:
pretrained
.
model
.
blocks
[
hooks
[
0
]].
register_forward_hook
(
get_activation
(
"1"
))
pretrained
.
model
.
blocks
[
hooks
[
1
]].
register_forward_hook
(
get_activation
(
"2"
))
else
:
pretrained
.
model
.
patch_embed
.
backbone
.
stages
[
0
].
register_forward_hook
(
get_activation
(
"1"
)
)
pretrained
.
model
.
patch_embed
.
backbone
.
stages
[
1
].
register_forward_hook
(
get_activation
(
"2"
)
)
pretrained
.
model
.
blocks
[
hooks
[
2
]].
register_forward_hook
(
get_activation
(
"3"
))
pretrained
.
model
.
blocks
[
hooks
[
3
]].
register_forward_hook
(
get_activation
(
"4"
))
pretrained
.
activations
=
activations
readout_oper
=
get_readout_oper
(
vit_features
,
features
,
use_readout
,
start_index
)
if
use_vit_only
==
True
:
pretrained
.
act_postprocess1
=
nn
.
Sequential
(
readout_oper
[
0
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
0
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
ConvTranspose2d
(
in_channels
=
features
[
0
],
out_channels
=
features
[
0
],
kernel_size
=
4
,
stride
=
4
,
padding
=
0
,
bias
=
True
,
dilation
=
1
,
groups
=
1
,
),
)
pretrained
.
act_postprocess2
=
nn
.
Sequential
(
readout_oper
[
1
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
1
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
ConvTranspose2d
(
in_channels
=
features
[
1
],
out_channels
=
features
[
1
],
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
bias
=
True
,
dilation
=
1
,
groups
=
1
,
),
)
else
:
pretrained
.
act_postprocess1
=
nn
.
Sequential
(
nn
.
Identity
(),
nn
.
Identity
(),
nn
.
Identity
()
)
pretrained
.
act_postprocess2
=
nn
.
Sequential
(
nn
.
Identity
(),
nn
.
Identity
(),
nn
.
Identity
()
)
pretrained
.
act_postprocess3
=
nn
.
Sequential
(
readout_oper
[
2
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
2
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
)
pretrained
.
act_postprocess4
=
nn
.
Sequential
(
readout_oper
[
3
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
3
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
Conv2d
(
in_channels
=
features
[
3
],
out_channels
=
features
[
3
],
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
),
)
pretrained
.
model
.
start_index
=
start_index
pretrained
.
model
.
patch_size
=
[
16
,
16
]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained
.
model
.
forward_flex
=
types
.
MethodType
(
forward_flex
,
pretrained
.
model
)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained
.
model
.
_resize_pos_embed
=
types
.
MethodType
(
_resize_pos_embed
,
pretrained
.
model
)
return
pretrained
def
_make_pretrained_vitb_rn50_384
(
pretrained
,
use_readout
=
"ignore"
,
hooks
=
None
,
use_vit_only
=
False
):
model
=
timm
.
create_model
(
"vit_base_resnet50_384"
,
pretrained
=
pretrained
)
hooks
=
[
0
,
1
,
8
,
11
]
if
hooks
==
None
else
hooks
return
_make_vit_b_rn50_backbone
(
model
,
features
=
[
256
,
512
,
768
,
768
],
size
=
[
384
,
384
],
hooks
=
hooks
,
use_vit_only
=
use_vit_only
,
use_readout
=
use_readout
,
)
examples/images/diffusion/ldm/modules/midas/utils.py
0 → 100644
View file @
6c4c6a04
"""Utils for monoDepth."""
import
sys
import
re
import
numpy
as
np
import
cv2
import
torch
def
read_pfm
(
path
):
"""Read pfm file.
Args:
path (str): path to file
Returns:
tuple: (data, scale)
"""
with
open
(
path
,
"rb"
)
as
file
:
color
=
None
width
=
None
height
=
None
scale
=
None
endian
=
None
header
=
file
.
readline
().
rstrip
()
if
header
.
decode
(
"ascii"
)
==
"PF"
:
color
=
True
elif
header
.
decode
(
"ascii"
)
==
"Pf"
:
color
=
False
else
:
raise
Exception
(
"Not a PFM file: "
+
path
)
dim_match
=
re
.
match
(
r
"^(\d+)\s(\d+)\s$"
,
file
.
readline
().
decode
(
"ascii"
))
if
dim_match
:
width
,
height
=
list
(
map
(
int
,
dim_match
.
groups
()))
else
:
raise
Exception
(
"Malformed PFM header."
)
scale
=
float
(
file
.
readline
().
decode
(
"ascii"
).
rstrip
())
if
scale
<
0
:
# little-endian
endian
=
"<"
scale
=
-
scale
else
:
# big-endian
endian
=
">"
data
=
np
.
fromfile
(
file
,
endian
+
"f"
)
shape
=
(
height
,
width
,
3
)
if
color
else
(
height
,
width
)
data
=
np
.
reshape
(
data
,
shape
)
data
=
np
.
flipud
(
data
)
return
data
,
scale
def
write_pfm
(
path
,
image
,
scale
=
1
):
"""Write pfm file.
Args:
path (str): pathto file
image (array): data
scale (int, optional): Scale. Defaults to 1.
"""
with
open
(
path
,
"wb"
)
as
file
:
color
=
None
if
image
.
dtype
.
name
!=
"float32"
:
raise
Exception
(
"Image dtype must be float32."
)
image
=
np
.
flipud
(
image
)
if
len
(
image
.
shape
)
==
3
and
image
.
shape
[
2
]
==
3
:
# color image
color
=
True
elif
(
len
(
image
.
shape
)
==
2
or
len
(
image
.
shape
)
==
3
and
image
.
shape
[
2
]
==
1
):
# greyscale
color
=
False
else
:
raise
Exception
(
"Image must have H x W x 3, H x W x 1 or H x W dimensions."
)
file
.
write
(
"PF
\n
"
if
color
else
"Pf
\n
"
.
encode
())
file
.
write
(
"%d %d
\n
"
.
encode
()
%
(
image
.
shape
[
1
],
image
.
shape
[
0
]))
endian
=
image
.
dtype
.
byteorder
if
endian
==
"<"
or
endian
==
"="
and
sys
.
byteorder
==
"little"
:
scale
=
-
scale
file
.
write
(
"%f
\n
"
.
encode
()
%
scale
)
image
.
tofile
(
file
)
def
read_image
(
path
):
"""Read image and output RGB image (0-1).
Args:
path (str): path to file
Returns:
array: RGB image (0-1)
"""
img
=
cv2
.
imread
(
path
)
if
img
.
ndim
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
/
255.0
return
img
def
resize_image
(
img
):
"""Resize image and make it fit for network.
Args:
img (array): image
Returns:
tensor: data ready for network
"""
height_orig
=
img
.
shape
[
0
]
width_orig
=
img
.
shape
[
1
]
if
width_orig
>
height_orig
:
scale
=
width_orig
/
384
else
:
scale
=
height_orig
/
384
height
=
(
np
.
ceil
(
height_orig
/
scale
/
32
)
*
32
).
astype
(
int
)
width
=
(
np
.
ceil
(
width_orig
/
scale
/
32
)
*
32
).
astype
(
int
)
img_resized
=
cv2
.
resize
(
img
,
(
width
,
height
),
interpolation
=
cv2
.
INTER_AREA
)
img_resized
=
(
torch
.
from_numpy
(
np
.
transpose
(
img_resized
,
(
2
,
0
,
1
))).
contiguous
().
float
()
)
img_resized
=
img_resized
.
unsqueeze
(
0
)
return
img_resized
def
resize_depth
(
depth
,
width
,
height
):
"""Resize depth map and bring to CPU (numpy).
Args:
depth (tensor): depth
width (int): image width
height (int): image height
Returns:
array: processed depth
"""
depth
=
torch
.
squeeze
(
depth
[
0
,
:,
:,
:]).
to
(
"cpu"
)
depth_resized
=
cv2
.
resize
(
depth
.
numpy
(),
(
width
,
height
),
interpolation
=
cv2
.
INTER_CUBIC
)
return
depth_resized
def
write_depth
(
path
,
depth
,
bits
=
1
):
"""Write depth map to pfm and png file.
Args:
path (str): filepath without extension
depth (array): depth
"""
write_pfm
(
path
+
".pfm"
,
depth
.
astype
(
np
.
float32
))
depth_min
=
depth
.
min
()
depth_max
=
depth
.
max
()
max_val
=
(
2
**
(
8
*
bits
))
-
1
if
depth_max
-
depth_min
>
np
.
finfo
(
"float"
).
eps
:
out
=
max_val
*
(
depth
-
depth_min
)
/
(
depth_max
-
depth_min
)
else
:
out
=
np
.
zeros
(
depth
.
shape
,
dtype
=
depth
.
type
)
if
bits
==
1
:
cv2
.
imwrite
(
path
+
".png"
,
out
.
astype
(
"uint8"
))
elif
bits
==
2
:
cv2
.
imwrite
(
path
+
".png"
,
out
.
astype
(
"uint16"
))
return
examples/images/diffusion/ldm/modules/x_transformer.py
deleted
100644 → 0
View file @
5efda697
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
import
torch
from
torch
import
nn
,
einsum
import
torch.nn.functional
as
F
from
functools
import
partial
from
inspect
import
isfunction
from
collections
import
namedtuple
from
einops
import
rearrange
,
repeat
,
reduce
# constants
DEFAULT_DIM_HEAD
=
64
Intermediates
=
namedtuple
(
'Intermediates'
,
[
'pre_softmax_attn'
,
'post_softmax_attn'
])
LayerIntermediates
=
namedtuple
(
'Intermediates'
,
[
'hiddens'
,
'attn_intermediates'
])
class
AbsolutePositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_seq_len
):
super
().
__init__
()
self
.
emb
=
nn
.
Embedding
(
max_seq_len
,
dim
)
self
.
init_
()
def
init_
(
self
):
nn
.
init
.
normal_
(
self
.
emb
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
):
n
=
torch
.
arange
(
x
.
shape
[
1
],
device
=
x
.
device
)
return
self
.
emb
(
n
)[
None
,
:,
:]
class
FixedPositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
inv_freq
=
1.
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
def
forward
(
self
,
x
,
seq_dim
=
1
,
offset
=
0
):
t
=
torch
.
arange
(
x
.
shape
[
seq_dim
],
device
=
x
.
device
).
type_as
(
self
.
inv_freq
)
+
offset
sinusoid_inp
=
torch
.
einsum
(
'i , j -> i j'
,
t
,
self
.
inv_freq
)
emb
=
torch
.
cat
((
sinusoid_inp
.
sin
(),
sinusoid_inp
.
cos
()),
dim
=-
1
)
return
emb
[
None
,
:,
:]
# helpers
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
always
(
val
):
def
inner
(
*
args
,
**
kwargs
):
return
val
return
inner
def
not_equals
(
val
):
def
inner
(
x
):
return
x
!=
val
return
inner
def
equals
(
val
):
def
inner
(
x
):
return
x
==
val
return
inner
def
max_neg_value
(
tensor
):
return
-
torch
.
finfo
(
tensor
.
dtype
).
max
# keyword argument helpers
def
pick_and_pop
(
keys
,
d
):
values
=
list
(
map
(
lambda
key
:
d
.
pop
(
key
),
keys
))
return
dict
(
zip
(
keys
,
values
))
def
group_dict_by_key
(
cond
,
d
):
return_val
=
[
dict
(),
dict
()]
for
key
in
d
.
keys
():
match
=
bool
(
cond
(
key
))
ind
=
int
(
not
match
)
return_val
[
ind
][
key
]
=
d
[
key
]
return
(
*
return_val
,)
def
string_begins_with
(
prefix
,
str
):
return
str
.
startswith
(
prefix
)
def
group_by_key_prefix
(
prefix
,
d
):
return
group_dict_by_key
(
partial
(
string_begins_with
,
prefix
),
d
)
def
groupby_prefix_and_trim
(
prefix
,
d
):
kwargs_with_prefix
,
kwargs
=
group_dict_by_key
(
partial
(
string_begins_with
,
prefix
),
d
)
kwargs_without_prefix
=
dict
(
map
(
lambda
x
:
(
x
[
0
][
len
(
prefix
):],
x
[
1
]),
tuple
(
kwargs_with_prefix
.
items
())))
return
kwargs_without_prefix
,
kwargs
# classes
class
Scale
(
nn
.
Module
):
def
__init__
(
self
,
value
,
fn
):
super
().
__init__
()
self
.
value
=
value
self
.
fn
=
fn
def
forward
(
self
,
x
,
**
kwargs
):
x
,
*
rest
=
self
.
fn
(
x
,
**
kwargs
)
return
(
x
*
self
.
value
,
*
rest
)
class
Rezero
(
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
().
__init__
()
self
.
fn
=
fn
self
.
g
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
x
,
**
kwargs
):
x
,
*
rest
=
self
.
fn
(
x
,
**
kwargs
)
return
(
x
*
self
.
g
,
*
rest
)
class
ScaleNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-5
):
super
().
__init__
()
self
.
scale
=
dim
**
-
0.5
self
.
eps
=
eps
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
1
))
def
forward
(
self
,
x
):
norm
=
torch
.
norm
(
x
,
dim
=-
1
,
keepdim
=
True
)
*
self
.
scale
return
x
/
norm
.
clamp
(
min
=
self
.
eps
)
*
self
.
g
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-8
):
super
().
__init__
()
self
.
scale
=
dim
**
-
0.5
self
.
eps
=
eps
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
dim
))
def
forward
(
self
,
x
):
norm
=
torch
.
norm
(
x
,
dim
=-
1
,
keepdim
=
True
)
*
self
.
scale
return
x
/
norm
.
clamp
(
min
=
self
.
eps
)
*
self
.
g
class
Residual
(
nn
.
Module
):
def
forward
(
self
,
x
,
residual
):
return
x
+
residual
class
GRUGating
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
gru
=
nn
.
GRUCell
(
dim
,
dim
)
def
forward
(
self
,
x
,
residual
):
gated_output
=
self
.
gru
(
rearrange
(
x
,
'b n d -> (b n) d'
),
rearrange
(
residual
,
'b n d -> (b n) d'
)
)
return
gated_output
.
reshape_as
(
x
)
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
# attention.
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_head
=
DEFAULT_DIM_HEAD
,
heads
=
8
,
causal
=
False
,
mask
=
None
,
talking_heads
=
False
,
sparse_topk
=
None
,
use_entmax15
=
False
,
num_mem_kv
=
0
,
dropout
=
0.
,
on_attn
=
False
):
super
().
__init__
()
if
use_entmax15
:
raise
NotImplementedError
(
"Check out entmax activation instead of softmax activation!"
)
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
causal
=
causal
self
.
mask
=
mask
inner_dim
=
dim_head
*
heads
self
.
to_q
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
# talking heads
self
.
talking_heads
=
talking_heads
if
talking_heads
:
self
.
pre_softmax_proj
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
heads
))
self
.
post_softmax_proj
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
heads
))
# explicit topk sparse attention
self
.
sparse_topk
=
sparse_topk
# entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
self
.
attn_fn
=
F
.
softmax
# add memory key / values
self
.
num_mem_kv
=
num_mem_kv
if
num_mem_kv
>
0
:
self
.
mem_k
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
num_mem_kv
,
dim_head
))
self
.
mem_v
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
num_mem_kv
,
dim_head
))
# attention on attention
self
.
attn_on_attn
=
on_attn
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
dim
*
2
),
nn
.
GLU
())
if
on_attn
else
nn
.
Linear
(
inner_dim
,
dim
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
context_mask
=
None
,
rel_pos
=
None
,
sinusoidal_emb
=
None
,
prev_attn
=
None
,
mem
=
None
):
b
,
n
,
_
,
h
,
talking_heads
,
device
=
*
x
.
shape
,
self
.
heads
,
self
.
talking_heads
,
x
.
device
kv_input
=
default
(
context
,
x
)
q_input
=
x
k_input
=
kv_input
v_input
=
kv_input
if
exists
(
mem
):
k_input
=
torch
.
cat
((
mem
,
k_input
),
dim
=-
2
)
v_input
=
torch
.
cat
((
mem
,
v_input
),
dim
=-
2
)
if
exists
(
sinusoidal_emb
):
# in shortformer, the query would start at a position offset depending on the past cached memory
offset
=
k_input
.
shape
[
-
2
]
-
q_input
.
shape
[
-
2
]
q_input
=
q_input
+
sinusoidal_emb
(
q_input
,
offset
=
offset
)
k_input
=
k_input
+
sinusoidal_emb
(
k_input
)
q
=
self
.
to_q
(
q_input
)
k
=
self
.
to_k
(
k_input
)
v
=
self
.
to_v
(
v_input
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b n (h d) -> b h n d'
,
h
=
h
),
(
q
,
k
,
v
))
input_mask
=
None
if
any
(
map
(
exists
,
(
mask
,
context_mask
))):
q_mask
=
default
(
mask
,
lambda
:
torch
.
ones
((
b
,
n
),
device
=
device
).
bool
())
k_mask
=
q_mask
if
not
exists
(
context
)
else
context_mask
k_mask
=
default
(
k_mask
,
lambda
:
torch
.
ones
((
b
,
k
.
shape
[
-
2
]),
device
=
device
).
bool
())
q_mask
=
rearrange
(
q_mask
,
'b i -> b () i ()'
)
k_mask
=
rearrange
(
k_mask
,
'b j -> b () () j'
)
input_mask
=
q_mask
*
k_mask
if
self
.
num_mem_kv
>
0
:
mem_k
,
mem_v
=
map
(
lambda
t
:
repeat
(
t
,
'h n d -> b h n d'
,
b
=
b
),
(
self
.
mem_k
,
self
.
mem_v
))
k
=
torch
.
cat
((
mem_k
,
k
),
dim
=-
2
)
v
=
torch
.
cat
((
mem_v
,
v
),
dim
=-
2
)
if
exists
(
input_mask
):
input_mask
=
F
.
pad
(
input_mask
,
(
self
.
num_mem_kv
,
0
),
value
=
True
)
dots
=
einsum
(
'b h i d, b h j d -> b h i j'
,
q
,
k
)
*
self
.
scale
mask_value
=
max_neg_value
(
dots
)
if
exists
(
prev_attn
):
dots
=
dots
+
prev_attn
pre_softmax_attn
=
dots
if
talking_heads
:
dots
=
einsum
(
'b h i j, h k -> b k i j'
,
dots
,
self
.
pre_softmax_proj
).
contiguous
()
if
exists
(
rel_pos
):
dots
=
rel_pos
(
dots
)
if
exists
(
input_mask
):
dots
.
masked_fill_
(
~
input_mask
,
mask_value
)
del
input_mask
if
self
.
causal
:
i
,
j
=
dots
.
shape
[
-
2
:]
r
=
torch
.
arange
(
i
,
device
=
device
)
mask
=
rearrange
(
r
,
'i -> () () i ()'
)
<
rearrange
(
r
,
'j -> () () () j'
)
mask
=
F
.
pad
(
mask
,
(
j
-
i
,
0
),
value
=
False
)
dots
.
masked_fill_
(
mask
,
mask_value
)
del
mask
if
exists
(
self
.
sparse_topk
)
and
self
.
sparse_topk
<
dots
.
shape
[
-
1
]:
top
,
_
=
dots
.
topk
(
self
.
sparse_topk
,
dim
=-
1
)
vk
=
top
[...,
-
1
].
unsqueeze
(
-
1
).
expand_as
(
dots
)
mask
=
dots
<
vk
dots
.
masked_fill_
(
mask
,
mask_value
)
del
mask
attn
=
self
.
attn_fn
(
dots
,
dim
=-
1
)
post_softmax_attn
=
attn
attn
=
self
.
dropout
(
attn
)
if
talking_heads
:
attn
=
einsum
(
'b h i j, h k -> b k i j'
,
attn
,
self
.
post_softmax_proj
).
contiguous
()
out
=
einsum
(
'b h i j, b h j d -> b h i d'
,
attn
,
v
)
out
=
rearrange
(
out
,
'b h n d -> b n (h d)'
)
intermediates
=
Intermediates
(
pre_softmax_attn
=
pre_softmax_attn
,
post_softmax_attn
=
post_softmax_attn
)
return
self
.
to_out
(
out
),
intermediates
class
AttentionLayers
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
depth
,
heads
=
8
,
causal
=
False
,
cross_attend
=
False
,
only_cross
=
False
,
use_scalenorm
=
False
,
use_rmsnorm
=
False
,
use_rezero
=
False
,
rel_pos_num_buckets
=
32
,
rel_pos_max_distance
=
128
,
position_infused_attn
=
False
,
custom_layers
=
None
,
sandwich_coef
=
None
,
par_ratio
=
None
,
residual_attn
=
False
,
cross_residual_attn
=
False
,
macaron
=
False
,
pre_norm
=
True
,
gate_residual
=
False
,
**
kwargs
):
super
().
__init__
()
ff_kwargs
,
kwargs
=
groupby_prefix_and_trim
(
'ff_'
,
kwargs
)
attn_kwargs
,
_
=
groupby_prefix_and_trim
(
'attn_'
,
kwargs
)
dim_head
=
attn_kwargs
.
get
(
'dim_head'
,
DEFAULT_DIM_HEAD
)
self
.
dim
=
dim
self
.
depth
=
depth
self
.
layers
=
nn
.
ModuleList
([])
self
.
has_pos_emb
=
position_infused_attn
self
.
pia_pos_emb
=
FixedPositionalEmbedding
(
dim
)
if
position_infused_attn
else
None
self
.
rotary_pos_emb
=
always
(
None
)
assert
rel_pos_num_buckets
<=
rel_pos_max_distance
,
'number of relative position buckets must be less than the relative position max distance'
self
.
rel_pos
=
None
self
.
pre_norm
=
pre_norm
self
.
residual_attn
=
residual_attn
self
.
cross_residual_attn
=
cross_residual_attn
norm_class
=
ScaleNorm
if
use_scalenorm
else
nn
.
LayerNorm
norm_class
=
RMSNorm
if
use_rmsnorm
else
norm_class
norm_fn
=
partial
(
norm_class
,
dim
)
norm_fn
=
nn
.
Identity
if
use_rezero
else
norm_fn
branch_fn
=
Rezero
if
use_rezero
else
None
if
cross_attend
and
not
only_cross
:
default_block
=
(
'a'
,
'c'
,
'f'
)
elif
cross_attend
and
only_cross
:
default_block
=
(
'c'
,
'f'
)
else
:
default_block
=
(
'a'
,
'f'
)
if
macaron
:
default_block
=
(
'f'
,)
+
default_block
if
exists
(
custom_layers
):
layer_types
=
custom_layers
elif
exists
(
par_ratio
):
par_depth
=
depth
*
len
(
default_block
)
assert
1
<
par_ratio
<=
par_depth
,
'par ratio out of range'
default_block
=
tuple
(
filter
(
not_equals
(
'f'
),
default_block
))
par_attn
=
par_depth
//
par_ratio
depth_cut
=
par_depth
*
2
//
3
# 2 / 3 attention layer cutoff suggested by PAR paper
par_width
=
(
depth_cut
+
depth_cut
//
par_attn
)
//
par_attn
assert
len
(
default_block
)
<=
par_width
,
'default block is too large for par_ratio'
par_block
=
default_block
+
(
'f'
,)
*
(
par_width
-
len
(
default_block
))
par_head
=
par_block
*
par_attn
layer_types
=
par_head
+
(
'f'
,)
*
(
par_depth
-
len
(
par_head
))
elif
exists
(
sandwich_coef
):
assert
sandwich_coef
>
0
and
sandwich_coef
<=
depth
,
'sandwich coefficient should be less than the depth'
layer_types
=
(
'a'
,)
*
sandwich_coef
+
default_block
*
(
depth
-
sandwich_coef
)
+
(
'f'
,)
*
sandwich_coef
else
:
layer_types
=
default_block
*
depth
self
.
layer_types
=
layer_types
self
.
num_attn_layers
=
len
(
list
(
filter
(
equals
(
'a'
),
layer_types
)))
for
layer_type
in
self
.
layer_types
:
if
layer_type
==
'a'
:
layer
=
Attention
(
dim
,
heads
=
heads
,
causal
=
causal
,
**
attn_kwargs
)
elif
layer_type
==
'c'
:
layer
=
Attention
(
dim
,
heads
=
heads
,
**
attn_kwargs
)
elif
layer_type
==
'f'
:
layer
=
FeedForward
(
dim
,
**
ff_kwargs
)
layer
=
layer
if
not
macaron
else
Scale
(
0.5
,
layer
)
else
:
raise
Exception
(
f
'invalid layer type
{
layer_type
}
'
)
if
isinstance
(
layer
,
Attention
)
and
exists
(
branch_fn
):
layer
=
branch_fn
(
layer
)
if
gate_residual
:
residual_fn
=
GRUGating
(
dim
)
else
:
residual_fn
=
Residual
()
self
.
layers
.
append
(
nn
.
ModuleList
([
norm_fn
(),
layer
,
residual_fn
]))
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
context_mask
=
None
,
mems
=
None
,
return_hiddens
=
False
):
hiddens
=
[]
intermediates
=
[]
prev_attn
=
None
prev_cross_attn
=
None
mems
=
mems
.
copy
()
if
exists
(
mems
)
else
[
None
]
*
self
.
num_attn_layers
for
ind
,
(
layer_type
,
(
norm
,
block
,
residual_fn
))
in
enumerate
(
zip
(
self
.
layer_types
,
self
.
layers
)):
is_last
=
ind
==
(
len
(
self
.
layers
)
-
1
)
if
layer_type
==
'a'
:
hiddens
.
append
(
x
)
layer_mem
=
mems
.
pop
(
0
)
residual
=
x
if
self
.
pre_norm
:
x
=
norm
(
x
)
if
layer_type
==
'a'
:
out
,
inter
=
block
(
x
,
mask
=
mask
,
sinusoidal_emb
=
self
.
pia_pos_emb
,
rel_pos
=
self
.
rel_pos
,
prev_attn
=
prev_attn
,
mem
=
layer_mem
)
elif
layer_type
==
'c'
:
out
,
inter
=
block
(
x
,
context
=
context
,
mask
=
mask
,
context_mask
=
context_mask
,
prev_attn
=
prev_cross_attn
)
elif
layer_type
==
'f'
:
out
=
block
(
x
)
x
=
residual_fn
(
out
,
residual
)
if
layer_type
in
(
'a'
,
'c'
):
intermediates
.
append
(
inter
)
if
layer_type
==
'a'
and
self
.
residual_attn
:
prev_attn
=
inter
.
pre_softmax_attn
elif
layer_type
==
'c'
and
self
.
cross_residual_attn
:
prev_cross_attn
=
inter
.
pre_softmax_attn
if
not
self
.
pre_norm
and
not
is_last
:
x
=
norm
(
x
)
if
return_hiddens
:
intermediates
=
LayerIntermediates
(
hiddens
=
hiddens
,
attn_intermediates
=
intermediates
)
return
x
,
intermediates
return
x
class
Encoder
(
AttentionLayers
):
def
__init__
(
self
,
**
kwargs
):
assert
'causal'
not
in
kwargs
,
'cannot set causality on encoder'
super
().
__init__
(
causal
=
False
,
**
kwargs
)
class
TransformerWrapper
(
nn
.
Module
):
def
__init__
(
self
,
*
,
num_tokens
,
max_seq_len
,
attn_layers
,
emb_dim
=
None
,
max_mem_len
=
0.
,
emb_dropout
=
0.
,
num_memory_tokens
=
None
,
tie_embedding
=
False
,
use_pos_emb
=
True
):
super
().
__init__
()
assert
isinstance
(
attn_layers
,
AttentionLayers
),
'attention layers must be one of Encoder or Decoder'
dim
=
attn_layers
.
dim
emb_dim
=
default
(
emb_dim
,
dim
)
self
.
max_seq_len
=
max_seq_len
self
.
max_mem_len
=
max_mem_len
self
.
num_tokens
=
num_tokens
self
.
token_emb
=
nn
.
Embedding
(
num_tokens
,
emb_dim
)
self
.
pos_emb
=
AbsolutePositionalEmbedding
(
emb_dim
,
max_seq_len
)
if
(
use_pos_emb
and
not
attn_layers
.
has_pos_emb
)
else
always
(
0
)
self
.
emb_dropout
=
nn
.
Dropout
(
emb_dropout
)
self
.
project_emb
=
nn
.
Linear
(
emb_dim
,
dim
)
if
emb_dim
!=
dim
else
nn
.
Identity
()
self
.
attn_layers
=
attn_layers
self
.
norm
=
nn
.
LayerNorm
(
dim
)
self
.
init_
()
self
.
to_logits
=
nn
.
Linear
(
dim
,
num_tokens
)
if
not
tie_embedding
else
lambda
t
:
t
@
self
.
token_emb
.
weight
.
t
()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens
=
default
(
num_memory_tokens
,
0
)
self
.
num_memory_tokens
=
num_memory_tokens
if
num_memory_tokens
>
0
:
self
.
memory_tokens
=
nn
.
Parameter
(
torch
.
randn
(
num_memory_tokens
,
dim
))
# let funnel encoder know number of memory tokens, if specified
if
hasattr
(
attn_layers
,
'num_memory_tokens'
):
attn_layers
.
num_memory_tokens
=
num_memory_tokens
def
init_
(
self
):
nn
.
init
.
normal_
(
self
.
token_emb
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
,
return_embeddings
=
False
,
mask
=
None
,
return_mems
=
False
,
return_attn
=
False
,
mems
=
None
,
**
kwargs
):
b
,
n
,
device
,
num_mem
=
*
x
.
shape
,
x
.
device
,
self
.
num_memory_tokens
x
=
self
.
token_emb
(
x
)
x
+=
self
.
pos_emb
(
x
)
x
=
self
.
emb_dropout
(
x
)
x
=
self
.
project_emb
(
x
)
if
num_mem
>
0
:
mem
=
repeat
(
self
.
memory_tokens
,
'n d -> b n d'
,
b
=
b
)
x
=
torch
.
cat
((
mem
,
x
),
dim
=
1
)
# auto-handle masking after appending memory tokens
if
exists
(
mask
):
mask
=
F
.
pad
(
mask
,
(
num_mem
,
0
),
value
=
True
)
x
,
intermediates
=
self
.
attn_layers
(
x
,
mask
=
mask
,
mems
=
mems
,
return_hiddens
=
True
,
**
kwargs
)
x
=
self
.
norm
(
x
)
mem
,
x
=
x
[:,
:
num_mem
],
x
[:,
num_mem
:]
out
=
self
.
to_logits
(
x
)
if
not
return_embeddings
else
x
if
return_mems
:
hiddens
=
intermediates
.
hiddens
new_mems
=
list
(
map
(
lambda
pair
:
torch
.
cat
(
pair
,
dim
=-
2
),
zip
(
mems
,
hiddens
)))
if
exists
(
mems
)
else
hiddens
new_mems
=
list
(
map
(
lambda
t
:
t
[...,
-
self
.
max_mem_len
:,
:].
detach
(),
new_mems
))
return
out
,
new_mems
if
return_attn
:
attn_maps
=
list
(
map
(
lambda
t
:
t
.
post_softmax_attn
,
intermediates
.
attn_intermediates
))
return
out
,
attn_maps
return
out
examples/images/diffusion/ldm/util.py
View file @
6c4c6a04
import
importlib
import
torch
from
torch
import
optim
import
numpy
as
np
from
collections
import
abc
from
einops
import
rearrange
from
functools
import
partial
import
multiprocessing
as
mp
from
threading
import
Thread
from
queue
import
Queue
from
inspect
import
isfunction
from
PIL
import
Image
,
ImageDraw
,
ImageFont
...
...
@@ -45,7 +39,7 @@ def ismap(x):
def
isimage
(
x
):
if
not
isinstance
(
x
,
torch
.
Tensor
):
if
not
isinstance
(
x
,
torch
.
Tensor
):
return
False
return
(
len
(
x
.
shape
)
==
4
)
and
(
x
.
shape
[
1
]
==
3
or
x
.
shape
[
1
]
==
1
)
...
...
@@ -71,7 +65,7 @@ def mean_flat(tensor):
def
count_params
(
model
,
verbose
=
False
):
total_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
if
verbose
:
print
(
f
"
{
model
.
__class__
.
__name__
}
has
{
total_params
*
1.e-6
:.
2
f
}
M params."
)
print
(
f
"
{
model
.
__class__
.
__name__
}
has
{
total_params
*
1.e-6
:.
2
f
}
M params."
)
return
total_params
...
...
@@ -93,111 +87,111 @@ def get_obj_from_str(string, reload=False):
return
getattr
(
importlib
.
import_module
(
module
,
package
=
None
),
cls
)
def
_do_parallel_data_prefetch
(
func
,
Q
,
data
,
idx
,
idx_to_fn
=
False
):
# cre
ate dummy dataset instance
# run prefetching
if
idx_to_fn
:
res
=
func
(
data
,
worker_id
=
idx
)
else
:
res
=
func
(
data
)
Q
.
put
([
idx
,
res
])
Q
.
put
(
"Done"
)
def
parallel_data_prefetch
(
func
:
callable
,
data
,
n_proc
,
target_data_type
=
"ndarray"
,
cpu_intensive
=
True
,
use_worker_id
=
False
)
:
# if target_data_type not in ["ndarray", "list"]:
#
raise ValueError(
#
"Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
#
)
if
isinstance
(
data
,
np
.
ndarray
)
and
target_data_type
==
"list"
:
raise
ValueError
(
"list expected but function got ndarray."
)
elif
isinstance
(
data
,
abc
.
Iterable
):
if
isinstance
(
data
,
dict
):
print
(
f
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data
=
list
(
data
.
values
()
)
if
target_data_type
==
"ndarray"
:
data
=
np
.
asarray
(
data
)
else
:
data
=
list
(
data
)
else
:
raise
TypeError
(
f
"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually
{
type
(
data
)
}
."
)
if
c
pu_intensiv
e
:
Q
=
mp
.
Queue
(
1000
)
proc
=
mp
.
Process
else
:
Q
=
Queue
(
1000
)
proc
=
Thread
# spawn processes
if
target_data_type
==
"ndarray"
:
argument
s
=
[
[
func
,
Q
,
part
,
i
,
use_worker_id
]
for
i
,
part
in
enumerate
(
np
.
array_split
(
data
,
n_proc
))
]
else
:
step
=
(
int
(
len
(
data
)
/
n_proc
+
1
)
if
len
(
data
)
%
n_proc
!=
0
e
lse
int
(
len
(
data
)
/
n_proc
)
)
arguments
=
[
[
func
,
Q
,
part
,
i
,
use_worker_id
]
for
i
,
part
in
enumerate
(
[
data
[
i
:
i
+
step
]
for
i
in
range
(
0
,
len
(
data
),
ste
p
)
]
)
]
processes
=
[]
for
i
in
range
(
n_proc
):
p
=
proc
(
target
=
_do_parallel_data_prefetch
,
args
=
arguments
[
i
])
processes
+=
[
p
]
# start processes
print
(
f
"Start prefetching..."
)
import
time
start
=
time
.
time
()
gather_res
=
[[]
for
_
in
range
(
n_proc
)]
try
:
for
p
in
process
es
:
p
.
start
(
)
k
=
0
while
k
<
n_proc
:
# get result
res
=
Q
.
get
(
)
if
res
==
"Done"
:
k
+=
1
else
:
gather_res
[
res
[
0
]]
=
res
[
1
]
except
Exception
as
e
:
print
(
"Exception: "
,
e
)
for
p
in
processes
:
p
.
terminate
(
)
raise
e
finally
:
for
p
in
processes
:
p
.
join
()
print
(
f
"Prefetching complete. [
{
time
.
time
()
-
start
}
sec.]"
)
if
target_data_type
==
'ndarray'
:
if
not
isinstance
(
gather_res
[
0
],
np
.
ndarray
):
return
np
.
concatenate
([
np
.
asarray
(
r
)
for
r
in
gather_res
],
axis
=
0
)
# order outputs
return
np
.
concatenate
(
gather_res
,
axis
=
0
)
elif
target_data_type
==
'list'
:
out
=
[]
for
r
in
gather_res
:
out
.
extend
(
r
)
return
out
else
:
return
gather_re
s
class
AdamWwithEMAandWings
(
optim
.
Optimizer
):
# cre
dit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
def
__init__
(
self
,
params
,
lr
=
1.e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1.e-8
,
# TODO: check hyperparameters before using
weight_decay
=
1.e-2
,
amsgrad
=
False
,
ema_decay
=
0.9999
,
# ema decay to match previous code
ema_power
=
1.
,
param_names
=
())
:
"""AdamW that saves EMA versions of the parameters."""
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
)
)
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
)
)
if
not
0.0
<=
betas
[
0
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter at index 0: {}"
.
format
(
betas
[
0
]))
if
not
0.0
<=
betas
[
1
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter at index 1: {}"
.
format
(
betas
[
1
]))
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
if
not
0.0
<=
ema_decay
<=
1.0
:
raise
ValueError
(
"Invalid ema_decay value: {}"
.
format
(
ema_decay
))
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
,
ema_decay
=
ema_decay
,
ema_power
=
ema_power
,
param_names
=
param_names
)
super
().
__init__
(
params
,
defaults
)
def
__setstate__
(
self
,
state
):
super
().
__setstate__
(
state
)
for
group
in
self
.
param_groups
:
group
.
setdefault
(
'amsgrad'
,
False
)
@
torch
.
no_grad
(
)
def
step
(
self
,
closure
=
None
)
:
"""Performs a single optimization step.
Args
:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
c
losure
is
not
Non
e
:
with
torch
.
enable_grad
():
loss
=
closure
()
for
group
in
self
.
param_groups
:
params_with_grad
=
[]
grads
=
[]
exp_avgs
=
[]
exp_avg_sq
s
=
[
]
ema_params_with_grad
=
[
]
state_sums
=
[]
max_exp_avg_sqs
=
[
]
state_steps
=
[]
amsgrad
=
group
[
'amsgrad'
]
beta1
,
beta2
=
group
[
'betas'
]
ema_decay
=
group
[
'ema_decay'
]
e
ma_power
=
group
[
'ema_power'
]
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
params_with_grad
.
append
(
p
)
if
p
.
grad
.
is_sparse
:
raise
RuntimeError
(
'AdamW does not support sparse gradients'
)
grads
.
append
(
p
.
grad
)
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
# Exponential moving average of gradient values
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
)
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
)
if
amsgrad
:
# Maintains max of all exp. moving avg. of sq. grad. valu
es
state
[
'max_exp_avg_sq'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
)
# Exponential moving average of parameter values
state
[
'param_exp_avg'
]
=
p
.
detach
().
float
().
clone
()
exp_avgs
.
append
(
state
[
'exp_avg'
])
exp_avg_sqs
.
append
(
state
[
'exp_avg_sq'
]
)
ema_params_with_grad
.
append
(
state
[
'param_exp_avg'
])
if
amsgrad
:
max_exp_avg_sqs
.
append
(
state
[
'max_exp_avg_sq'
])
# update the steps for each param group update
state
[
'step'
]
+=
1
# record the step after step update
state_steps
.
append
(
state
[
'step'
]
)
optim
.
_functional
.
adamw
(
params_with_grad
,
grads
,
exp_avgs
,
exp_avg_sqs
,
max_exp_avg_sqs
,
state_steps
,
amsgrad
=
amsgrad
,
beta1
=
beta1
,
beta2
=
beta2
,
lr
=
group
[
'lr'
],
weight_decay
=
group
[
'weight_decay'
],
eps
=
group
[
'eps'
],
maximize
=
False
)
cur_ema_decay
=
min
(
ema_decay
,
1
-
state
[
'step'
]
**
-
ema_power
)
for
param
,
ema_param
in
zip
(
params_with_grad
,
ema_params_with_grad
):
ema_param
.
mul_
(
cur_ema_decay
).
add_
(
param
.
float
(),
alpha
=
1
-
cur_ema_decay
)
return
los
s
\ No newline at end of file
examples/images/diffusion/main.py
View file @
6c4c6a04
import
argparse
,
os
,
sys
,
datetime
,
glob
,
importlib
,
csv
import
numpy
as
np
import
argparse
import
csv
import
datetime
import
glob
import
importlib
import
os
import
sys
import
time
import
numpy
as
np
import
torch
import
torchvision
import
lightning.pytorch
as
pl
from
packaging
import
version
from
omegaconf
import
OmegaConf
from
torch.utils.data
import
random_split
,
DataLoader
,
Dataset
,
Subset
try
:
import
lightning.pytorch
as
pl
except
:
import
pytorch_lightning
as
pl
from
functools
import
partial
from
omegaconf
import
OmegaConf
from
packaging
import
version
from
PIL
import
Image
# from lightning.pytorch.strategies.colossalai import ColossalAIStrategy
# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from
colossalai.nn.optimizer
import
HybridAdam
from
prefetch_generator
import
BackgroundGenerator
from
lightning.pytorch
import
seed_everything
from
lightning.pytorch.trainer
import
Trainer
from
lightning.pytorch.callbacks
import
ModelCheckpoint
,
Callback
,
LearningRateMonitor
from
lightning.pytorch.utilities.rank_zero
import
rank_zero_only
from
lightning.pytorch.utilities
import
rank_zero_info
from
diffusers.models.unet_2d
import
UNet2DModel
from
clip.model
import
Bottleneck
from
transformers.models.clip.modeling_clip
import
CLIPTextTransformer
from
torch.utils.data
import
DataLoader
,
Dataset
,
Subset
,
random_split
try
:
from
lightning.pytorch
import
seed_everything
from
lightning.pytorch.callbacks
import
Callback
,
LearningRateMonitor
,
ModelCheckpoint
from
lightning.pytorch.trainer
import
Trainer
from
lightning.pytorch.utilities
import
rank_zero_info
,
rank_zero_only
LIGHTNING_PACK_NAME
=
"lightning.pytorch."
except
:
from
pytorch_lightning
import
seed_everything
from
pytorch_lightning.callbacks
import
Callback
,
LearningRateMonitor
,
ModelCheckpoint
from
pytorch_lightning.trainer
import
Trainer
from
pytorch_lightning.utilities
import
rank_zero_info
,
rank_zero_only
LIGHTNING_PACK_NAME
=
"pytorch_lightning."
from
ldm.data.base
import
Txt2ImgIterableBaseDataset
from
ldm.util
import
instantiate_from_config
import
clip
from
einops
import
rearrange
,
repeat
from
transformers
import
CLIPTokenizer
,
CLIPTextModel
import
kornia
from
ldm.modules.x_transformer
import
*
from
ldm.modules.encoders.modules
import
*
from
taming.modules.diffusionmodules.model
import
ResnetBlock
from
taming.modules.transformer.mingpt
import
*
from
taming.modules.transformer.permuter
import
*
from
ldm.modules.ema
import
LitEma
from
ldm.modules.distributions.distributions
import
normal_kl
,
DiagonalGaussianDistribution
from
ldm.models.autoencoder
import
AutoencoderKL
from
ldm.models.autoencoder
import
*
from
ldm.models.diffusion.ddim
import
*
from
ldm.modules.diffusionmodules.openaimodel
import
*
from
ldm.modules.diffusionmodules.model
import
*
from
ldm.modules.diffusionmodules.model
import
Decoder
,
Encoder
,
Up_module
,
Down_module
,
Mid_module
,
temb_module
from
ldm.modules.attention
import
enable_flash_attention
# from ldm.modules.attention import enable_flash_attentions
class
DataLoaderX
(
DataLoader
):
...
...
@@ -56,6 +50,7 @@ class DataLoaderX(DataLoader):
def
get_parser
(
**
parser_kwargs
):
def
str2bool
(
v
):
if
isinstance
(
v
,
bool
):
return
v
...
...
@@ -91,7 +86,7 @@ def get_parser(**parser_kwargs):
nargs
=
"*"
,
metavar
=
"base_config.yaml"
,
help
=
"paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`."
,
"Parameters can be overwritten or added with command-line options of the form `--key value`."
,
default
=
list
(),
)
parser
.
add_argument
(
...
...
@@ -111,11 +106,7 @@ def get_parser(**parser_kwargs):
nargs
=
"?"
,
help
=
"disable test"
,
)
parser
.
add_argument
(
"-p"
,
"--project"
,
help
=
"name of new or path to existing project"
)
parser
.
add_argument
(
"-p"
,
"--project"
,
help
=
"name of new or path to existing project"
)
parser
.
add_argument
(
"-d"
,
"--debug"
,
...
...
@@ -210,8 +201,17 @@ def worker_init_fn(_):
class
DataModuleFromConfig
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
batch_size
,
train
=
None
,
validation
=
None
,
test
=
None
,
predict
=
None
,
wrap
=
False
,
num_workers
=
None
,
shuffle_test_loader
=
False
,
use_worker_init_fn
=
False
,
def
__init__
(
self
,
batch_size
,
train
=
None
,
validation
=
None
,
test
=
None
,
predict
=
None
,
wrap
=
False
,
num_workers
=
None
,
shuffle_test_loader
=
False
,
use_worker_init_fn
=
False
,
shuffle_val_dataloader
=
False
):
super
().
__init__
()
self
.
batch_size
=
batch_size
...
...
@@ -237,9 +237,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
instantiate_from_config
(
data_cfg
)
def
setup
(
self
,
stage
=
None
):
self
.
datasets
=
dict
(
(
k
,
instantiate_from_config
(
self
.
dataset_configs
[
k
]))
for
k
in
self
.
dataset_configs
)
self
.
datasets
=
dict
((
k
,
instantiate_from_config
(
self
.
dataset_configs
[
k
]))
for
k
in
self
.
dataset_configs
)
if
self
.
wrap
:
for
k
in
self
.
datasets
:
self
.
datasets
[
k
]
=
WrappedDataset
(
self
.
datasets
[
k
])
...
...
@@ -250,9 +248,11 @@ class DataModuleFromConfig(pl.LightningDataModule):
init_fn
=
worker_init_fn
else
:
init_fn
=
None
return
DataLoaderX
(
self
.
datasets
[
"train"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
shuffle
=
False
if
is_iterable_dataset
else
True
,
worker_init_fn
=
init_fn
)
return
DataLoaderX
(
self
.
datasets
[
"train"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
shuffle
=
False
if
is_iterable_dataset
else
True
,
worker_init_fn
=
init_fn
)
def
_val_dataloader
(
self
,
shuffle
=
False
):
if
isinstance
(
self
.
datasets
[
'validation'
],
Txt2ImgIterableBaseDataset
)
or
self
.
use_worker_init_fn
:
...
...
@@ -260,10 +260,10 @@ class DataModuleFromConfig(pl.LightningDataModule):
else
:
init_fn
=
None
return
DataLoaderX
(
self
.
datasets
[
"validation"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
,
shuffle
=
shuffle
)
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
,
shuffle
=
shuffle
)
def
_test_dataloader
(
self
,
shuffle
=
False
):
is_iterable_dataset
=
isinstance
(
self
.
datasets
[
'train'
],
Txt2ImgIterableBaseDataset
)
...
...
@@ -275,19 +275,25 @@ class DataModuleFromConfig(pl.LightningDataModule):
# do not shuffle dataloader for iterable dataset
shuffle
=
shuffle
and
(
not
is_iterable_dataset
)
return
DataLoaderX
(
self
.
datasets
[
"test"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
,
shuffle
=
shuffle
)
return
DataLoaderX
(
self
.
datasets
[
"test"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
,
shuffle
=
shuffle
)
def
_predict_dataloader
(
self
,
shuffle
=
False
):
if
isinstance
(
self
.
datasets
[
'predict'
],
Txt2ImgIterableBaseDataset
)
or
self
.
use_worker_init_fn
:
init_fn
=
worker_init_fn
else
:
init_fn
=
None
return
DataLoaderX
(
self
.
datasets
[
"predict"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
)
return
DataLoaderX
(
self
.
datasets
[
"predict"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
)
class
SetupCallback
(
Callback
):
def
__init__
(
self
,
resume
,
now
,
logdir
,
ckptdir
,
cfgdir
,
config
,
lightning_config
):
super
().
__init__
()
self
.
resume
=
resume
...
...
@@ -317,8 +323,7 @@ class SetupCallback(Callback):
os
.
makedirs
(
os
.
path
.
join
(
self
.
ckptdir
,
'trainstep_checkpoints'
),
exist_ok
=
True
)
print
(
"Project config"
)
print
(
OmegaConf
.
to_yaml
(
self
.
config
))
OmegaConf
.
save
(
self
.
config
,
os
.
path
.
join
(
self
.
cfgdir
,
"{}-project.yaml"
.
format
(
self
.
now
)))
OmegaConf
.
save
(
self
.
config
,
os
.
path
.
join
(
self
.
cfgdir
,
"{}-project.yaml"
.
format
(
self
.
now
)))
print
(
"Lightning config"
)
print
(
OmegaConf
.
to_yaml
(
self
.
lightning_config
))
...
...
@@ -338,8 +343,16 @@ class SetupCallback(Callback):
class
ImageLogger
(
Callback
):
def
__init__
(
self
,
batch_frequency
,
max_images
,
clamp
=
True
,
increase_log_steps
=
True
,
rescale
=
True
,
disabled
=
False
,
log_on_batch_idx
=
False
,
log_first_step
=
False
,
def
__init__
(
self
,
batch_frequency
,
max_images
,
clamp
=
True
,
increase_log_steps
=
True
,
rescale
=
True
,
disabled
=
False
,
log_on_batch_idx
=
False
,
log_first_step
=
False
,
log_images_kwargs
=
None
):
super
().
__init__
()
self
.
rescale
=
rescale
...
...
@@ -348,7 +361,7 @@ class ImageLogger(Callback):
self
.
logger_log_images
=
{
pl
.
loggers
.
CSVLogger
:
self
.
_testtube
,
}
self
.
log_steps
=
[
2
**
n
for
n
in
range
(
int
(
np
.
log2
(
self
.
batch_freq
))
+
1
)]
self
.
log_steps
=
[
2
**
n
for
n
in
range
(
int
(
np
.
log2
(
self
.
batch_freq
))
+
1
)]
if
not
increase_log_steps
:
self
.
log_steps
=
[
self
.
batch_freq
]
self
.
clamp
=
clamp
...
...
@@ -361,39 +374,30 @@ class ImageLogger(Callback):
def
_testtube
(
self
,
pl_module
,
images
,
batch_idx
,
split
):
for
k
in
images
:
grid
=
torchvision
.
utils
.
make_grid
(
images
[
k
])
grid
=
(
grid
+
1.0
)
/
2.0
# -1,1 -> 0,1; c,h,w
grid
=
(
grid
+
1.0
)
/
2.0
# -1,1 -> 0,1; c,h,w
tag
=
f
"
{
split
}
/
{
k
}
"
pl_module
.
logger
.
experiment
.
add_image
(
tag
,
grid
,
global_step
=
pl_module
.
global_step
)
pl_module
.
logger
.
experiment
.
add_image
(
tag
,
grid
,
global_step
=
pl_module
.
global_step
)
@
rank_zero_only
def
log_local
(
self
,
save_dir
,
split
,
images
,
global_step
,
current_epoch
,
batch_idx
):
def
log_local
(
self
,
save_dir
,
split
,
images
,
global_step
,
current_epoch
,
batch_idx
):
root
=
os
.
path
.
join
(
save_dir
,
"images"
,
split
)
for
k
in
images
:
grid
=
torchvision
.
utils
.
make_grid
(
images
[
k
],
nrow
=
4
)
if
self
.
rescale
:
grid
=
(
grid
+
1.0
)
/
2.0
# -1,1 -> 0,1; c,h,w
grid
=
(
grid
+
1.0
)
/
2.0
# -1,1 -> 0,1; c,h,w
grid
=
grid
.
transpose
(
0
,
1
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
grid
=
grid
.
numpy
()
grid
=
(
grid
*
255
).
astype
(
np
.
uint8
)
filename
=
"{}_gs-{:06}_e-{:06}_b-{:06}.png"
.
format
(
k
,
global_step
,
current_epoch
,
batch_idx
)
filename
=
"{}_gs-{:06}_e-{:06}_b-{:06}.png"
.
format
(
k
,
global_step
,
current_epoch
,
batch_idx
)
path
=
os
.
path
.
join
(
root
,
filename
)
os
.
makedirs
(
os
.
path
.
split
(
path
)[
0
],
exist_ok
=
True
)
Image
.
fromarray
(
grid
).
save
(
path
)
def
log_img
(
self
,
pl_module
,
batch
,
batch_idx
,
split
=
"train"
):
check_idx
=
batch_idx
if
self
.
log_on_batch_idx
else
pl_module
.
global_step
if
(
self
.
check_frequency
(
check_idx
)
and
# batch_idx % self.batch_freq == 0
hasattr
(
pl_module
,
"log_images"
)
and
callable
(
pl_module
.
log_images
)
and
self
.
max_images
>
0
):
if
(
self
.
check_frequency
(
check_idx
)
and
# batch_idx % self.batch_freq == 0
hasattr
(
pl_module
,
"log_images"
)
and
callable
(
pl_module
.
log_images
)
and
self
.
max_images
>
0
):
logger
=
type
(
pl_module
.
logger
)
is_train
=
pl_module
.
training
...
...
@@ -411,8 +415,8 @@ class ImageLogger(Callback):
if
self
.
clamp
:
images
[
k
]
=
torch
.
clamp
(
images
[
k
],
-
1.
,
1.
)
self
.
log_local
(
pl_module
.
logger
.
save_dir
,
split
,
images
,
pl_module
.
global_step
,
pl_module
.
current_epoch
,
batch_idx
)
self
.
log_local
(
pl_module
.
logger
.
save_dir
,
split
,
images
,
pl_module
.
global_step
,
pl_module
.
current_epoch
,
batch_idx
)
logger_log_images
=
self
.
logger_log_images
.
get
(
logger
,
lambda
*
args
,
**
kwargs
:
None
)
logger_log_images
(
pl_module
,
images
,
pl_module
.
global_step
,
split
)
...
...
@@ -421,8 +425,8 @@ class ImageLogger(Callback):
pl_module
.
train
()
def
check_frequency
(
self
,
check_idx
):
if
((
check_idx
%
self
.
batch_freq
)
==
0
or
(
check_idx
in
self
.
log_steps
))
and
(
check_idx
>
0
or
self
.
log_first_step
):
if
((
check_idx
%
self
.
batch_freq
)
==
0
or
(
check_idx
in
self
.
log_steps
))
and
(
check_idx
>
0
or
self
.
log_first_step
):
try
:
self
.
log_steps
.
pop
(
0
)
except
IndexError
as
e
:
...
...
@@ -461,7 +465,7 @@ class CUDACallback(Callback):
def
on_train_epoch_end
(
self
,
trainer
,
pl_module
):
torch
.
cuda
.
synchronize
(
trainer
.
strategy
.
root_device
.
index
)
max_memory
=
torch
.
cuda
.
max_memory_allocated
(
trainer
.
strategy
.
root_device
.
index
)
/
2
**
20
max_memory
=
torch
.
cuda
.
max_memory_allocated
(
trainer
.
strategy
.
root_device
.
index
)
/
2
**
20
epoch_time
=
time
.
time
()
-
self
.
start_time
try
:
...
...
@@ -528,13 +532,9 @@ if __name__ == "__main__":
opt
,
unknown
=
parser
.
parse_known_args
()
if
opt
.
name
and
opt
.
resume
:
raise
ValueError
(
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
if
opt
.
flash
:
enable_flash_attention
()
raise
ValueError
(
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
if
opt
.
resume
:
if
not
os
.
path
.
exists
(
opt
.
resume
):
raise
ValueError
(
"Cannot find {}"
.
format
(
opt
.
resume
))
...
...
@@ -578,7 +578,7 @@ if __name__ == "__main__":
lightning_config
=
config
.
pop
(
"lightning"
,
OmegaConf
.
create
())
# merge trainer cli with config
trainer_config
=
lightning_config
.
get
(
"trainer"
,
OmegaConf
.
create
())
for
k
in
nondefault_trainer_args
(
opt
):
trainer_config
[
k
]
=
getattr
(
opt
,
k
)
...
...
@@ -601,7 +601,7 @@ if __name__ == "__main__":
else
:
config
.
model
[
"params"
].
update
({
"use_fp16"
:
False
})
print
(
"Using FP16 = {}"
.
format
(
config
.
model
[
"params"
][
"use_fp16"
]))
model
=
instantiate_from_config
(
config
.
model
)
# trainer and callbacks
trainer_kwargs
=
dict
()
...
...
@@ -610,7 +610,7 @@ if __name__ == "__main__":
# default logger configs
default_logger_cfgs
=
{
"wandb"
:
{
"target"
:
"lightning.pytorch.
loggers.WandbLogger"
,
"target"
:
LIGHTNING_PACK_NAME
+
"
loggers.WandbLogger"
,
"params"
:
{
"name"
:
nowname
,
"save_dir"
:
logdir
,
...
...
@@ -618,9 +618,9 @@ if __name__ == "__main__":
"id"
:
nowname
,
}
},
"tensorboard"
:{
"target"
:
"lightning.pytorch.
loggers.TensorBoardLogger"
,
"params"
:{
"tensorboard"
:
{
"target"
:
LIGHTNING_PACK_NAME
+
"
loggers.TensorBoardLogger"
,
"params"
:
{
"save_dir"
:
logdir
,
"name"
:
"diff_tb"
,
"log_graph"
:
True
...
...
@@ -640,9 +640,10 @@ if __name__ == "__main__":
if
"strategy"
in
trainer_config
:
strategy_cfg
=
trainer_config
[
"strategy"
]
print
(
"Using strategy: {}"
.
format
(
strategy_cfg
[
"target"
]))
strategy_cfg
[
"target"
]
=
LIGHTNING_PACK_NAME
+
strategy_cfg
[
"target"
]
else
:
strategy_cfg
=
{
"target"
:
"lightning.pytorch.
strategies.DDPStrategy"
,
"target"
:
LIGHTNING_PACK_NAME
+
"
strategies.DDPStrategy"
,
"params"
:
{
"find_unused_parameters"
:
False
}
...
...
@@ -654,7 +655,7 @@ if __name__ == "__main__":
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg
=
{
"target"
:
"lightning.pytorch.
callbacks.ModelCheckpoint"
,
"target"
:
LIGHTNING_PACK_NAME
+
"
callbacks.ModelCheckpoint"
,
"params"
:
{
"dirpath"
:
ckptdir
,
"filename"
:
"{epoch:06}"
,
...
...
@@ -670,7 +671,7 @@ if __name__ == "__main__":
if
"modelcheckpoint"
in
lightning_config
:
modelckpt_cfg
=
lightning_config
.
modelcheckpoint
else
:
modelckpt_cfg
=
OmegaConf
.
create
()
modelckpt_cfg
=
OmegaConf
.
create
()
modelckpt_cfg
=
OmegaConf
.
merge
(
default_modelckpt_cfg
,
modelckpt_cfg
)
print
(
f
"Merged modelckpt-cfg:
\n
{
modelckpt_cfg
}
"
)
if
version
.
parse
(
pl
.
__version__
)
<
version
.
parse
(
'1.4.0'
):
...
...
@@ -702,7 +703,7 @@ if __name__ == "__main__":
"target"
:
"main.LearningRateMonitor"
,
"params"
:
{
"logging_interval"
:
"step"
,
# "log_momentum": True
# "log_momentum": True
}
},
"cuda_callback"
:
{
...
...
@@ -721,17 +722,17 @@ if __name__ == "__main__":
print
(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
)
default_metrics_over_trainsteps_ckpt_dict
=
{
'metrics_over_trainsteps_checkpoint'
:
{
"target"
:
'lightning.pytorch.
callbacks.ModelCheckpoint'
,
'params'
:
{
"dirpath"
:
os
.
path
.
join
(
ckptdir
,
'trainstep_checkpoints'
),
"filename"
:
"{epoch:06}-{step:09}"
,
"verbose"
:
True
,
'save_top_k'
:
-
1
,
'every_n_train_steps'
:
10000
,
'save_weights_only'
:
True
}
}
'metrics_over_trainsteps_checkpoint'
:
{
"target"
:
LIGHTNING_PACK_NAME
+
'
callbacks.ModelCheckpoint'
,
'params'
:
{
"dirpath"
:
os
.
path
.
join
(
ckptdir
,
'trainstep_checkpoints'
),
"filename"
:
"{epoch:06}-{step:09}"
,
"verbose"
:
True
,
'save_top_k'
:
-
1
,
'every_n_train_steps'
:
10000
,
'save_weights_only'
:
True
}
}
}
default_callbacks_cfg
.
update
(
default_metrics_over_trainsteps_ckpt_dict
)
...
...
@@ -744,7 +745,7 @@ if __name__ == "__main__":
trainer_kwargs
[
"callbacks"
]
=
[
instantiate_from_config
(
callbacks_cfg
[
k
])
for
k
in
callbacks_cfg
]
trainer
=
Trainer
.
from_argparse_args
(
trainer_opt
,
**
trainer_kwargs
)
trainer
.
logdir
=
logdir
###
trainer
.
logdir
=
logdir
###
# data
data
=
instantiate_from_config
(
config
.
data
)
...
...
@@ -772,14 +773,13 @@ if __name__ == "__main__":
if
opt
.
scale_lr
:
model
.
learning_rate
=
accumulate_grad_batches
*
ngpu
*
bs
*
base_lr
print
(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
.
format
(
model
.
learning_rate
,
accumulate_grad_batches
,
ngpu
,
bs
,
base_lr
))
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
.
format
(
model
.
learning_rate
,
accumulate_grad_batches
,
ngpu
,
bs
,
base_lr
))
else
:
model
.
learning_rate
=
base_lr
print
(
"++++ NOT USING LR SCALING ++++"
)
print
(
f
"Setting learning rate to
{
model
.
learning_rate
:.
2
e
}
"
)
# allow checkpointing via USR1
def
melk
(
*
args
,
**
kwargs
):
# run all checkpoint hooks
...
...
@@ -788,13 +788,11 @@ if __name__ == "__main__":
ckpt_path
=
os
.
path
.
join
(
ckptdir
,
"last.ckpt"
)
trainer
.
save_checkpoint
(
ckpt_path
)
def
divein
(
*
args
,
**
kwargs
):
if
trainer
.
global_rank
==
0
:
import
pudb
;
import
pudb
pudb
.
set_trace
()
import
signal
signal
.
signal
(
signal
.
SIGUSR1
,
melk
)
...
...
@@ -803,8 +801,6 @@ if __name__ == "__main__":
# run
if
opt
.
train
:
try
:
for
name
,
m
in
model
.
named_parameters
():
print
(
name
)
trainer
.
fit
(
model
,
data
)
except
Exception
:
melk
()
...
...
examples/images/diffusion/requirements.txt
View file @
6c4c6a04
albumentations==
0.4.3
diffusers
albumentations==
1.3.0
opencv-python
pudb==2019.2
datasets
invisible-watermark
prefetch_generator
imageio==2.9.0
imageio-ffmpeg==0.4.2
torchmetrics==0.6
omegaconf==2.1.1
multiprocess
lightning==1.8.1
test-tube>=0.7.5
streamlit>=0.73.1
einops==0.3.0
torch-fidelity==0.3.0
transformers==4.19.2
torchmetrics==0.6.0
kornia==0.6
opencv-python==4.6.0.66
prefetch_generator
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-e git+https://github.com/openai/CLIP.git@main#egg=clip
webdataset==0.2.5
open-clip-torch==2.7.0
gradio==3.11
datasets
-e .
examples/images/diffusion/scripts/img2img.py
View file @
6c4c6a04
"""make variations of input image"""
import
argparse
,
os
,
sys
,
glob
import
argparse
,
os
import
PIL
import
torch
import
numpy
as
np
...
...
@@ -12,12 +12,16 @@ from einops import rearrange, repeat
from
torchvision.utils
import
make_grid
from
torch
import
autocast
from
contextlib
import
nullcontext
import
time
from
lightning.pytorch
import
seed_everything
try
:
from
lightning.pytorch
import
seed_everything
except
:
from
pytorch_lightning
import
seed_everything
from
imwatermark
import
WatermarkEncoder
from
scripts.txt2img
import
put_watermark
from
ldm.util
import
instantiate_from_config
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
def
chunk
(
it
,
size
):
...
...
@@ -49,12 +53,12 @@ def load_img(path):
image
=
Image
.
open
(
path
).
convert
(
"RGB"
)
w
,
h
=
image
.
size
print
(
f
"loaded input image of size (
{
w
}
,
{
h
}
) from
{
path
}
"
)
w
,
h
=
map
(
lambda
x
:
x
-
x
%
32
,
(
w
,
h
))
# resize to integer multiple of
32
w
,
h
=
map
(
lambda
x
:
x
-
x
%
64
,
(
w
,
h
))
# resize to integer multiple of
64
image
=
image
.
resize
((
w
,
h
),
resample
=
PIL
.
Image
.
LANCZOS
)
image
=
np
.
array
(
image
).
astype
(
np
.
float32
)
/
255.0
image
=
image
[
None
].
transpose
(
0
,
3
,
1
,
2
)
image
=
torch
.
from_numpy
(
image
)
return
2.
*
image
-
1.
return
2.
*
image
-
1.
def
main
():
...
...
@@ -83,18 +87,6 @@ def main():
default
=
"outputs/img2img-samples"
)
parser
.
add_argument
(
"--skip_grid"
,
action
=
'store_true'
,
help
=
"do not save a grid, only individual samples. Helpful when evaluating lots of samples"
,
)
parser
.
add_argument
(
"--skip_save"
,
action
=
'store_true'
,
help
=
"do not save indiviual samples. For speed measurements."
,
)
parser
.
add_argument
(
"--ddim_steps"
,
type
=
int
,
...
...
@@ -102,11 +94,6 @@ def main():
help
=
"number of ddim sampling steps"
,
)
parser
.
add_argument
(
"--plms"
,
action
=
'store_true'
,
help
=
"use plms sampling"
,
)
parser
.
add_argument
(
"--fixed_code"
,
action
=
'store_true'
,
...
...
@@ -125,6 +112,7 @@ def main():
default
=
1
,
help
=
"sample this often"
,
)
parser
.
add_argument
(
"--C"
,
type
=
int
,
...
...
@@ -137,31 +125,35 @@ def main():
default
=
8
,
help
=
"downsampling factor, most often 8 or 16"
,
)
parser
.
add_argument
(
"--n_samples"
,
type
=
int
,
default
=
2
,
help
=
"how many samples to produce for each given prompt. A.k.a batch size"
,
)
parser
.
add_argument
(
"--n_rows"
,
type
=
int
,
default
=
0
,
help
=
"rows in the grid (default: n_samples)"
,
)
parser
.
add_argument
(
"--scale"
,
type
=
float
,
default
=
5
.0
,
default
=
9
.0
,
help
=
"unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))"
,
)
parser
.
add_argument
(
"--strength"
,
type
=
float
,
default
=
0.
75
,
default
=
0.
8
,
help
=
"strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image"
,
)
parser
.
add_argument
(
"--from-file"
,
type
=
str
,
...
...
@@ -170,13 +162,12 @@ def main():
parser
.
add_argument
(
"--config"
,
type
=
str
,
default
=
"configs/stable-diffusion/v
1
-inference.yaml"
,
default
=
"configs/stable-diffusion/v
2
-inference.yaml"
,
help
=
"path to config which constructs model"
,
)
parser
.
add_argument
(
"--ckpt"
,
type
=
str
,
default
=
"models/ldm/stable-diffusion-v1/model.ckpt"
,
help
=
"path to checkpoint of model"
,
)
parser
.
add_argument
(
...
...
@@ -202,15 +193,16 @@ def main():
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
model
=
model
.
to
(
device
)
if
opt
.
plms
:
raise
NotImplementedError
(
"PLMS sampler not (yet) supported"
)
sampler
=
PLMSSampler
(
model
)
else
:
sampler
=
DDIMSampler
(
model
)
sampler
=
DDIMSampler
(
model
)
os
.
makedirs
(
opt
.
outdir
,
exist_ok
=
True
)
outpath
=
opt
.
outdir
print
(
"Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)..."
)
wm
=
"SDV2"
wm_encoder
=
WatermarkEncoder
()
wm_encoder
.
set_watermark
(
'bytes'
,
wm
.
encode
(
'utf-8'
))
batch_size
=
opt
.
n_samples
n_rows
=
opt
.
n_rows
if
opt
.
n_rows
>
0
else
batch_size
if
not
opt
.
from_file
:
...
...
@@ -244,7 +236,6 @@ def main():
with
torch
.
no_grad
():
with
precision_scope
(
"cuda"
):
with
model
.
ema_scope
():
tic
=
time
.
time
()
all_samples
=
list
()
for
n
in
trange
(
opt
.
n_iter
,
desc
=
"Sampling"
):
for
prompts
in
tqdm
(
data
,
desc
=
"data"
):
...
...
@@ -256,37 +247,35 @@ def main():
c
=
model
.
get_learned_conditioning
(
prompts
)
# encode (scaled latent)
z_enc
=
sampler
.
stochastic_encode
(
init_latent
,
torch
.
tensor
([
t_enc
]
*
batch_size
).
to
(
device
))
z_enc
=
sampler
.
stochastic_encode
(
init_latent
,
torch
.
tensor
([
t_enc
]
*
batch_size
).
to
(
device
))
# decode it
samples
=
sampler
.
decode
(
z_enc
,
c
,
t_enc
,
unconditional_guidance_scale
=
opt
.
scale
,
unconditional_conditioning
=
uc
,)
unconditional_conditioning
=
uc
,
)
x_samples
=
model
.
decode_first_stage
(
samples
)
x_samples
=
torch
.
clamp
((
x_samples
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
if
not
opt
.
skip_save
:
for
x_sample
in
x_samples
:
x_sample
=
255.
*
re
arra
nge
(
x_sample
.
cpu
().
numpy
(),
'c h w -> h w c'
)
Image
.
fromarray
(
x_sample
.
astype
(
np
.
uint8
)).
save
(
os
.
path
.
join
(
sample_path
,
f
"
{
base_count
:
05
}
.png"
))
base_count
+=
1
for
x_sample
in
x_samples
:
x_sample
=
255.
*
rearrange
(
x_sample
.
cpu
().
numpy
(),
'c h w -> h w c'
)
img
=
Image
.
from
arra
y
(
x_sample
.
astype
(
np
.
uint8
)
)
img
=
put_watermark
(
img
,
wm_encoder
)
img
.
save
(
os
.
path
.
join
(
sample_path
,
f
"
{
base_count
:
05
}
.png"
))
base_count
+=
1
all_samples
.
append
(
x_samples
)
if
not
opt
.
skip_grid
:
# additionally, save as grid
grid
=
torch
.
stack
(
all_samples
,
0
)
grid
=
rearrange
(
grid
,
'n b c h w -> (n b) c h w'
)
grid
=
make_grid
(
grid
,
nrow
=
n_rows
)
# to image
grid
=
255.
*
rearrange
(
grid
,
'c h w -> h w c'
).
cpu
().
numpy
()
Image
.
fromarray
(
grid
.
astype
(
np
.
uint8
)).
save
(
os
.
path
.
join
(
outpath
,
f
'grid-
{
grid_count
:
04
}
.png'
))
grid_count
+=
1
# additionally, save as grid
grid
=
torch
.
stack
(
all_samples
,
0
)
grid
=
rearrange
(
grid
,
'n b c h w -> (n b) c h w'
)
grid
=
make_grid
(
grid
,
nrow
=
n_rows
)
toc
=
time
.
time
()
# to image
grid
=
255.
*
rearrange
(
grid
,
'c h w -> h w c'
).
cpu
().
numpy
()
grid
=
Image
.
fromarray
(
grid
.
astype
(
np
.
uint8
))
grid
=
put_watermark
(
grid
,
wm_encoder
)
grid
.
save
(
os
.
path
.
join
(
outpath
,
f
'grid-
{
grid_count
:
04
}
.png'
))
grid_count
+=
1
print
(
f
"Your samples are ready and waiting for you here:
\n
{
outpath
}
\n
"
f
"
\n
Enjoy."
)
print
(
f
"Your samples are ready and waiting for you here:
\n
{
outpath
}
\n
Enjoy."
)
if
__name__
==
"__main__"
:
...
...
examples/images/diffusion/scripts/txt2img.py
View file @
6c4c6a04
import
argparse
,
os
,
sys
,
glob
import
argparse
,
os
import
cv2
import
torch
import
numpy
as
np
from
omegaconf
import
OmegaConf
from
PIL
import
Image
from
tqdm
import
tqdm
,
trange
from
imwatermark
import
WatermarkEncoder
from
itertools
import
islice
from
einops
import
rearrange
from
torchvision.utils
import
make_grid
import
time
from
lightning.pytorch
import
seed_everything
try
:
from
lightning.pytorch
import
seed_everything
except
:
from
pytorch_lightning
import
seed_everything
from
torch
import
autocast
from
contextlib
import
contextmanager
,
nullcontext
from
contextlib
import
nullcontext
from
imwatermark
import
WatermarkEncoder
from
ldm.util
import
instantiate_from_config
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
from
ldm.models.diffusion.dpm_solver
import
DPMSolverSampler
from
diffusers.pipelines.stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
transformers
import
AutoFeatureExtractor
# load safety model
safety_model_id
=
"CompVis/stable-diffusion-safety-checker"
safety_feature_extractor
=
AutoFeatureExtractor
.
from_pretrained
(
safety_model_id
)
safety_checker
=
StableDiffusionSafetyChecker
.
from_pretrained
(
safety_model_id
)
torch
.
set_grad_enabled
(
False
)
def
chunk
(
it
,
size
):
it
=
iter
(
it
)
return
iter
(
lambda
:
tuple
(
islice
(
it
,
size
)),
())
def
numpy_to_pil
(
images
):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if
images
.
ndim
==
3
:
images
=
images
[
None
,
...]
images
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
pil_images
=
[
Image
.
fromarray
(
image
)
for
image
in
images
]
return
pil_images
def
load_model_from_config
(
config
,
ckpt
,
verbose
=
False
):
print
(
f
"Loading model from
{
ckpt
}
"
)
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
)
...
...
@@ -65,43 +48,13 @@ def load_model_from_config(config, ckpt, verbose=False):
return
model
def
put_watermark
(
img
,
wm_encoder
=
None
):
if
wm_encoder
is
not
None
:
img
=
cv2
.
cvtColor
(
np
.
array
(
img
),
cv2
.
COLOR_RGB2BGR
)
img
=
wm_encoder
.
encode
(
img
,
'dwtDct'
)
img
=
Image
.
fromarray
(
img
[:,
:,
::
-
1
])
return
img
def
load_replacement
(
x
):
try
:
hwc
=
x
.
shape
y
=
Image
.
open
(
"assets/rick.jpeg"
).
convert
(
"RGB"
).
resize
((
hwc
[
1
],
hwc
[
0
]))
y
=
(
np
.
array
(
y
)
/
255.0
).
astype
(
x
.
dtype
)
assert
y
.
shape
==
x
.
shape
return
y
except
Exception
:
return
x
def
check_safety
(
x_image
):
safety_checker_input
=
safety_feature_extractor
(
numpy_to_pil
(
x_image
),
return_tensors
=
"pt"
)
x_checked_image
,
has_nsfw_concept
=
safety_checker
(
images
=
x_image
,
clip_input
=
safety_checker_input
.
pixel_values
)
assert
x_checked_image
.
shape
[
0
]
==
len
(
has_nsfw_concept
)
for
i
in
range
(
len
(
has_nsfw_concept
)):
if
has_nsfw_concept
[
i
]:
x_checked_image
[
i
]
=
load_replacement
(
x_checked_image
[
i
])
return
x_checked_image
,
has_nsfw_concept
def
main
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
nargs
=
"?"
,
default
=
"a p
ainting of a virus monster playing guitar
"
,
default
=
"a p
rofessional photograph of an astronaut riding a triceratops
"
,
help
=
"the prompt to render"
)
parser
.
add_argument
(
...
...
@@ -112,17 +65,7 @@ def main():
default
=
"outputs/txt2img-samples"
)
parser
.
add_argument
(
"--skip_grid"
,
action
=
'store_true'
,
help
=
"do not save a grid, only individual samples. Helpful when evaluating lots of samples"
,
)
parser
.
add_argument
(
"--skip_save"
,
action
=
'store_true'
,
help
=
"do not save individual samples. For speed measurements."
,
)
parser
.
add_argument
(
"--ddim_steps"
,
"--steps"
,
type
=
int
,
default
=
50
,
help
=
"number of ddim sampling steps"
,
...
...
@@ -133,14 +76,14 @@ def main():
help
=
"use plms sampling"
,
)
parser
.
add_argument
(
"--
laion400
m"
,
"--
dp
m"
,
action
=
'store_true'
,
help
=
"use
s the LAION400M model
"
,
help
=
"use
DPM (2) sampler
"
,
)
parser
.
add_argument
(
"--fixed_code"
,
action
=
'store_true'
,
help
=
"if enabled, uses the same starting code across samples "
,
help
=
"if enabled, uses the same starting code across
all
samples "
,
)
parser
.
add_argument
(
"--ddim_eta"
,
...
...
@@ -151,7 +94,7 @@ def main():
parser
.
add_argument
(
"--n_iter"
,
type
=
int
,
default
=
2
,
default
=
3
,
help
=
"sample this often"
,
)
parser
.
add_argument
(
...
...
@@ -176,13 +119,13 @@ def main():
"--f"
,
type
=
int
,
default
=
8
,
help
=
"downsampling factor"
,
help
=
"downsampling factor
, most often 8 or 16
"
,
)
parser
.
add_argument
(
"--n_samples"
,
type
=
int
,
default
=
3
,
help
=
"how many samples to produce for each given prompt. A.k.a
.
batch size"
,
help
=
"how many samples to produce for each given prompt. A.k.a batch size"
,
)
parser
.
add_argument
(
"--n_rows"
,
...
...
@@ -193,24 +136,23 @@ def main():
parser
.
add_argument
(
"--scale"
,
type
=
float
,
default
=
7.5
,
default
=
9.0
,
help
=
"unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))"
,
)
parser
.
add_argument
(
"--from-file"
,
type
=
str
,
help
=
"if specified, load prompts from this file"
,
help
=
"if specified, load prompts from this file
, separated by newlines
"
,
)
parser
.
add_argument
(
"--config"
,
type
=
str
,
default
=
"configs/stable-diffusion/v
1
-inference.yaml"
,
default
=
"configs/stable-diffusion/v
2
-inference.yaml"
,
help
=
"path to config which constructs model"
,
)
parser
.
add_argument
(
"--ckpt"
,
type
=
str
,
default
=
"models/ldm/stable-diffusion-v1/model.ckpt"
,
help
=
"path to checkpoint of model"
,
)
parser
.
add_argument
(
...
...
@@ -226,14 +168,25 @@ def main():
choices
=
[
"full"
,
"autocast"
],
default
=
"autocast"
)
parser
.
add_argument
(
"--repeat"
,
type
=
int
,
default
=
1
,
help
=
"repeat each prompt in file this often"
,
)
opt
=
parser
.
parse_args
()
return
opt
if
opt
.
laion400m
:
print
(
"Falling back to LAION 400M model..."
)
opt
.
config
=
"configs/latent-diffusion/txt2img-1p4B-eval.yaml"
opt
.
ckpt
=
"models/ldm/text2img-large/model.ckpt"
opt
.
outdir
=
"outputs/txt2img-samples-laion400m"
def
put_watermark
(
img
,
wm_encoder
=
None
):
if
wm_encoder
is
not
None
:
img
=
cv2
.
cvtColor
(
np
.
array
(
img
),
cv2
.
COLOR_RGB2BGR
)
img
=
wm_encoder
.
encode
(
img
,
'dwtDct'
)
img
=
Image
.
fromarray
(
img
[:,
:,
::
-
1
])
return
img
def
main
(
opt
):
seed_everything
(
opt
.
seed
)
config
=
OmegaConf
.
load
(
f
"
{
opt
.
config
}
"
)
...
...
@@ -244,6 +197,8 @@ def main():
if
opt
.
plms
:
sampler
=
PLMSSampler
(
model
)
elif
opt
.
dpm
:
sampler
=
DPMSolverSampler
(
model
)
else
:
sampler
=
DDIMSampler
(
model
)
...
...
@@ -251,7 +206,7 @@ def main():
outpath
=
opt
.
outdir
print
(
"Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)..."
)
wm
=
"S
tableDiffusionV1
"
wm
=
"S
DV2
"
wm_encoder
=
WatermarkEncoder
()
wm_encoder
.
set_watermark
(
'bytes'
,
wm
.
encode
(
'utf-8'
))
...
...
@@ -266,10 +221,12 @@ def main():
print
(
f
"reading prompts from
{
opt
.
from_file
}
"
)
with
open
(
opt
.
from_file
,
"r"
)
as
f
:
data
=
f
.
read
().
splitlines
()
data
=
[
p
for
p
in
data
for
i
in
range
(
opt
.
repeat
)]
data
=
list
(
chunk
(
data
,
batch_size
))
sample_path
=
os
.
path
.
join
(
outpath
,
"samples"
)
os
.
makedirs
(
sample_path
,
exist_ok
=
True
)
sample_count
=
0
base_count
=
len
(
os
.
listdir
(
sample_path
))
grid_count
=
len
(
os
.
listdir
(
outpath
))
-
1
...
...
@@ -277,68 +234,59 @@ def main():
if
opt
.
fixed_code
:
start_code
=
torch
.
randn
([
opt
.
n_samples
,
opt
.
C
,
opt
.
H
//
opt
.
f
,
opt
.
W
//
opt
.
f
],
device
=
device
)
precision_scope
=
autocast
if
opt
.
precision
==
"autocast"
else
nullcontext
with
torch
.
no_grad
():
with
precision_scope
(
"cuda"
):
with
model
.
ema_scope
():
tic
=
time
.
time
()
all_samples
=
list
()
for
n
in
trange
(
opt
.
n_iter
,
desc
=
"Sampling"
):
for
prompts
in
tqdm
(
data
,
desc
=
"data"
):
uc
=
None
if
opt
.
scale
!=
1.0
:
uc
=
model
.
get_learned_conditioning
(
batch_size
*
[
""
])
if
isinstance
(
prompts
,
tuple
):
prompts
=
list
(
prompts
)
c
=
model
.
get_learned_conditioning
(
prompts
)
shape
=
[
opt
.
C
,
opt
.
H
//
opt
.
f
,
opt
.
W
//
opt
.
f
]
samples_ddim
,
_
=
sampler
.
sample
(
S
=
opt
.
ddim_steps
,
conditioning
=
c
,
batch_size
=
opt
.
n_samples
,
shape
=
shape
,
verbose
=
False
,
unconditional_guidance_scale
=
opt
.
scale
,
unconditional_conditioning
=
uc
,
eta
=
opt
.
ddim_eta
,
x_T
=
start_code
)
x_samples_ddim
=
model
.
decode_first_stage
(
samples_ddim
)
x_samples_ddim
=
torch
.
clamp
((
x_samples_ddim
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
x_samples_ddim
=
x_samples_ddim
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
x_checked_image
,
has_nsfw_concept
=
check_safety
(
x_samples_ddim
)
x_checked_image_torch
=
torch
.
from_numpy
(
x_checked_image
).
permute
(
0
,
3
,
1
,
2
)
if
not
opt
.
skip_save
:
for
x_sample
in
x_checked_image_torch
:
x_sample
=
255.
*
rearrange
(
x_sample
.
cpu
().
numpy
(),
'c h w -> h w c'
)
img
=
Image
.
fromarray
(
x_sample
.
astype
(
np
.
uint8
))
img
=
put_watermark
(
img
,
wm_encoder
)
img
.
save
(
os
.
path
.
join
(
sample_path
,
f
"
{
base_count
:
05
}
.png"
))
base_count
+=
1
if
not
opt
.
skip_grid
:
all_samples
.
append
(
x_checked_image_torch
)
if
not
opt
.
skip_grid
:
# additionally, save as grid
grid
=
torch
.
stack
(
all_samples
,
0
)
grid
=
rearrange
(
grid
,
'n b c h w -> (n b) c h w'
)
grid
=
make_grid
(
grid
,
nrow
=
n_rows
)
# to image
grid
=
255.
*
rearrange
(
grid
,
'c h w -> h w c'
).
cpu
().
numpy
()
img
=
Image
.
fromarray
(
grid
.
astype
(
np
.
uint8
))
img
=
put_watermark
(
img
,
wm_encoder
)
img
.
save
(
os
.
path
.
join
(
outpath
,
f
'grid-
{
grid_count
:
04
}
.png'
))
grid_count
+=
1
toc
=
time
.
time
()
precision_scope
=
autocast
if
opt
.
precision
==
"autocast"
else
nullcontext
with
torch
.
no_grad
(),
\
precision_scope
(
"cuda"
),
\
model
.
ema_scope
():
all_samples
=
list
()
for
n
in
trange
(
opt
.
n_iter
,
desc
=
"Sampling"
):
for
prompts
in
tqdm
(
data
,
desc
=
"data"
):
uc
=
None
if
opt
.
scale
!=
1.0
:
uc
=
model
.
get_learned_conditioning
(
batch_size
*
[
""
])
if
isinstance
(
prompts
,
tuple
):
prompts
=
list
(
prompts
)
c
=
model
.
get_learned_conditioning
(
prompts
)
shape
=
[
opt
.
C
,
opt
.
H
//
opt
.
f
,
opt
.
W
//
opt
.
f
]
samples
,
_
=
sampler
.
sample
(
S
=
opt
.
steps
,
conditioning
=
c
,
batch_size
=
opt
.
n_samples
,
shape
=
shape
,
verbose
=
False
,
unconditional_guidance_scale
=
opt
.
scale
,
unconditional_conditioning
=
uc
,
eta
=
opt
.
ddim_eta
,
x_T
=
start_code
)
x_samples
=
model
.
decode_first_stage
(
samples
)
x_samples
=
torch
.
clamp
((
x_samples
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
for
x_sample
in
x_samples
:
x_sample
=
255.
*
rearrange
(
x_sample
.
cpu
().
numpy
(),
'c h w -> h w c'
)
img
=
Image
.
fromarray
(
x_sample
.
astype
(
np
.
uint8
))
img
=
put_watermark
(
img
,
wm_encoder
)
img
.
save
(
os
.
path
.
join
(
sample_path
,
f
"
{
base_count
:
05
}
.png"
))
base_count
+=
1
sample_count
+=
1
all_samples
.
append
(
x_samples
)
# additionally, save as grid
grid
=
torch
.
stack
(
all_samples
,
0
)
grid
=
rearrange
(
grid
,
'n b c h w -> (n b) c h w'
)
grid
=
make_grid
(
grid
,
nrow
=
n_rows
)
# to image
grid
=
255.
*
rearrange
(
grid
,
'c h w -> h w c'
).
cpu
().
numpy
()
grid
=
Image
.
fromarray
(
grid
.
astype
(
np
.
uint8
))
grid
=
put_watermark
(
grid
,
wm_encoder
)
grid
.
save
(
os
.
path
.
join
(
outpath
,
f
'grid-
{
grid_count
:
04
}
.png'
))
grid_count
+=
1
print
(
f
"Your samples are ready and waiting for you here:
\n
{
outpath
}
\n
"
f
"
\n
Enjoy."
)
if
__name__
==
"__main__"
:
main
()
opt
=
parse_args
()
main
(
opt
)
examples/images/diffusion/train.sh
View file @
6c4c6a04
HF_DATASETS_OFFLINE
=
1
TRANSFORMERS_OFFLINE
=
1
# HF_DATASETS_OFFLINE=1
# TRANSFORMERS_OFFLINE=1
# DIFFUSERS_OFFLINE=1
python main.py
--logdir
/tmp
-
t
--postfix
tes
t
-b
configs/train_colossalai.yaml
python main.py
--logdir
/tmp
/
-t
-b
configs/
Teyvat/
train_colossalai
_teyvat
.yaml
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment