Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
DiffBIR_pytorch
Commits
30355a12
Commit
30355a12
authored
Sep 11, 2023
by
0x3f3f3f3fun
Browse files
add support for cpu
parent
7a5f8d70
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
131 additions
and
91 deletions
+131
-91
.gitignore
.gitignore
+1
-1
README.md
README.md
+10
-5
gradio_diffbir.py
gradio_diffbir.py
+5
-2
inference.py
inference.py
+61
-57
inference_face.py
inference_face.py
+8
-4
ldm/modules/attention.py
ldm/modules/attention.py
+12
-10
ldm/modules/diffusionmodules/model.py
ldm/modules/diffusionmodules/model.py
+12
-9
ldm/modules/encoders/modules.py
ldm/modules/encoders/modules.py
+5
-3
ldm/xformers_state.py
ldm/xformers_state.py
+17
-0
No files found.
.gitignore
View file @
30355a12
...
@@ -7,4 +7,4 @@ __pycache__
...
@@ -7,4 +7,4 @@ __pycache__
!install_env.sh
!install_env.sh
/weights
/weights
/temp
/temp
results
/
/
results
README.md
View file @
30355a12
...
@@ -95,7 +95,8 @@ python gradio_diffbir.py \
...
@@ -95,7 +95,8 @@ python gradio_diffbir.py \
--ckpt weights/general_full_v1.ckpt \
--ckpt weights/general_full_v1.ckpt \
--config configs/model/cldm.yaml \
--config configs/model/cldm.yaml \
--reload_swinir \
--reload_swinir \
--swinir_ckpt weights/general_swinir_v1.ckpt
--swinir_ckpt weights/general_swinir_v1.ckpt \
--device cuda
```
```
<div align="center">
<div align="center">
...
@@ -120,7 +121,8 @@ python inference.py \
...
@@ -120,7 +121,8 @@ python inference.py \
--sr_scale
4
\
--sr_scale
4
\
--image_size
512
\
--image_size
512
\
--color_fix_type
wavelet
--resize_back
\
--color_fix_type
wavelet
--resize_back
\
--output
results/general
--output
results/general
\
--device
cuda
```
```
If you are confused about where the
`reload_swinir`
option came from, please refer to the
[
degradation details
](
#degradation-details
)
.
If you are confused about where the
`reload_swinir`
option came from, please refer to the
[
degradation details
](
#degradation-details
)
.
...
@@ -139,7 +141,8 @@ python inference_face.py \
...
@@ -139,7 +141,8 @@ python inference_face.py \
--image_size
512
\
--image_size
512
\
--color_fix_type
wavelet
\
--color_fix_type
wavelet
\
--output
results/face/aligned
--resize_back
\
--output
results/face/aligned
--resize_back
\
--has_aligned
--has_aligned
\
--device
cuda
# for unaligned face inputs
# for unaligned face inputs
python inference_face.py
\
python inference_face.py
\
...
@@ -150,7 +153,8 @@ python inference_face.py \
...
@@ -150,7 +153,8 @@ python inference_face.py \
--sr_scale
1
\
--sr_scale
1
\
--image_size
512
\
--image_size
512
\
--color_fix_type
wavelet
\
--color_fix_type
wavelet
\
--output
results/face/whole_img
--resize_back
--output
results/face/whole_img
--resize_back
\
--device
cuda
```
```
### Only Stage1 Model (Remove Degradations)
### Only Stage1 Model (Remove Degradations)
...
@@ -181,7 +185,8 @@ python inference.py \
...
@@ -181,7 +185,8 @@ python inference.py \
--input
[
img_dir_path]
\
--input
[
img_dir_path]
\
--color_fix_type
wavelet
--resize_back
\
--color_fix_type
wavelet
--resize_back
\
--output
[
output_dir_path]
\
--output
[
output_dir_path]
\
--disable_preprocess_model
--disable_preprocess_model
\
--device
cuda
```
```
## <a name="train"></a>:stars:Train
## <a name="train"></a>:stars:Train
...
...
gradio_diffbir.py
View file @
30355a12
...
@@ -10,6 +10,7 @@ import gradio as gr
...
@@ -10,6 +10,7 @@ import gradio as gr
from
PIL
import
Image
from
PIL
import
Image
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
from
ldm.xformers_state
import
disable_xformers
from
model.spaced_sampler
import
SpacedSampler
from
model.spaced_sampler
import
SpacedSampler
from
model.cldm
import
ControlLDM
from
model.cldm
import
ControlLDM
from
utils.image
import
(
from
utils.image
import
(
...
@@ -23,10 +24,12 @@ parser.add_argument("--config", required=True, type=str)
...
@@ -23,10 +24,12 @@ parser.add_argument("--config", required=True, type=str)
parser
.
add_argument
(
"--ckpt"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--ckpt"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--reload_swinir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--reload_swinir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--swinir_ckpt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--swinir_ckpt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cpu"
,
"cuda"
])
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# load model
# load model
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
if
args
.
device
==
"cpu"
:
disable_xformers
()
model
:
ControlLDM
=
instantiate_from_config
(
OmegaConf
.
load
(
args
.
config
))
model
:
ControlLDM
=
instantiate_from_config
(
OmegaConf
.
load
(
args
.
config
))
load_state_dict
(
model
,
torch
.
load
(
args
.
ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
load_state_dict
(
model
,
torch
.
load
(
args
.
ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
# reload preprocess model if specified
# reload preprocess model if specified
...
@@ -34,7 +37,7 @@ if args.reload_swinir:
...
@@ -34,7 +37,7 @@ if args.reload_swinir:
print
(
f
"reload swinir model from
{
args
.
swinir_ckpt
}
"
)
print
(
f
"reload swinir model from
{
args
.
swinir_ckpt
}
"
)
load_state_dict
(
model
.
preprocess_model
,
torch
.
load
(
args
.
swinir_ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
load_state_dict
(
model
.
preprocess_model
,
torch
.
load
(
args
.
swinir_ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
model
.
freeze
()
model
.
freeze
()
model
.
to
(
device
)
model
.
to
(
args
.
device
)
# load sampler
# load sampler
sampler
=
SpacedSampler
(
model
,
var_type
=
"fixed_small"
)
sampler
=
SpacedSampler
(
model
,
var_type
=
"fixed_small"
)
...
...
inference.py
View file @
30355a12
...
@@ -10,6 +10,7 @@ import pytorch_lightning as pl
...
@@ -10,6 +10,7 @@ import pytorch_lightning as pl
from
PIL
import
Image
from
PIL
import
Image
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
from
ldm.xformers_state
import
disable_xformers
from
model.spaced_sampler
import
SpacedSampler
from
model.spaced_sampler
import
SpacedSampler
from
model.ddim_sampler
import
DDIMSampler
from
model.ddim_sampler
import
DDIMSampler
from
model.cldm
import
ControlLDM
from
model.cldm
import
ControlLDM
...
@@ -127,6 +128,7 @@ def parse_args() -> Namespace:
...
@@ -127,6 +128,7 @@ def parse_args() -> Namespace:
parser
.
add_argument
(
"--skip_if_exist"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--skip_if_exist"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
231
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
231
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cpu"
,
"cuda"
])
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -134,7 +136,9 @@ def parse_args() -> Namespace:
...
@@ -134,7 +136,9 @@ def parse_args() -> Namespace:
def
main
()
->
None
:
def
main
()
->
None
:
args
=
parse_args
()
args
=
parse_args
()
pl
.
seed_everything
(
args
.
seed
)
pl
.
seed_everything
(
args
.
seed
)
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
if
args
.
device
==
"cpu"
:
disable_xformers
()
model
:
ControlLDM
=
instantiate_from_config
(
OmegaConf
.
load
(
args
.
config
))
model
:
ControlLDM
=
instantiate_from_config
(
OmegaConf
.
load
(
args
.
config
))
load_state_dict
(
model
,
torch
.
load
(
args
.
ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
load_state_dict
(
model
,
torch
.
load
(
args
.
ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
...
@@ -145,68 +149,68 @@ def main() -> None:
...
@@ -145,68 +149,68 @@ def main() -> None:
print
(
f
"reload swinir model from
{
args
.
swinir_ckpt
}
"
)
print
(
f
"reload swinir model from
{
args
.
swinir_ckpt
}
"
)
load_state_dict
(
model
.
preprocess_model
,
torch
.
load
(
args
.
swinir_ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
load_state_dict
(
model
.
preprocess_model
,
torch
.
load
(
args
.
swinir_ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
model
.
freeze
()
model
.
freeze
()
model
.
to
(
device
)
model
.
to
(
args
.
device
)
assert
os
.
path
.
isdir
(
args
.
input
)
assert
os
.
path
.
isdir
(
args
.
input
)
print
(
f
"sampling
{
args
.
steps
}
steps using
{
args
.
sampler
}
sampler"
)
print
(
f
"sampling
{
args
.
steps
}
steps using
{
args
.
sampler
}
sampler"
)
with
torch
.
autocast
(
device
):
# with torch.autocast(device, dtype=torch.bfloat16):
for
file_path
in
list_image_files
(
args
.
input
,
follow_links
=
True
):
for
file_path
in
list_image_files
(
args
.
input
,
follow_links
=
True
):
lq
=
Image
.
open
(
file_path
).
convert
(
"RGB"
)
lq
=
Image
.
open
(
file_path
).
convert
(
"RGB"
)
if
args
.
sr_scale
!=
1
:
if
args
.
sr_scale
!=
1
:
lq
=
lq
.
resize
(
lq
=
lq
.
resize
(
tuple
(
math
.
ceil
(
x
*
args
.
sr_scale
)
for
x
in
lq
.
size
),
tuple
(
math
.
ceil
(
x
*
args
.
sr_scale
)
for
x
in
lq
.
size
),
Image
.
BICUBIC
Image
.
BICUBIC
)
)
lq_resized
=
auto_resize
(
lq
,
args
.
image_size
)
lq_resized
=
auto_resize
(
lq
,
args
.
image_size
)
x
=
pad
(
np
.
array
(
lq_resized
),
scale
=
64
)
x
=
pad
(
np
.
array
(
lq_resized
),
scale
=
64
)
for
i
in
range
(
args
.
repeat_times
):
for
i
in
range
(
args
.
repeat_times
):
save_path
=
os
.
path
.
join
(
args
.
output
,
os
.
path
.
relpath
(
file_path
,
args
.
input
))
save_path
=
os
.
path
.
join
(
args
.
output
,
os
.
path
.
relpath
(
file_path
,
args
.
input
))
parent_path
,
stem
,
_
=
get_file_name_parts
(
save_path
)
parent_path
,
stem
,
_
=
get_file_name_parts
(
save_path
)
save_path
=
os
.
path
.
join
(
parent_path
,
f
"
{
stem
}
_
{
i
}
.png"
)
save_path
=
os
.
path
.
join
(
parent_path
,
f
"
{
stem
}
_
{
i
}
.png"
)
if
os
.
path
.
exists
(
save_path
):
if
os
.
path
.
exists
(
save_path
):
if
args
.
skip_if_exist
:
if
args
.
skip_if_exist
:
print
(
f
"skip
{
save_path
}
"
)
print
(
f
"skip
{
save_path
}
"
)
continue
else
:
raise
RuntimeError
(
f
"
{
save_path
}
already exist"
)
os
.
makedirs
(
parent_path
,
exist_ok
=
True
)
try
:
preds
,
stage1_preds
=
process
(
model
,
[
x
],
steps
=
args
.
steps
,
sampler
=
args
.
sampler
,
strength
=
1
,
color_fix_type
=
args
.
color_fix_type
,
disable_preprocess_model
=
args
.
disable_preprocess_model
)
except
RuntimeError
as
e
:
# Avoid cuda_out_of_memory error.
print
(
f
"
{
file_path
}
, error:
{
e
}
"
)
continue
continue
pred
,
stage1_pred
=
preds
[
0
],
stage1_preds
[
0
]
# remove padding
pred
=
pred
[:
lq_resized
.
height
,
:
lq_resized
.
width
,
:]
stage1_pred
=
stage1_pred
[:
lq_resized
.
height
,
:
lq_resized
.
width
,
:]
if
args
.
show_lq
:
if
args
.
resize_back
:
if
lq_resized
.
size
!=
lq
.
size
:
pred
=
np
.
array
(
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
stage1_pred
=
np
.
array
(
Image
.
fromarray
(
stage1_pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
lq
=
np
.
array
(
lq
)
else
:
lq
=
np
.
array
(
lq_resized
)
images
=
[
lq
,
pred
]
if
args
.
disable_preprocess_model
else
[
lq
,
stage1_pred
,
pred
]
Image
.
fromarray
(
np
.
concatenate
(
images
,
axis
=
1
)).
save
(
save_path
)
else
:
else
:
if
args
.
resize_back
and
lq_resized
.
size
!=
lq
.
size
:
raise
RuntimeError
(
f
"
{
save_path
}
already exist"
)
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
).
save
(
save_path
)
os
.
makedirs
(
parent_path
,
exist_ok
=
True
)
else
:
Image
.
fromarray
(
pred
).
save
(
save_path
)
# try:
print
(
f
"save to
{
save_path
}
"
)
preds
,
stage1_preds
=
process
(
model
,
[
x
],
steps
=
args
.
steps
,
sampler
=
args
.
sampler
,
strength
=
1
,
color_fix_type
=
args
.
color_fix_type
,
disable_preprocess_model
=
args
.
disable_preprocess_model
)
# except RuntimeError as e:
# # Avoid cuda_out_of_memory error.
# print(f"{file_path}, error: {e}")
# continue
pred
,
stage1_pred
=
preds
[
0
],
stage1_preds
[
0
]
# remove padding
pred
=
pred
[:
lq_resized
.
height
,
:
lq_resized
.
width
,
:]
stage1_pred
=
stage1_pred
[:
lq_resized
.
height
,
:
lq_resized
.
width
,
:]
if
args
.
show_lq
:
if
args
.
resize_back
:
if
lq_resized
.
size
!=
lq
.
size
:
pred
=
np
.
array
(
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
stage1_pred
=
np
.
array
(
Image
.
fromarray
(
stage1_pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
))
lq
=
np
.
array
(
lq
)
else
:
lq
=
np
.
array
(
lq_resized
)
images
=
[
lq
,
pred
]
if
args
.
disable_preprocess_model
else
[
lq
,
stage1_pred
,
pred
]
Image
.
fromarray
(
np
.
concatenate
(
images
,
axis
=
1
)).
save
(
save_path
)
else
:
if
args
.
resize_back
and
lq_resized
.
size
!=
lq
.
size
:
Image
.
fromarray
(
pred
).
resize
(
lq
.
size
,
Image
.
LANCZOS
).
save
(
save_path
)
else
:
Image
.
fromarray
(
pred
).
save
(
save_path
)
print
(
f
"save to
{
save_path
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
inference_face.py
View file @
30355a12
...
@@ -10,6 +10,7 @@ from argparse import ArgumentParser, Namespace
...
@@ -10,6 +10,7 @@ from argparse import ArgumentParser, Namespace
from
facexlib.utils.face_restoration_helper
import
FaceRestoreHelper
from
facexlib.utils.face_restoration_helper
import
FaceRestoreHelper
from
ldm.xformers_state
import
disable_xformers
from
model.cldm
import
ControlLDM
from
model.cldm
import
ControlLDM
from
model.ddim_sampler
import
DDIMSampler
from
model.ddim_sampler
import
DDIMSampler
from
model.spaced_sampler
import
SpacedSampler
from
model.spaced_sampler
import
SpacedSampler
...
@@ -56,6 +57,7 @@ def parse_args() -> Namespace:
...
@@ -56,6 +57,7 @@ def parse_args() -> Namespace:
parser
.
add_argument
(
"--skip_if_exist"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--skip_if_exist"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
231
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
231
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cpu"
,
"cuda"
])
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -64,7 +66,9 @@ def main() -> None:
...
@@ -64,7 +66,9 @@ def main() -> None:
args
=
parse_args
()
args
=
parse_args
()
img_save_ext
=
'png'
img_save_ext
=
'png'
pl
.
seed_everything
(
args
.
seed
)
pl
.
seed_everything
(
args
.
seed
)
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
if
args
.
device
==
"cpu"
:
disable_xformers
()
model
:
ControlLDM
=
instantiate_from_config
(
OmegaConf
.
load
(
args
.
config
))
model
:
ControlLDM
=
instantiate_from_config
(
OmegaConf
.
load
(
args
.
config
))
load_state_dict
(
model
,
torch
.
load
(
args
.
ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
load_state_dict
(
model
,
torch
.
load
(
args
.
ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
...
@@ -75,13 +79,13 @@ def main() -> None:
...
@@ -75,13 +79,13 @@ def main() -> None:
print
(
f
"reload swinir model from
{
args
.
swinir_ckpt
}
"
)
print
(
f
"reload swinir model from
{
args
.
swinir_ckpt
}
"
)
load_state_dict
(
model
.
preprocess_model
,
torch
.
load
(
args
.
swinir_ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
load_state_dict
(
model
.
preprocess_model
,
torch
.
load
(
args
.
swinir_ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
model
.
freeze
()
model
.
freeze
()
model
.
to
(
device
)
model
.
to
(
args
.
device
)
assert
os
.
path
.
isdir
(
args
.
input
)
assert
os
.
path
.
isdir
(
args
.
input
)
# ------------------ set up FaceRestoreHelper -------------------
# ------------------ set up FaceRestoreHelper -------------------
face_helper
=
FaceRestoreHelper
(
face_helper
=
FaceRestoreHelper
(
device
=
device
,
device
=
args
.
device
,
upscale_factor
=
1
,
upscale_factor
=
1
,
face_size
=
args
.
image_size
,
face_size
=
args
.
image_size
,
use_parse
=
True
,
use_parse
=
True
,
...
@@ -186,4 +190,4 @@ def main() -> None:
...
@@ -186,4 +190,4 @@ def main() -> None:
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
\ No newline at end of file
ldm/modules/attention.py
View file @
30355a12
...
@@ -7,14 +7,14 @@ from einops import rearrange, repeat
...
@@ -7,14 +7,14 @@ from einops import rearrange, repeat
from
typing
import
Optional
,
Any
from
typing
import
Optional
,
Any
from
ldm.modules.diffusionmodules.util
import
checkpoint
from
ldm.modules.diffusionmodules.util
import
checkpoint
from
ldm
import
xformers_state
# try:
try
:
# import xformers
import
xformers
# import xformers.ops
import
xformers.ops
# XFORMERS_IS_AVAILBLE = True
XFORMERS_IS_AVAILBLE
=
True
# except:
except
:
# XFORMERS_IS_AVAILBLE = False
XFORMERS_IS_AVAILBLE
=
False
# CrossAttn precision handling
# CrossAttn precision handling
import
os
import
os
...
@@ -172,7 +172,8 @@ class CrossAttention(nn.Module):
...
@@ -172,7 +172,8 @@ class CrossAttention(nn.Module):
# force cast to fp32 to avoid overflowing
# force cast to fp32 to avoid overflowing
if
_ATTN_PRECISION
==
"fp32"
:
if
_ATTN_PRECISION
==
"fp32"
:
with
torch
.
autocast
(
enabled
=
False
,
device_type
=
'cuda'
):
# with torch.autocast(enabled=False, device_type = 'cuda'):
with
torch
.
autocast
(
enabled
=
False
,
device_type
=
str
(
x
.
device
)):
q
,
k
=
q
.
float
(),
k
.
float
()
q
,
k
=
q
.
float
(),
k
.
float
()
sim
=
einsum
(
'b i d, b j d -> b i j'
,
q
,
k
)
*
self
.
scale
sim
=
einsum
(
'b i d, b j d -> b i j'
,
q
,
k
)
*
self
.
scale
else
:
else
:
...
@@ -230,7 +231,7 @@ class MemoryEfficientCrossAttention(nn.Module):
...
@@ -230,7 +231,7 @@ class MemoryEfficientCrossAttention(nn.Module):
)
)
# actually compute the attention, what we cannot get enough of
# actually compute the attention, what we cannot get enough of
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
out
=
xformers_state
.
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
if
exists
(
mask
):
if
exists
(
mask
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -251,7 +252,8 @@ class BasicTransformerBlock(nn.Module):
...
@@ -251,7 +252,8 @@ class BasicTransformerBlock(nn.Module):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
):
disable_self_attn
=
False
):
super
().
__init__
()
super
().
__init__
()
attn_mode
=
"softmax-xformers"
if
XFORMERS_IS_AVAILBLE
else
"softmax"
# attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
attn_mode
=
"softmax-xformers"
if
xformers_state
.
is_xformers_available
()
else
"softmax"
assert
attn_mode
in
self
.
ATTENTION_MODES
assert
attn_mode
in
self
.
ATTENTION_MODES
attn_cls
=
self
.
ATTENTION_MODES
[
attn_mode
]
attn_cls
=
self
.
ATTENTION_MODES
[
attn_mode
]
self
.
disable_self_attn
=
disable_self_attn
self
.
disable_self_attn
=
disable_self_attn
...
...
ldm/modules/diffusionmodules/model.py
View file @
30355a12
...
@@ -7,14 +7,16 @@ from einops import rearrange
...
@@ -7,14 +7,16 @@ from einops import rearrange
from
typing
import
Optional
,
Any
from
typing
import
Optional
,
Any
from
ldm.modules.attention
import
MemoryEfficientCrossAttention
from
ldm.modules.attention
import
MemoryEfficientCrossAttention
from
ldm
import
xformers_state
try
:
import
xformers
# try:
import
xformers.ops
# import xformers
XFORMERS_IS_AVAILBLE
=
True
# import xformers.ops
except
:
# XFORMERS_IS_AVAILBLE = True
XFORMERS_IS_AVAILBLE
=
False
# except:
print
(
"No module 'xformers'. Proceeding without it."
)
# XFORMERS_IS_AVAILBLE = False
# print("No module 'xformers'. Proceeding without it.")
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
...
@@ -255,7 +257,7 @@ class MemoryEfficientAttnBlock(nn.Module):
...
@@ -255,7 +257,7 @@ class MemoryEfficientAttnBlock(nn.Module):
.
contiguous
(),
.
contiguous
(),
(
q
,
k
,
v
),
(
q
,
k
,
v
),
)
)
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
out
=
xformers_state
.
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
out
=
(
out
=
(
out
.
unsqueeze
(
0
)
out
.
unsqueeze
(
0
)
...
@@ -279,7 +281,8 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
...
@@ -279,7 +281,8 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def
make_attn
(
in_channels
,
attn_type
=
"vanilla"
,
attn_kwargs
=
None
):
def
make_attn
(
in_channels
,
attn_type
=
"vanilla"
,
attn_kwargs
=
None
):
assert
attn_type
in
[
"vanilla"
,
"vanilla-xformers"
,
"memory-efficient-cross-attn"
,
"linear"
,
"none"
],
f
'attn_type
{
attn_type
}
unknown'
assert
attn_type
in
[
"vanilla"
,
"vanilla-xformers"
,
"memory-efficient-cross-attn"
,
"linear"
,
"none"
],
f
'attn_type
{
attn_type
}
unknown'
if
XFORMERS_IS_AVAILBLE
and
attn_type
==
"vanilla"
:
# if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
if
xformers_state
.
is_xformers_available
()
and
attn_type
==
"vanilla"
:
attn_type
=
"vanilla-xformers"
attn_type
=
"vanilla-xformers"
print
(
f
"making attention of type '
{
attn_type
}
' with
{
in_channels
}
in_channels"
)
print
(
f
"making attention of type '
{
attn_type
}
' with
{
in_channels
}
in_channels"
)
if
attn_type
==
"vanilla"
:
if
attn_type
==
"vanilla"
:
...
...
ldm/modules/encoders/modules.py
View file @
30355a12
...
@@ -140,7 +140,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
...
@@ -140,7 +140,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
"last"
,
"last"
,
"penultimate"
"penultimate"
]
]
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
version
=
"laion2b_s32b_b79k"
,
device
=
"cuda"
,
max_length
=
77
,
# def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
version
=
"laion2b_s32b_b79k"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
):
freeze
=
True
,
layer
=
"last"
):
super
().
__init__
()
super
().
__init__
()
assert
layer
in
self
.
LAYERS
assert
layer
in
self
.
LAYERS
...
@@ -148,7 +149,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
...
@@ -148,7 +149,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
del
model
.
visual
del
model
.
visual
self
.
model
=
model
self
.
model
=
model
self
.
device
=
device
#
self.device = device
self
.
max_length
=
max_length
self
.
max_length
=
max_length
if
freeze
:
if
freeze
:
self
.
freeze
()
self
.
freeze
()
...
@@ -167,7 +168,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
...
@@ -167,7 +168,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
def
forward
(
self
,
text
):
def
forward
(
self
,
text
):
tokens
=
open_clip
.
tokenize
(
text
)
tokens
=
open_clip
.
tokenize
(
text
)
z
=
self
.
encode_with_transformer
(
tokens
.
to
(
self
.
device
))
# z = self.encode_with_transformer(tokens.to(self.device))
z
=
self
.
encode_with_transformer
(
tokens
.
to
(
next
(
self
.
model
.
parameters
()).
device
))
return
z
return
z
def
encode_with_transformer
(
self
,
text
):
def
encode_with_transformer
(
self
,
text
):
...
...
ldm/xformers_state.py
0 → 100644
View file @
30355a12
try
:
import
xformers
import
xformers.ops
XFORMERS_IS_AVAILBLE
=
True
except
:
XFORMERS_IS_AVAILBLE
=
False
print
(
"No module 'xformers'. Proceeding without it."
)
def
is_xformers_available
()
->
bool
:
global
XFORMERS_IS_AVAILBLE
return
XFORMERS_IS_AVAILBLE
def
disable_xformers
()
->
None
:
print
(
"DISABLE XFORMERS!"
)
global
XFORMERS_IS_AVAILBLE
XFORMERS_IS_AVAILBLE
=
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment