Commit 30355a12 authored by 0x3f3f3f3fun's avatar 0x3f3f3f3fun
Browse files

add support for cpu

parent 7a5f8d70
...@@ -7,4 +7,4 @@ __pycache__ ...@@ -7,4 +7,4 @@ __pycache__
!install_env.sh !install_env.sh
/weights /weights
/temp /temp
results/ /results
...@@ -95,7 +95,8 @@ python gradio_diffbir.py \ ...@@ -95,7 +95,8 @@ python gradio_diffbir.py \
--ckpt weights/general_full_v1.ckpt \ --ckpt weights/general_full_v1.ckpt \
--config configs/model/cldm.yaml \ --config configs/model/cldm.yaml \
--reload_swinir \ --reload_swinir \
--swinir_ckpt weights/general_swinir_v1.ckpt --swinir_ckpt weights/general_swinir_v1.ckpt \
--device cuda
``` ```
<div align="center"> <div align="center">
...@@ -120,7 +121,8 @@ python inference.py \ ...@@ -120,7 +121,8 @@ python inference.py \
--sr_scale 4 \ --sr_scale 4 \
--image_size 512 \ --image_size 512 \
--color_fix_type wavelet --resize_back \ --color_fix_type wavelet --resize_back \
--output results/general --output results/general \
--device cuda
``` ```
If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details). If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details).
...@@ -139,7 +141,8 @@ python inference_face.py \ ...@@ -139,7 +141,8 @@ python inference_face.py \
--image_size 512 \ --image_size 512 \
--color_fix_type wavelet \ --color_fix_type wavelet \
--output results/face/aligned --resize_back \ --output results/face/aligned --resize_back \
--has_aligned --has_aligned \
--device cuda
# for unaligned face inputs # for unaligned face inputs
python inference_face.py \ python inference_face.py \
...@@ -150,7 +153,8 @@ python inference_face.py \ ...@@ -150,7 +153,8 @@ python inference_face.py \
--sr_scale 1 \ --sr_scale 1 \
--image_size 512 \ --image_size 512 \
--color_fix_type wavelet \ --color_fix_type wavelet \
--output results/face/whole_img --resize_back --output results/face/whole_img --resize_back \
--device cuda
``` ```
### Only Stage1 Model (Remove Degradations) ### Only Stage1 Model (Remove Degradations)
...@@ -181,7 +185,8 @@ python inference.py \ ...@@ -181,7 +185,8 @@ python inference.py \
--input [img_dir_path] \ --input [img_dir_path] \
--color_fix_type wavelet --resize_back \ --color_fix_type wavelet --resize_back \
--output [output_dir_path] \ --output [output_dir_path] \
--disable_preprocess_model --disable_preprocess_model \
--device cuda
``` ```
## <a name="train"></a>:stars:Train ## <a name="train"></a>:stars:Train
......
...@@ -10,6 +10,7 @@ import gradio as gr ...@@ -10,6 +10,7 @@ import gradio as gr
from PIL import Image from PIL import Image
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler from model.spaced_sampler import SpacedSampler
from model.cldm import ControlLDM from model.cldm import ControlLDM
from utils.image import ( from utils.image import (
...@@ -23,10 +24,12 @@ parser.add_argument("--config", required=True, type=str) ...@@ -23,10 +24,12 @@ parser.add_argument("--config", required=True, type=str)
parser.add_argument("--ckpt", type=str, required=True) parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--reload_swinir", action="store_true") parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default="") parser.add_argument("--swinir_ckpt", type=str, default="")
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
args = parser.parse_args() args = parser.parse_args()
# load model # load model
device = "cuda" if torch.cuda.is_available() else "cpu" if args.device == "cpu":
disable_xformers()
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config)) model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True) load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
# reload preprocess model if specified # reload preprocess model if specified
...@@ -34,7 +37,7 @@ if args.reload_swinir: ...@@ -34,7 +37,7 @@ if args.reload_swinir:
print(f"reload swinir model from {args.swinir_ckpt}") print(f"reload swinir model from {args.swinir_ckpt}")
load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True) load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze() model.freeze()
model.to(device) model.to(args.device)
# load sampler # load sampler
sampler = SpacedSampler(model, var_type="fixed_small") sampler = SpacedSampler(model, var_type="fixed_small")
......
...@@ -10,6 +10,7 @@ import pytorch_lightning as pl ...@@ -10,6 +10,7 @@ import pytorch_lightning as pl
from PIL import Image from PIL import Image
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler from model.spaced_sampler import SpacedSampler
from model.ddim_sampler import DDIMSampler from model.ddim_sampler import DDIMSampler
from model.cldm import ControlLDM from model.cldm import ControlLDM
...@@ -127,6 +128,7 @@ def parse_args() -> Namespace: ...@@ -127,6 +128,7 @@ def parse_args() -> Namespace:
parser.add_argument("--skip_if_exist", action="store_true") parser.add_argument("--skip_if_exist", action="store_true")
parser.add_argument("--seed", type=int, default=231) parser.add_argument("--seed", type=int, default=231)
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
return parser.parse_args() return parser.parse_args()
...@@ -134,7 +136,9 @@ def parse_args() -> Namespace: ...@@ -134,7 +136,9 @@ def parse_args() -> Namespace:
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
pl.seed_everything(args.seed) pl.seed_everything(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.device == "cpu":
disable_xformers()
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config)) model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True) load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
...@@ -145,68 +149,68 @@ def main() -> None: ...@@ -145,68 +149,68 @@ def main() -> None:
print(f"reload swinir model from {args.swinir_ckpt}") print(f"reload swinir model from {args.swinir_ckpt}")
load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True) load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze() model.freeze()
model.to(device) model.to(args.device)
assert os.path.isdir(args.input) assert os.path.isdir(args.input)
print(f"sampling {args.steps} steps using {args.sampler} sampler") print(f"sampling {args.steps} steps using {args.sampler} sampler")
with torch.autocast(device): # with torch.autocast(device, dtype=torch.bfloat16):
for file_path in list_image_files(args.input, follow_links=True): for file_path in list_image_files(args.input, follow_links=True):
lq = Image.open(file_path).convert("RGB") lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1: if args.sr_scale != 1:
lq = lq.resize( lq = lq.resize(
tuple(math.ceil(x * args.sr_scale) for x in lq.size), tuple(math.ceil(x * args.sr_scale) for x in lq.size),
Image.BICUBIC Image.BICUBIC
) )
lq_resized = auto_resize(lq, args.image_size) lq_resized = auto_resize(lq, args.image_size)
x = pad(np.array(lq_resized), scale=64) x = pad(np.array(lq_resized), scale=64)
for i in range(args.repeat_times): for i in range(args.repeat_times):
save_path = os.path.join(args.output, os.path.relpath(file_path, args.input)) save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
parent_path, stem, _ = get_file_name_parts(save_path) parent_path, stem, _ = get_file_name_parts(save_path)
save_path = os.path.join(parent_path, f"{stem}_{i}.png") save_path = os.path.join(parent_path, f"{stem}_{i}.png")
if os.path.exists(save_path): if os.path.exists(save_path):
if args.skip_if_exist: if args.skip_if_exist:
print(f"skip {save_path}") print(f"skip {save_path}")
continue
else:
raise RuntimeError(f"{save_path} already exist")
os.makedirs(parent_path, exist_ok=True)
try:
preds, stage1_preds = process(
model, [x], steps=args.steps, sampler=args.sampler,
strength=1,
color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model
)
except RuntimeError as e:
# Avoid cuda_out_of_memory error.
print(f"{file_path}, error: {e}")
continue continue
pred, stage1_pred = preds[0], stage1_preds[0]
# remove padding
pred = pred[:lq_resized.height, :lq_resized.width, :]
stage1_pred = stage1_pred[:lq_resized.height, :lq_resized.width, :]
if args.show_lq:
if args.resize_back:
if lq_resized.size != lq.size:
pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
stage1_pred = np.array(Image.fromarray(stage1_pred).resize(lq.size, Image.LANCZOS))
lq = np.array(lq)
else:
lq = np.array(lq_resized)
images = [lq, pred] if args.disable_preprocess_model else [lq, stage1_pred, pred]
Image.fromarray(np.concatenate(images, axis=1)).save(save_path)
else: else:
if args.resize_back and lq_resized.size != lq.size: raise RuntimeError(f"{save_path} already exist")
Image.fromarray(pred).resize(lq.size, Image.LANCZOS).save(save_path) os.makedirs(parent_path, exist_ok=True)
else:
Image.fromarray(pred).save(save_path) # try:
print(f"save to {save_path}") preds, stage1_preds = process(
model, [x], steps=args.steps, sampler=args.sampler,
strength=1,
color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model
)
# except RuntimeError as e:
# # Avoid cuda_out_of_memory error.
# print(f"{file_path}, error: {e}")
# continue
pred, stage1_pred = preds[0], stage1_preds[0]
# remove padding
pred = pred[:lq_resized.height, :lq_resized.width, :]
stage1_pred = stage1_pred[:lq_resized.height, :lq_resized.width, :]
if args.show_lq:
if args.resize_back:
if lq_resized.size != lq.size:
pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
stage1_pred = np.array(Image.fromarray(stage1_pred).resize(lq.size, Image.LANCZOS))
lq = np.array(lq)
else:
lq = np.array(lq_resized)
images = [lq, pred] if args.disable_preprocess_model else [lq, stage1_pred, pred]
Image.fromarray(np.concatenate(images, axis=1)).save(save_path)
else:
if args.resize_back and lq_resized.size != lq.size:
Image.fromarray(pred).resize(lq.size, Image.LANCZOS).save(save_path)
else:
Image.fromarray(pred).save(save_path)
print(f"save to {save_path}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -10,6 +10,7 @@ from argparse import ArgumentParser, Namespace ...@@ -10,6 +10,7 @@ from argparse import ArgumentParser, Namespace
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from ldm.xformers_state import disable_xformers
from model.cldm import ControlLDM from model.cldm import ControlLDM
from model.ddim_sampler import DDIMSampler from model.ddim_sampler import DDIMSampler
from model.spaced_sampler import SpacedSampler from model.spaced_sampler import SpacedSampler
...@@ -56,6 +57,7 @@ def parse_args() -> Namespace: ...@@ -56,6 +57,7 @@ def parse_args() -> Namespace:
parser.add_argument("--skip_if_exist", action="store_true") parser.add_argument("--skip_if_exist", action="store_true")
parser.add_argument("--seed", type=int, default=231) parser.add_argument("--seed", type=int, default=231)
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
return parser.parse_args() return parser.parse_args()
...@@ -64,7 +66,9 @@ def main() -> None: ...@@ -64,7 +66,9 @@ def main() -> None:
args = parse_args() args = parse_args()
img_save_ext = 'png' img_save_ext = 'png'
pl.seed_everything(args.seed) pl.seed_everything(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.device == "cpu":
disable_xformers()
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config)) model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True) load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
...@@ -75,13 +79,13 @@ def main() -> None: ...@@ -75,13 +79,13 @@ def main() -> None:
print(f"reload swinir model from {args.swinir_ckpt}") print(f"reload swinir model from {args.swinir_ckpt}")
load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True) load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze() model.freeze()
model.to(device) model.to(args.device)
assert os.path.isdir(args.input) assert os.path.isdir(args.input)
# ------------------ set up FaceRestoreHelper ------------------- # ------------------ set up FaceRestoreHelper -------------------
face_helper = FaceRestoreHelper( face_helper = FaceRestoreHelper(
device=device, device=args.device,
upscale_factor=1, upscale_factor=1,
face_size=args.image_size, face_size=args.image_size,
use_parse=True, use_parse=True,
...@@ -186,4 +190,4 @@ def main() -> None: ...@@ -186,4 +190,4 @@ def main() -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -7,14 +7,14 @@ from einops import rearrange, repeat ...@@ -7,14 +7,14 @@ from einops import rearrange, repeat
from typing import Optional, Any from typing import Optional, Any
from ldm.modules.diffusionmodules.util import checkpoint from ldm.modules.diffusionmodules.util import checkpoint
from ldm import xformers_state
# try:
try: # import xformers
import xformers # import xformers.ops
import xformers.ops # XFORMERS_IS_AVAILBLE = True
XFORMERS_IS_AVAILBLE = True # except:
except: # XFORMERS_IS_AVAILBLE = False
XFORMERS_IS_AVAILBLE = False
# CrossAttn precision handling # CrossAttn precision handling
import os import os
...@@ -172,7 +172,8 @@ class CrossAttention(nn.Module): ...@@ -172,7 +172,8 @@ class CrossAttention(nn.Module):
# force cast to fp32 to avoid overflowing # force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32": if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'): # with torch.autocast(enabled=False, device_type = 'cuda'):
with torch.autocast(enabled=False, device_type=str(x.device)):
q, k = q.float(), k.float() q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else: else:
...@@ -230,7 +231,7 @@ class MemoryEfficientCrossAttention(nn.Module): ...@@ -230,7 +231,7 @@ class MemoryEfficientCrossAttention(nn.Module):
) )
# actually compute the attention, what we cannot get enough of # actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) out = xformers_state.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask): if exists(mask):
raise NotImplementedError raise NotImplementedError
...@@ -251,7 +252,8 @@ class BasicTransformerBlock(nn.Module): ...@@ -251,7 +252,8 @@ class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False): disable_self_attn=False):
super().__init__() super().__init__()
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
attn_mode = "softmax-xformers" if xformers_state.is_xformers_available() else "softmax"
assert attn_mode in self.ATTENTION_MODES assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode] attn_cls = self.ATTENTION_MODES[attn_mode]
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
......
...@@ -7,14 +7,16 @@ from einops import rearrange ...@@ -7,14 +7,16 @@ from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention from ldm.modules.attention import MemoryEfficientCrossAttention
from ldm import xformers_state
try:
import xformers # try:
import xformers.ops # import xformers
XFORMERS_IS_AVAILBLE = True # import xformers.ops
except: # XFORMERS_IS_AVAILBLE = True
XFORMERS_IS_AVAILBLE = False # except:
print("No module 'xformers'. Proceeding without it.") # XFORMERS_IS_AVAILBLE = False
# print("No module 'xformers'. Proceeding without it.")
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
...@@ -255,7 +257,7 @@ class MemoryEfficientAttnBlock(nn.Module): ...@@ -255,7 +257,7 @@ class MemoryEfficientAttnBlock(nn.Module):
.contiguous(), .contiguous(),
(q, k, v), (q, k, v),
) )
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) out = xformers_state.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = ( out = (
out.unsqueeze(0) out.unsqueeze(0)
...@@ -279,7 +281,8 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): ...@@ -279,7 +281,8 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": # if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
if xformers_state.is_xformers_available() and attn_type == "vanilla":
attn_type = "vanilla-xformers" attn_type = "vanilla-xformers"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels") print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla": if attn_type == "vanilla":
......
...@@ -140,7 +140,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): ...@@ -140,7 +140,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
"last", "last",
"penultimate" "penultimate"
] ]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, # def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", max_length=77,
freeze=True, layer="last"): freeze=True, layer="last"):
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
...@@ -148,7 +149,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): ...@@ -148,7 +149,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
del model.visual del model.visual
self.model = model self.model = model
self.device = device # self.device = device
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
self.freeze() self.freeze()
...@@ -167,7 +168,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): ...@@ -167,7 +168,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
def forward(self, text): def forward(self, text):
tokens = open_clip.tokenize(text) tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device)) # z = self.encode_with_transformer(tokens.to(self.device))
z = self.encode_with_transformer(tokens.to(next(self.model.parameters()).device))
return z return z
def encode_with_transformer(self, text): def encode_with_transformer(self, text):
......
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
print("No module 'xformers'. Proceeding without it.")
def is_xformers_available() -> bool:
global XFORMERS_IS_AVAILBLE
return XFORMERS_IS_AVAILBLE
def disable_xformers() -> None:
print("DISABLE XFORMERS!")
global XFORMERS_IS_AVAILBLE
XFORMERS_IS_AVAILBLE = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment