Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
DiffBIR_pytorch
Commits
30355a12
Commit
30355a12
authored
Sep 11, 2023
by
0x3f3f3f3fun
Browse files
add support for cpu
parent
7a5f8d70
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
131 additions
and
91 deletions
+131
-91
.gitignore
.gitignore
+1
-1
README.md
README.md
+10
-5
gradio_diffbir.py
gradio_diffbir.py
+5
-2
inference.py
inference.py
+61
-57
inference_face.py
inference_face.py
+8
-4
ldm/modules/attention.py
ldm/modules/attention.py
+12
-10
ldm/modules/diffusionmodules/model.py
ldm/modules/diffusionmodules/model.py
+12
-9
ldm/modules/encoders/modules.py
ldm/modules/encoders/modules.py
+5
-3
ldm/xformers_state.py
ldm/xformers_state.py
+17
-0
No files found.
.gitignore
View file @
30355a12
...
...
@@ -7,4 +7,4 @@ __pycache__
!install_env.sh
/weights
/temp
results
/
/
results
README.md
View file @
30355a12
...
...
@@ -95,7 +95,8 @@ python gradio_diffbir.py \
--ckpt weights/general_full_v1.ckpt \
--config configs/model/cldm.yaml \
--reload_swinir \
--swinir_ckpt weights/general_swinir_v1.ckpt
--swinir_ckpt weights/general_swinir_v1.ckpt \
--device cuda
```
<div align="center">
...
...
@@ -120,7 +121,8 @@ python inference.py \
--sr_scale
4
\
--image_size
512
\
--color_fix_type
wavelet
--resize_back
\
--output
results/general
--output
results/general
\
--device
cuda
```
If you are confused about where the
`reload_swinir`
option came from, please refer to the
[
degradation details
](
#degradation-details
)
.
...
...
@@ -139,7 +141,8 @@ python inference_face.py \
--image_size
512
\
--color_fix_type
wavelet
\
--output
results/face/aligned
--resize_back
\
--has_aligned
--has_aligned
\
--device
cuda
# for unaligned face inputs
python inference_face.py
\
...
...
@@ -150,7 +153,8 @@ python inference_face.py \
--sr_scale
1
\
--image_size
512
\
--color_fix_type
wavelet
\
--output
results/face/whole_img
--resize_back
--output
results/face/whole_img
--resize_back
\
--device
cuda
```
### Only Stage1 Model (Remove Degradations)
...
...
@@ -181,7 +185,8 @@ python inference.py \
--input
[
img_dir_path]
\
--color_fix_type
wavelet
--resize_back
\
--output
[
output_dir_path]
\
--disable_preprocess_model
--disable_preprocess_model
\
--device
cuda
```
## <a name="train"></a>:stars:Train
...
...
gradio_diffbir.py
View file @
30355a12
...
...
@@ -10,6 +10,7 @@ import gradio as gr
from
PIL
import
Image
from
omegaconf
import
OmegaConf
from
ldm.xformers_state
import
disable_xformers
from
model.spaced_sampler
import
SpacedSampler
from
model.cldm
import
ControlLDM
from
utils.image
import
(
...
...
@@ -23,10 +24,12 @@ parser.add_argument("--config", required=True, type=str)
parser
.
add_argument
(
"--ckpt"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--reload_swinir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--swinir_ckpt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cpu"
,
"cuda"
])
args
=
parser
.
parse_args
()
# load model
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
if
args
.
device
==
"cpu"
:
disable_xformers
()
model
:
ControlLDM
=
instantiate_from_config
(
OmegaConf
.
load
(
args
.
config
))
load_state_dict
(
model
,
torch
.
load
(
args
.
ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
# reload preprocess model if specified
...
...
@@ -34,7 +37,7 @@ if args.reload_swinir:
print
(
f
"reload swinir model from
{
args
.
swinir_ckpt
}
"
)
load_state_dict
(
model
.
preprocess_model
,
torch
.
load
(
args
.
swinir_ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
model
.
freeze
()
model
.
to
(
device
)
model
.
to
(
args
.
device
)
# load sampler
sampler
=
SpacedSampler
(
model
,
var_type
=
"fixed_small"
)
...
...
inference.py
View file @
30355a12
...
...
@@ -10,6 +10,7 @@ import pytorch_lightning as pl
from
PIL
import
Image
from
omegaconf
import
OmegaConf
from
ldm.xformers_state
import
disable_xformers
from
model.spaced_sampler
import
SpacedSampler
from
model.ddim_sampler
import
DDIMSampler
from
model.cldm
import
ControlLDM
...
...
@@ -127,6 +128,7 @@ def parse_args() -> Namespace:
parser
.
add_argument
(
"--skip_if_exist"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
231
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cpu"
,
"cuda"
])
return
parser
.
parse_args
()
...
...
@@ -134,7 +136,9 @@ def parse_args() -> Namespace:
def
main
()
->
None
:
args
=
parse_args
()
pl
.
seed_everything
(
args
.
seed
)
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
if
args
.
device
==
"cpu"
:
disable_xformers
()
model
:
ControlLDM
=
instantiate_from_config
(
OmegaConf
.
load
(
args
.
config
))
load_state_dict
(
model
,
torch
.
load
(
args
.
ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
...
...
@@ -145,12 +149,12 @@ def main() -> None:
print
(
f
"reload swinir model from
{
args
.
swinir_ckpt
}
"
)
load_state_dict
(
model
.
preprocess_model
,
torch
.
load
(
args
.
swinir_ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
model
.
freeze
()
model
.
to
(
device
)
model
.
to
(
args
.
device
)
assert
os
.
path
.
isdir
(
args
.
input
)
print
(
f
"sampling
{
args
.
steps
}
steps using
{
args
.
sampler
}
sampler"
)
with
torch
.
autocast
(
device
):
#
with torch.autocast(device
, dtype=torch.bfloat16
):
for
file_path
in
list_image_files
(
args
.
input
,
follow_links
=
True
):
lq
=
Image
.
open
(
file_path
).
convert
(
"RGB"
)
if
args
.
sr_scale
!=
1
:
...
...
@@ -173,17 +177,17 @@ def main() -> None:
raise
RuntimeError
(
f
"
{
save_path
}
already exist"
)
os
.
makedirs
(
parent_path
,
exist_ok
=
True
)
try
:
#
try:
preds
,
stage1_preds
=
process
(
model
,
[
x
],
steps
=
args
.
steps
,
sampler
=
args
.
sampler
,
strength
=
1
,
color_fix_type
=
args
.
color_fix_type
,
disable_preprocess_model
=
args
.
disable_preprocess_model
)
except
RuntimeError
as
e
:
# Avoid cuda_out_of_memory error.
print
(
f
"
{
file_path
}
, error:
{
e
}
"
)
continue
#
except RuntimeError as e:
#
# Avoid cuda_out_of_memory error.
#
print(f"{file_path}, error: {e}")
#
continue
pred
,
stage1_pred
=
preds
[
0
],
stage1_preds
[
0
]
...
...
inference_face.py
View file @
30355a12
...
...
@@ -10,6 +10,7 @@ from argparse import ArgumentParser, Namespace
from
facexlib.utils.face_restoration_helper
import
FaceRestoreHelper
from
ldm.xformers_state
import
disable_xformers
from
model.cldm
import
ControlLDM
from
model.ddim_sampler
import
DDIMSampler
from
model.spaced_sampler
import
SpacedSampler
...
...
@@ -56,6 +57,7 @@ def parse_args() -> Namespace:
parser
.
add_argument
(
"--skip_if_exist"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
231
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cpu"
,
"cuda"
])
return
parser
.
parse_args
()
...
...
@@ -64,7 +66,9 @@ def main() -> None:
args
=
parse_args
()
img_save_ext
=
'png'
pl
.
seed_everything
(
args
.
seed
)
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
if
args
.
device
==
"cpu"
:
disable_xformers
()
model
:
ControlLDM
=
instantiate_from_config
(
OmegaConf
.
load
(
args
.
config
))
load_state_dict
(
model
,
torch
.
load
(
args
.
ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
...
...
@@ -75,13 +79,13 @@ def main() -> None:
print
(
f
"reload swinir model from
{
args
.
swinir_ckpt
}
"
)
load_state_dict
(
model
.
preprocess_model
,
torch
.
load
(
args
.
swinir_ckpt
,
map_location
=
"cpu"
),
strict
=
True
)
model
.
freeze
()
model
.
to
(
device
)
model
.
to
(
args
.
device
)
assert
os
.
path
.
isdir
(
args
.
input
)
# ------------------ set up FaceRestoreHelper -------------------
face_helper
=
FaceRestoreHelper
(
device
=
device
,
device
=
args
.
device
,
upscale_factor
=
1
,
face_size
=
args
.
image_size
,
use_parse
=
True
,
...
...
ldm/modules/attention.py
View file @
30355a12
...
...
@@ -7,14 +7,14 @@ from einops import rearrange, repeat
from
typing
import
Optional
,
Any
from
ldm.modules.diffusionmodules.util
import
checkpoint
from
ldm
import
xformers_state
try
:
import
xformers
import
xformers.ops
XFORMERS_IS_AVAILBLE
=
True
except
:
XFORMERS_IS_AVAILBLE
=
False
# try:
# import xformers
# import xformers.ops
# XFORMERS_IS_AVAILBLE = True
# except:
# XFORMERS_IS_AVAILBLE = False
# CrossAttn precision handling
import
os
...
...
@@ -172,7 +172,8 @@ class CrossAttention(nn.Module):
# force cast to fp32 to avoid overflowing
if
_ATTN_PRECISION
==
"fp32"
:
with
torch
.
autocast
(
enabled
=
False
,
device_type
=
'cuda'
):
# with torch.autocast(enabled=False, device_type = 'cuda'):
with
torch
.
autocast
(
enabled
=
False
,
device_type
=
str
(
x
.
device
)):
q
,
k
=
q
.
float
(),
k
.
float
()
sim
=
einsum
(
'b i d, b j d -> b i j'
,
q
,
k
)
*
self
.
scale
else
:
...
...
@@ -230,7 +231,7 @@ class MemoryEfficientCrossAttention(nn.Module):
)
# actually compute the attention, what we cannot get enough of
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
out
=
xformers_state
.
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
if
exists
(
mask
):
raise
NotImplementedError
...
...
@@ -251,7 +252,8 @@ class BasicTransformerBlock(nn.Module):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
):
super
().
__init__
()
attn_mode
=
"softmax-xformers"
if
XFORMERS_IS_AVAILBLE
else
"softmax"
# attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
attn_mode
=
"softmax-xformers"
if
xformers_state
.
is_xformers_available
()
else
"softmax"
assert
attn_mode
in
self
.
ATTENTION_MODES
attn_cls
=
self
.
ATTENTION_MODES
[
attn_mode
]
self
.
disable_self_attn
=
disable_self_attn
...
...
ldm/modules/diffusionmodules/model.py
View file @
30355a12
...
...
@@ -7,14 +7,16 @@ from einops import rearrange
from
typing
import
Optional
,
Any
from
ldm.modules.attention
import
MemoryEfficientCrossAttention
from
ldm
import
xformers_state
try
:
import
xformers
import
xformers.ops
XFORMERS_IS_AVAILBLE
=
True
except
:
XFORMERS_IS_AVAILBLE
=
False
print
(
"No module 'xformers'. Proceeding without it."
)
# try:
# import xformers
# import xformers.ops
# XFORMERS_IS_AVAILBLE = True
# except:
# XFORMERS_IS_AVAILBLE = False
# print("No module 'xformers'. Proceeding without it.")
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
...
...
@@ -255,7 +257,7 @@ class MemoryEfficientAttnBlock(nn.Module):
.
contiguous
(),
(
q
,
k
,
v
),
)
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
out
=
xformers_state
.
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
out
=
(
out
.
unsqueeze
(
0
)
...
...
@@ -279,7 +281,8 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def
make_attn
(
in_channels
,
attn_type
=
"vanilla"
,
attn_kwargs
=
None
):
assert
attn_type
in
[
"vanilla"
,
"vanilla-xformers"
,
"memory-efficient-cross-attn"
,
"linear"
,
"none"
],
f
'attn_type
{
attn_type
}
unknown'
if
XFORMERS_IS_AVAILBLE
and
attn_type
==
"vanilla"
:
# if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
if
xformers_state
.
is_xformers_available
()
and
attn_type
==
"vanilla"
:
attn_type
=
"vanilla-xformers"
print
(
f
"making attention of type '
{
attn_type
}
' with
{
in_channels
}
in_channels"
)
if
attn_type
==
"vanilla"
:
...
...
ldm/modules/encoders/modules.py
View file @
30355a12
...
...
@@ -140,7 +140,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
"last"
,
"penultimate"
]
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
version
=
"laion2b_s32b_b79k"
,
device
=
"cuda"
,
max_length
=
77
,
# def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
version
=
"laion2b_s32b_b79k"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
):
super
().
__init__
()
assert
layer
in
self
.
LAYERS
...
...
@@ -148,7 +149,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
del
model
.
visual
self
.
model
=
model
self
.
device
=
device
#
self.device = device
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
...
...
@@ -167,7 +168,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
def
forward
(
self
,
text
):
tokens
=
open_clip
.
tokenize
(
text
)
z
=
self
.
encode_with_transformer
(
tokens
.
to
(
self
.
device
))
# z = self.encode_with_transformer(tokens.to(self.device))
z
=
self
.
encode_with_transformer
(
tokens
.
to
(
next
(
self
.
model
.
parameters
()).
device
))
return
z
def
encode_with_transformer
(
self
,
text
):
...
...
ldm/xformers_state.py
0 → 100644
View file @
30355a12
try
:
import
xformers
import
xformers.ops
XFORMERS_IS_AVAILBLE
=
True
except
:
XFORMERS_IS_AVAILBLE
=
False
print
(
"No module 'xformers'. Proceeding without it."
)
def
is_xformers_available
()
->
bool
:
global
XFORMERS_IS_AVAILBLE
return
XFORMERS_IS_AVAILBLE
def
disable_xformers
()
->
None
:
print
(
"DISABLE XFORMERS!"
)
global
XFORMERS_IS_AVAILBLE
XFORMERS_IS_AVAILBLE
=
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment