Commit 30355a12 authored by 0x3f3f3f3fun's avatar 0x3f3f3f3fun
Browse files

add support for cpu

parent 7a5f8d70
......@@ -7,4 +7,4 @@ __pycache__
!install_env.sh
/weights
/temp
results/
/results
......@@ -95,7 +95,8 @@ python gradio_diffbir.py \
--ckpt weights/general_full_v1.ckpt \
--config configs/model/cldm.yaml \
--reload_swinir \
--swinir_ckpt weights/general_swinir_v1.ckpt
--swinir_ckpt weights/general_swinir_v1.ckpt \
--device cuda
```
<div align="center">
......@@ -120,7 +121,8 @@ python inference.py \
--sr_scale 4 \
--image_size 512 \
--color_fix_type wavelet --resize_back \
--output results/general
--output results/general \
--device cuda
```
If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details).
......@@ -139,7 +141,8 @@ python inference_face.py \
--image_size 512 \
--color_fix_type wavelet \
--output results/face/aligned --resize_back \
--has_aligned
--has_aligned \
--device cuda
# for unaligned face inputs
python inference_face.py \
......@@ -150,7 +153,8 @@ python inference_face.py \
--sr_scale 1 \
--image_size 512 \
--color_fix_type wavelet \
--output results/face/whole_img --resize_back
--output results/face/whole_img --resize_back \
--device cuda
```
### Only Stage1 Model (Remove Degradations)
......@@ -181,7 +185,8 @@ python inference.py \
--input [img_dir_path] \
--color_fix_type wavelet --resize_back \
--output [output_dir_path] \
--disable_preprocess_model
--disable_preprocess_model \
--device cuda
```
## <a name="train"></a>:stars:Train
......
......@@ -10,6 +10,7 @@ import gradio as gr
from PIL import Image
from omegaconf import OmegaConf
from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler
from model.cldm import ControlLDM
from utils.image import (
......@@ -23,10 +24,12 @@ parser.add_argument("--config", required=True, type=str)
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default="")
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
args = parser.parse_args()
# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.device == "cpu":
disable_xformers()
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
# reload preprocess model if specified
......@@ -34,7 +37,7 @@ if args.reload_swinir:
print(f"reload swinir model from {args.swinir_ckpt}")
load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze()
model.to(device)
model.to(args.device)
# load sampler
sampler = SpacedSampler(model, var_type="fixed_small")
......
......@@ -10,6 +10,7 @@ import pytorch_lightning as pl
from PIL import Image
from omegaconf import OmegaConf
from ldm.xformers_state import disable_xformers
from model.spaced_sampler import SpacedSampler
from model.ddim_sampler import DDIMSampler
from model.cldm import ControlLDM
......@@ -127,6 +128,7 @@ def parse_args() -> Namespace:
parser.add_argument("--skip_if_exist", action="store_true")
parser.add_argument("--seed", type=int, default=231)
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
return parser.parse_args()
......@@ -134,7 +136,9 @@ def parse_args() -> Namespace:
def main() -> None:
args = parse_args()
pl.seed_everything(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.device == "cpu":
disable_xformers()
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
......@@ -145,12 +149,12 @@ def main() -> None:
print(f"reload swinir model from {args.swinir_ckpt}")
load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze()
model.to(device)
model.to(args.device)
assert os.path.isdir(args.input)
print(f"sampling {args.steps} steps using {args.sampler} sampler")
with torch.autocast(device):
# with torch.autocast(device, dtype=torch.bfloat16):
for file_path in list_image_files(args.input, follow_links=True):
lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1:
......@@ -173,17 +177,17 @@ def main() -> None:
raise RuntimeError(f"{save_path} already exist")
os.makedirs(parent_path, exist_ok=True)
try:
# try:
preds, stage1_preds = process(
model, [x], steps=args.steps, sampler=args.sampler,
strength=1,
color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model
)
except RuntimeError as e:
# Avoid cuda_out_of_memory error.
print(f"{file_path}, error: {e}")
continue
# except RuntimeError as e:
# # Avoid cuda_out_of_memory error.
# print(f"{file_path}, error: {e}")
# continue
pred, stage1_pred = preds[0], stage1_preds[0]
......
......@@ -10,6 +10,7 @@ from argparse import ArgumentParser, Namespace
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from ldm.xformers_state import disable_xformers
from model.cldm import ControlLDM
from model.ddim_sampler import DDIMSampler
from model.spaced_sampler import SpacedSampler
......@@ -56,6 +57,7 @@ def parse_args() -> Namespace:
parser.add_argument("--skip_if_exist", action="store_true")
parser.add_argument("--seed", type=int, default=231)
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
return parser.parse_args()
......@@ -64,7 +66,9 @@ def main() -> None:
args = parse_args()
img_save_ext = 'png'
pl.seed_everything(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.device == "cpu":
disable_xformers()
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
......@@ -75,13 +79,13 @@ def main() -> None:
print(f"reload swinir model from {args.swinir_ckpt}")
load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze()
model.to(device)
model.to(args.device)
assert os.path.isdir(args.input)
# ------------------ set up FaceRestoreHelper -------------------
face_helper = FaceRestoreHelper(
device=device,
device=args.device,
upscale_factor=1,
face_size=args.image_size,
use_parse=True,
......
......@@ -7,14 +7,14 @@ from einops import rearrange, repeat
from typing import Optional, Any
from ldm.modules.diffusionmodules.util import checkpoint
from ldm import xformers_state
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
# try:
# import xformers
# import xformers.ops
# XFORMERS_IS_AVAILBLE = True
# except:
# XFORMERS_IS_AVAILBLE = False
# CrossAttn precision handling
import os
......@@ -172,7 +172,8 @@ class CrossAttention(nn.Module):
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
# with torch.autocast(enabled=False, device_type = 'cuda'):
with torch.autocast(enabled=False, device_type=str(x.device)):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
......@@ -230,7 +231,7 @@ class MemoryEfficientCrossAttention(nn.Module):
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = xformers_state.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
......@@ -251,7 +252,8 @@ class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):
super().__init__()
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
# attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
attn_mode = "softmax-xformers" if xformers_state.is_xformers_available() else "softmax"
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.disable_self_attn = disable_self_attn
......
......@@ -7,14 +7,16 @@ from einops import rearrange
from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention
from ldm import xformers_state
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
print("No module 'xformers'. Proceeding without it.")
# try:
# import xformers
# import xformers.ops
# XFORMERS_IS_AVAILBLE = True
# except:
# XFORMERS_IS_AVAILBLE = False
# print("No module 'xformers'. Proceeding without it.")
def get_timestep_embedding(timesteps, embedding_dim):
......@@ -255,7 +257,7 @@ class MemoryEfficientAttnBlock(nn.Module):
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = xformers_state.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = (
out.unsqueeze(0)
......@@ -279,7 +281,8 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
# if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
if xformers_state.is_xformers_available() and attn_type == "vanilla":
attn_type = "vanilla-xformers"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
......
......@@ -140,7 +140,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
"last",
"penultimate"
]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
# def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", max_length=77,
freeze=True, layer="last"):
super().__init__()
assert layer in self.LAYERS
......@@ -148,7 +149,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
del model.visual
self.model = model
self.device = device
# self.device = device
self.max_length = max_length
if freeze:
self.freeze()
......@@ -167,7 +168,8 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
# z = self.encode_with_transformer(tokens.to(self.device))
z = self.encode_with_transformer(tokens.to(next(self.model.parameters()).device))
return z
def encode_with_transformer(self, text):
......
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
print("No module 'xformers'. Proceeding without it.")
def is_xformers_available() -> bool:
global XFORMERS_IS_AVAILBLE
return XFORMERS_IS_AVAILBLE
def disable_xformers() -> None:
print("DISABLE XFORMERS!")
global XFORMERS_IS_AVAILBLE
XFORMERS_IS_AVAILBLE = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment