Unverified Commit d914488a authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Add color space settings in VARecorder (#438)

Add color space settings in VARecorder and implement load_image function
in WanAudioRunner
parent 34c1b7b1
......@@ -170,6 +170,14 @@ class VARecorder:
"rawvideo",
"-pix_fmt",
"rgb24",
"-color_range",
"pc",
"-colorspace",
"rgb",
"-color_primaries",
"bt709",
"-color_trc",
"iec61966-2-1",
"-r",
str(self.fps),
"-s",
......
import gc
import io
import json
import os
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -11,7 +12,7 @@ import torch.distributed as dist
import torch.nn.functional as F
import torchaudio as ta
import torchvision.transforms.functional as TF
from PIL import Image
from PIL import Image, ImageCms, ImageOps
from einops import rearrange
from loguru import logger
from torchvision.transforms import InterpolationMode
......@@ -328,6 +329,28 @@ class AudioProcessor:
return start_end_list
def load_image(image: Union[str, Image.Image], to_rgb: bool = True) -> Image.Image:
_image = image
if isinstance(image, str):
if os.path.isfile(image):
_image = Image.open(image)
else:
raise ValueError(f"Incorrect path. {image} is not a valid path.")
# orientation transpose
_image = ImageOps.exif_transpose(_image)
# convert color space to sRGB
icc_profile = _image.info.get("icc_profile")
if icc_profile:
srgb_profile = ImageCms.createProfile("sRGB")
input_profile = ImageCms.ImageCmsProfile(io.BytesIO(icc_profile))
_image = ImageCms.profileToProfile(_image, input_profile, srgb_profile)
# convert to "RGB"
if to_rgb:
_image = _image.convert("RGB")
return _image
@RUNNER_REGISTER("seko_talk")
class WanAudioRunner(WanRunner): # type:ignore
def __init__(self, config):
......@@ -399,7 +422,7 @@ class WanAudioRunner(WanRunner): # type:ignore
return audio_files, mask_files
def process_single_mask(self, mask_file):
mask_img = Image.open(mask_file).convert("RGB")
mask_img = load_image(mask_file)
mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
if mask_img.shape[1] == 3: # If it is an RGB three-channel image
......@@ -426,7 +449,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if isinstance(img_path, Image.Image):
ref_img = img_path
else:
ref_img = Image.open(img_path).convert("RGB")
ref_img = load_image(img_path)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
ref_img, h, w = resize_image(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment