Commit 638a3fff authored by celine's avatar celine
Browse files

add support for mps device

parent 8aa292c6
......@@ -63,13 +63,14 @@
## <a name="update"></a>:new:Update
- **2023.09.14**: Integrate a patch-based sampling strategy ([mixture-of-diffusers](https://github.com/albarji/mixture-of-diffusers)). [**Try it!**](#general_image_inference) Here is an [example](https://imgsli.com/MjA2MDA1) with a resolution of 2396 x 1596. GPU memory usage will continue to be optimized in the future and we are looking forward to your pull requests!
- **2023.09.14**: Add support for background upsampler(DiffBIR/[RealESRGAN](https://github.com/xinntao/Real-ESRGAN)) in face enhancement! :rocket: [**Try it!** >](#unaligned_face_inference)
- **2023.09.13**: Provide online demo (DiffBIR-official) in [OpenXLab](https://openxlab.org.cn/apps/detail/linxinqi/DiffBIR-official), which integrates both general model and face model. Please have a try! [camenduru](https://github.com/camenduru) also implements an online demo, thanks for his work.:hugs:
- **2023.09.12**: Upload inference code of latent image guidance and release [real47](inputs/real47) testset.
- **2023.09.08**: Add support for restoring unaligned faces.
- **2023.09.06**: Update [colab demo](https://colab.research.google.com/github/camenduru/DiffBIR-colab/blob/main/DiffBIR_colab.ipynb). Thanks to [camenduru](https://github.com/camenduru)!:hugs:
- **2023.08.30**: Repo is released.
- **2023.09.19**: ✅ Add support to inference on **MPS/CPU** device for Apple Silicon! Check [installation_xOS.md](assets/docs/installation_xOS.md).
- **2023.09.14**: ✅ Integrate a patch-based sampling strategy ([mixture-of-diffusers](https://github.com/albarji/mixture-of-diffusers)). [**Try it!**](#general_image_inference) Here is an [example](https://imgsli.com/MjA2MDA1) with a resolution of 2396 x 1596. GPU memory usage will continue to be optimized in the future and we are looking forward to your pull requests!
- **2023.09.14**: ✅ Add support for background upsampler(DiffBIR/[RealESRGAN](https://github.com/xinntao/Real-ESRGAN)) in face enhancement! :rocket: [**Try it!** >](#unaligned_face_inference)
- **2023.09.13**: :rocket: Provide online demo (DiffBIR-official) in [OpenXLab](https://openxlab.org.cn/apps/detail/linxinqi/DiffBIR-official), which integrates both general model and face model. Please have a try! [camenduru](https://github.com/camenduru) also implements an online demo, thanks for his work.:hugs:
- **2023.09.12**: ✅ Upload inference code of latent image guidance and release [real47](inputs/real47) testset.
- **2023.09.08**: ✅ Add support for restoring unaligned faces.
- **2023.09.06**: :rocket: Update [colab demo](https://colab.research.google.com/github/camenduru/DiffBIR-colab/blob/main/DiffBIR_colab.ipynb). Thanks to [camenduru](https://github.com/camenduru)!:hugs:
- **2023.08.30**: This repo is released.
<!-- - [**History Updates** >]() -->
......@@ -83,7 +84,7 @@
- [x] Add a patch-based sampling schedule:mag:.
- [x] Upload inference code of latent image guidance:page_facing_up:.
- [ ] Improve the performance:superhero:.
- [ ] Support MPS acceleration for MacOS users.
- [x] Support MPS acceleration for MacOS users.
## <a name="installation"></a>:gear:Installation
<!-- - **Python** >= 3.9
......@@ -171,7 +172,8 @@ python inference.py \
Remove the brackets to enable tiled sampling. If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details).
#### Face Image
Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt) to `weights/` and run the following command.
<!-- Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt) to `weights/` and run the following command. -->
<!-- The model can be downloaded from the internet automatically. -->
```shell
# for aligned face inputs
......
......@@ -8,53 +8,44 @@ You can choose to run on **CPU** without `xformers` and `triton` installed.
To use **CUDA**, please refer to [issue#24](https://github.com/XPixelGroup/DiffBIR/issues/24) to try solve the problem of `triton` installation.
# MacOS
Currenly only CPU device is supported to run DiffBIR on Apple Silicon since most GPU acceleration packages are compatiable with CUDA only.
We are still trying to support MPS device. Stay tuned for our progress!
You can try to set up according to the following steps.
1. Install **torch** according to the [official document](https://pytorch.org/get-started/locally/).
```bash
pip install torch torchvision
```
2. Package `triton` and `xformers` is not needed since they work with CUDA.
Remove torch & cuda related packages. Your requirements.txt looks like:
```bash
# requirements.txt
pytorch_lightning==1.4.2
einops
open-clip-torch
omegaconf
torchmetrics==0.6.0
opencv-python-headless
scipy
matplotlib
lpips
gradio
chardet
transformers
facexlib
```
```bash
pip install -r requirements.txt
```
3. Run the inference script using CPU. Ensure you've downloaded the model weights.
```bash
python inference.py \
--input inputs/demo/general \
--config configs/model/cldm.yaml \
--ckpt weights/general_full_v1.ckpt \
--reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt \
--steps 50 \
--sr_scale 4 \
--image_size 512 \
--color_fix_type wavelet --resize_back \
--output results/demo/general \
--device cpu
```
\ No newline at end of file
<!-- Currenly only CPU device is supported to run DiffBIR on Apple Silicon since most GPU acceleration packages are compatiable with CUDA only.
We are still trying to support MPS device. Stay tuned for our progress! -->
You can try to set up according to the following steps to use CPU or MPS device.
1. Install **torch (Preview/Nighly version)**.
```bash
# MPS acceleration is available on MacOS 12.3+
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
```
Check more details in [official document](https://pytorch.org/get-started/locally/).
2. Package `triton` and `xformers` is not needed since they work with CUDA. Remove the related packages.
Your requirements.txt should look like:
```bash
# requirements.txt
pytorch_lightning==1.4.2
einops
open-clip-torch
omegaconf
torchmetrics==0.6.0
opencv-python-headless
scipy
matplotlib
lpips
gradio
chardet
transformers
facexlib
```
```bash
pip install -r requirements.txt
```
3. [Run the inference script](https://github.com/XPixelGroup/DiffBIR#general_image_inference) and specify `--device cpu` or `--device mps`. Using MPS can accelarate your inference.
You can specify `--tiled` and related arguments to avoid OOM.
\ No newline at end of file
......@@ -127,17 +127,39 @@ def parse_args() -> Namespace:
parser.add_argument("--skip_if_exist", action="store_true")
parser.add_argument("--seed", type=int, default=231)
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"])
return parser.parse_args()
def check_device(device):
if device == "cuda":
# check if CUDA is available
if not torch.cuda.is_available():
print("CUDA not available because the current PyTorch install was not "
"built with CUDA enabled.")
device = "cpu"
else:
# xformers only support CUDA. Disable xformers when using cpu or mps.
disable_xformers()
if device == "mps":
# check if MPS is available
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not "
"built with MPS enabled.")
device = "cpu"
else:
print("MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
print(f'using device {device}')
return device
def main() -> None:
args = parse_args()
pl.seed_everything(args.seed)
if args.device == "cpu":
disable_xformers()
args.device = check_device(args.device)
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
......
......@@ -15,7 +15,7 @@ from utils.image import auto_resize, pad
from utils.file import load_file_from_url
from utils.face_restoration_helper import FaceRestoreHelper
from inference import process
from inference import process, check_device
pretrained_models = {
'general_v1': {
......@@ -54,7 +54,9 @@ def parse_args() -> Namespace:
# Loading two DiffBIR models requires huge GPU memory capacity. Choose RealESRGAN as an alternative.
parser.add_argument('--bg_upsampler', type=str, default='RealESRGAN', choices=['DiffBIR', 'RealESRGAN'], help='Background upsampler.')
# TODO: support tiled for DiffBIR background upsampler
parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler.')
parser.add_argument('--bg_tile_stride', type=int, default=200, help='Tile stride for background sampler.')
# postprocessing and saving
parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
......@@ -64,8 +66,7 @@ def parse_args() -> Namespace:
# change seed to finte-tune your restored images! just specify another random number.
parser.add_argument("--seed", type=int, default=231)
# TODO: support mps device for MacOS devices
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"])
return parser.parse_args()
......@@ -115,7 +116,7 @@ def main() -> None:
assert os.path.isdir(args.input)
auto_xformers_status(args.device)
args.device = check_device(args.device)
model = build_diffbir_model(args.config, args.ckpt, args.swinir_ckpt).to(args.device)
# ------------------ set up FaceRestoreHelper -------------------
......@@ -131,15 +132,12 @@ def main() -> None:
if args.bg_upsampler == 'DiffBIR':
# Loading two DiffBIR models consumes huge GPU memory capacity.
bg_upsampler = build_diffbir_model(args.config, 'weights/general_full_v1.pth')
# try:
bg_upsampler = bg_upsampler.to(args.device)
# except:
# # put the bg_upsampler on cpu to avoid OOM
# gpu_alternate = True
elif args.bg_upsampler == 'RealESRGAN':
from utils.realesrgan.realesrganer import set_realesrgan
# support official RealESRGAN x2 & x4 upsample model
bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 4
# support official RealESRGAN x2 & x4 upsample model.
# Using x2 upsampler as default if scale is not specified as 4.
bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 2
print(f'Loading RealESRGAN_x{bg_upscale}plus.pth for background upsampling...')
bg_upsampler = set_realesrgan(args.bg_tile, args.device, bg_upscale)
else:
......
......@@ -74,7 +74,13 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
try:
# float64 as default. float64 is not supported by mps device.
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
except:
# to be compatiable with mps
res = torch.from_numpy(arr.astype(np.float32)).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
......
......@@ -6,9 +6,9 @@ from torchvision.transforms.functional import normalize
from facexlib.detection import init_detection_model
from facexlib.parsing import init_parsing_model
from facexlib.utils.misc import img2tensor, imwrite # , adain_npy, isgray, bgr2gray,
from basicsr.utils.download_util import load_file_from_url
# from basicsr.utils.misc import get_device
from facexlib.utils.misc import img2tensor, imwrite
from .file import load_file_from_url
def get_largest_face(det_faces, h, w):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment