Commit 638a3fff authored by celine's avatar celine
Browse files

add support for mps device

parent 8aa292c6
...@@ -63,13 +63,14 @@ ...@@ -63,13 +63,14 @@
## <a name="update"></a>:new:Update ## <a name="update"></a>:new:Update
- **2023.09.14**: Integrate a patch-based sampling strategy ([mixture-of-diffusers](https://github.com/albarji/mixture-of-diffusers)). [**Try it!**](#general_image_inference) Here is an [example](https://imgsli.com/MjA2MDA1) with a resolution of 2396 x 1596. GPU memory usage will continue to be optimized in the future and we are looking forward to your pull requests! - **2023.09.19**: ✅ Add support to inference on **MPS/CPU** device for Apple Silicon! Check [installation_xOS.md](assets/docs/installation_xOS.md).
- **2023.09.14**: Add support for background upsampler(DiffBIR/[RealESRGAN](https://github.com/xinntao/Real-ESRGAN)) in face enhancement! :rocket: [**Try it!** >](#unaligned_face_inference) - **2023.09.14**: ✅ Integrate a patch-based sampling strategy ([mixture-of-diffusers](https://github.com/albarji/mixture-of-diffusers)). [**Try it!**](#general_image_inference) Here is an [example](https://imgsli.com/MjA2MDA1) with a resolution of 2396 x 1596. GPU memory usage will continue to be optimized in the future and we are looking forward to your pull requests!
- **2023.09.13**: Provide online demo (DiffBIR-official) in [OpenXLab](https://openxlab.org.cn/apps/detail/linxinqi/DiffBIR-official), which integrates both general model and face model. Please have a try! [camenduru](https://github.com/camenduru) also implements an online demo, thanks for his work.:hugs: - **2023.09.14**: ✅ Add support for background upsampler(DiffBIR/[RealESRGAN](https://github.com/xinntao/Real-ESRGAN)) in face enhancement! :rocket: [**Try it!** >](#unaligned_face_inference)
- **2023.09.12**: Upload inference code of latent image guidance and release [real47](inputs/real47) testset. - **2023.09.13**: :rocket: Provide online demo (DiffBIR-official) in [OpenXLab](https://openxlab.org.cn/apps/detail/linxinqi/DiffBIR-official), which integrates both general model and face model. Please have a try! [camenduru](https://github.com/camenduru) also implements an online demo, thanks for his work.:hugs:
- **2023.09.08**: Add support for restoring unaligned faces. - **2023.09.12**: ✅ Upload inference code of latent image guidance and release [real47](inputs/real47) testset.
- **2023.09.06**: Update [colab demo](https://colab.research.google.com/github/camenduru/DiffBIR-colab/blob/main/DiffBIR_colab.ipynb). Thanks to [camenduru](https://github.com/camenduru)!:hugs: - **2023.09.08**: ✅ Add support for restoring unaligned faces.
- **2023.08.30**: Repo is released. - **2023.09.06**: :rocket: Update [colab demo](https://colab.research.google.com/github/camenduru/DiffBIR-colab/blob/main/DiffBIR_colab.ipynb). Thanks to [camenduru](https://github.com/camenduru)!:hugs:
- **2023.08.30**: This repo is released.
<!-- - [**History Updates** >]() --> <!-- - [**History Updates** >]() -->
...@@ -83,7 +84,7 @@ ...@@ -83,7 +84,7 @@
- [x] Add a patch-based sampling schedule:mag:. - [x] Add a patch-based sampling schedule:mag:.
- [x] Upload inference code of latent image guidance:page_facing_up:. - [x] Upload inference code of latent image guidance:page_facing_up:.
- [ ] Improve the performance:superhero:. - [ ] Improve the performance:superhero:.
- [ ] Support MPS acceleration for MacOS users. - [x] Support MPS acceleration for MacOS users.
## <a name="installation"></a>:gear:Installation ## <a name="installation"></a>:gear:Installation
<!-- - **Python** >= 3.9 <!-- - **Python** >= 3.9
...@@ -171,7 +172,8 @@ python inference.py \ ...@@ -171,7 +172,8 @@ python inference.py \
Remove the brackets to enable tiled sampling. If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details). Remove the brackets to enable tiled sampling. If you are confused about where the `reload_swinir` option came from, please refer to the [degradation details](#degradation-details).
#### Face Image #### Face Image
Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt) to `weights/` and run the following command. <!-- Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt) to `weights/` and run the following command. -->
<!-- The model can be downloaded from the internet automatically. -->
```shell ```shell
# for aligned face inputs # for aligned face inputs
......
...@@ -8,53 +8,44 @@ You can choose to run on **CPU** without `xformers` and `triton` installed. ...@@ -8,53 +8,44 @@ You can choose to run on **CPU** without `xformers` and `triton` installed.
To use **CUDA**, please refer to [issue#24](https://github.com/XPixelGroup/DiffBIR/issues/24) to try solve the problem of `triton` installation. To use **CUDA**, please refer to [issue#24](https://github.com/XPixelGroup/DiffBIR/issues/24) to try solve the problem of `triton` installation.
# MacOS # MacOS
Currenly only CPU device is supported to run DiffBIR on Apple Silicon since most GPU acceleration packages are compatiable with CUDA only. <!-- Currenly only CPU device is supported to run DiffBIR on Apple Silicon since most GPU acceleration packages are compatiable with CUDA only.
We are still trying to support MPS device. Stay tuned for our progress! We are still trying to support MPS device. Stay tuned for our progress! -->
You can try to set up according to the following steps. You can try to set up according to the following steps to use CPU or MPS device.
1. Install **torch** according to the [official document](https://pytorch.org/get-started/locally/). 1. Install **torch (Preview/Nighly version)**.
```bash ```bash
pip install torch torchvision # MPS acceleration is available on MacOS 12.3+
``` pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
```
2. Package `triton` and `xformers` is not needed since they work with CUDA. Check more details in [official document](https://pytorch.org/get-started/locally/).
Remove torch & cuda related packages. Your requirements.txt looks like: 2. Package `triton` and `xformers` is not needed since they work with CUDA. Remove the related packages.
```bash
# requirements.txt Your requirements.txt should look like:
pytorch_lightning==1.4.2 ```bash
einops # requirements.txt
open-clip-torch pytorch_lightning==1.4.2
omegaconf einops
torchmetrics==0.6.0 open-clip-torch
opencv-python-headless omegaconf
scipy torchmetrics==0.6.0
matplotlib opencv-python-headless
lpips scipy
gradio matplotlib
chardet lpips
transformers gradio
facexlib chardet
``` transformers
facexlib
```bash ```
pip install -r requirements.txt
``` ```bash
pip install -r requirements.txt
3. Run the inference script using CPU. Ensure you've downloaded the model weights. ```
```bash
python inference.py \ 3. [Run the inference script](https://github.com/XPixelGroup/DiffBIR#general_image_inference) and specify `--device cpu` or `--device mps`. Using MPS can accelarate your inference.
--input inputs/demo/general \
--config configs/model/cldm.yaml \ You can specify `--tiled` and related arguments to avoid OOM.
--ckpt weights/general_full_v1.ckpt \ \ No newline at end of file
--reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt \
--steps 50 \
--sr_scale 4 \
--image_size 512 \
--color_fix_type wavelet --resize_back \
--output results/demo/general \
--device cpu
```
\ No newline at end of file
...@@ -127,17 +127,39 @@ def parse_args() -> Namespace: ...@@ -127,17 +127,39 @@ def parse_args() -> Namespace:
parser.add_argument("--skip_if_exist", action="store_true") parser.add_argument("--skip_if_exist", action="store_true")
parser.add_argument("--seed", type=int, default=231) parser.add_argument("--seed", type=int, default=231)
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"]) parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"])
return parser.parse_args() return parser.parse_args()
def check_device(device):
if device == "cuda":
# check if CUDA is available
if not torch.cuda.is_available():
print("CUDA not available because the current PyTorch install was not "
"built with CUDA enabled.")
device = "cpu"
else:
# xformers only support CUDA. Disable xformers when using cpu or mps.
disable_xformers()
if device == "mps":
# check if MPS is available
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not "
"built with MPS enabled.")
device = "cpu"
else:
print("MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
print(f'using device {device}')
return device
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
pl.seed_everything(args.seed) pl.seed_everything(args.seed)
if args.device == "cpu": args.device = check_device(args.device)
disable_xformers()
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config)) model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True) load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
......
...@@ -15,7 +15,7 @@ from utils.image import auto_resize, pad ...@@ -15,7 +15,7 @@ from utils.image import auto_resize, pad
from utils.file import load_file_from_url from utils.file import load_file_from_url
from utils.face_restoration_helper import FaceRestoreHelper from utils.face_restoration_helper import FaceRestoreHelper
from inference import process from inference import process, check_device
pretrained_models = { pretrained_models = {
'general_v1': { 'general_v1': {
...@@ -54,7 +54,9 @@ def parse_args() -> Namespace: ...@@ -54,7 +54,9 @@ def parse_args() -> Namespace:
# Loading two DiffBIR models requires huge GPU memory capacity. Choose RealESRGAN as an alternative. # Loading two DiffBIR models requires huge GPU memory capacity. Choose RealESRGAN as an alternative.
parser.add_argument('--bg_upsampler', type=str, default='RealESRGAN', choices=['DiffBIR', 'RealESRGAN'], help='Background upsampler.') parser.add_argument('--bg_upsampler', type=str, default='RealESRGAN', choices=['DiffBIR', 'RealESRGAN'], help='Background upsampler.')
# TODO: support tiled for DiffBIR background upsampler
parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler.') parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler.')
parser.add_argument('--bg_tile_stride', type=int, default=200, help='Tile stride for background sampler.')
# postprocessing and saving # postprocessing and saving
parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"]) parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
...@@ -64,8 +66,7 @@ def parse_args() -> Namespace: ...@@ -64,8 +66,7 @@ def parse_args() -> Namespace:
# change seed to finte-tune your restored images! just specify another random number. # change seed to finte-tune your restored images! just specify another random number.
parser.add_argument("--seed", type=int, default=231) parser.add_argument("--seed", type=int, default=231)
# TODO: support mps device for MacOS devices parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"])
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
return parser.parse_args() return parser.parse_args()
...@@ -115,7 +116,7 @@ def main() -> None: ...@@ -115,7 +116,7 @@ def main() -> None:
assert os.path.isdir(args.input) assert os.path.isdir(args.input)
auto_xformers_status(args.device) args.device = check_device(args.device)
model = build_diffbir_model(args.config, args.ckpt, args.swinir_ckpt).to(args.device) model = build_diffbir_model(args.config, args.ckpt, args.swinir_ckpt).to(args.device)
# ------------------ set up FaceRestoreHelper ------------------- # ------------------ set up FaceRestoreHelper -------------------
...@@ -131,15 +132,12 @@ def main() -> None: ...@@ -131,15 +132,12 @@ def main() -> None:
if args.bg_upsampler == 'DiffBIR': if args.bg_upsampler == 'DiffBIR':
# Loading two DiffBIR models consumes huge GPU memory capacity. # Loading two DiffBIR models consumes huge GPU memory capacity.
bg_upsampler = build_diffbir_model(args.config, 'weights/general_full_v1.pth') bg_upsampler = build_diffbir_model(args.config, 'weights/general_full_v1.pth')
# try:
bg_upsampler = bg_upsampler.to(args.device) bg_upsampler = bg_upsampler.to(args.device)
# except:
# # put the bg_upsampler on cpu to avoid OOM
# gpu_alternate = True
elif args.bg_upsampler == 'RealESRGAN': elif args.bg_upsampler == 'RealESRGAN':
from utils.realesrgan.realesrganer import set_realesrgan from utils.realesrgan.realesrganer import set_realesrgan
# support official RealESRGAN x2 & x4 upsample model # support official RealESRGAN x2 & x4 upsample model.
bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 4 # Using x2 upsampler as default if scale is not specified as 4.
bg_upscale = int(args.sr_scale) if int(args.sr_scale) in [2, 4] else 2
print(f'Loading RealESRGAN_x{bg_upscale}plus.pth for background upsampling...') print(f'Loading RealESRGAN_x{bg_upscale}plus.pth for background upsampling...')
bg_upsampler = set_realesrgan(args.bg_tile, args.device, bg_upscale) bg_upsampler = set_realesrgan(args.bg_tile, args.device, bg_upscale)
else: else:
......
...@@ -74,7 +74,13 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): ...@@ -74,7 +74,13 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps. dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
""" """
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() try:
# float64 as default. float64 is not supported by mps device.
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
except:
# to be compatiable with mps
res = torch.from_numpy(arr.astype(np.float32)).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape): while len(res.shape) < len(broadcast_shape):
res = res[..., None] res = res[..., None]
return res.expand(broadcast_shape) return res.expand(broadcast_shape)
......
...@@ -6,9 +6,9 @@ from torchvision.transforms.functional import normalize ...@@ -6,9 +6,9 @@ from torchvision.transforms.functional import normalize
from facexlib.detection import init_detection_model from facexlib.detection import init_detection_model
from facexlib.parsing import init_parsing_model from facexlib.parsing import init_parsing_model
from facexlib.utils.misc import img2tensor, imwrite # , adain_npy, isgray, bgr2gray, from facexlib.utils.misc import img2tensor, imwrite
from basicsr.utils.download_util import load_file_from_url
# from basicsr.utils.misc import get_device from .file import load_file_from_url
def get_largest_face(det_faces, h, w): def get_largest_face(det_faces, h, w):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment