"scripts/deprecated/test_httpserver_classify.py" did not exist on "1fa15099d85087deeaa5090c76361e53abf9d4a6"
Unverified Commit d914488a authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Add color space settings in VARecorder (#438)

Add color space settings in VARecorder and implement load_image function
in WanAudioRunner
parent 34c1b7b1
...@@ -170,6 +170,14 @@ class VARecorder: ...@@ -170,6 +170,14 @@ class VARecorder:
"rawvideo", "rawvideo",
"-pix_fmt", "-pix_fmt",
"rgb24", "rgb24",
"-color_range",
"pc",
"-colorspace",
"rgb",
"-color_primaries",
"bt709",
"-color_trc",
"iec61966-2-1",
"-r", "-r",
str(self.fps), str(self.fps),
"-s", "-s",
......
import gc import gc
import io
import json import json
import os import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -11,7 +12,7 @@ import torch.distributed as dist ...@@ -11,7 +12,7 @@ import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio as ta import torchaudio as ta
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image, ImageCms, ImageOps
from einops import rearrange from einops import rearrange
from loguru import logger from loguru import logger
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
...@@ -328,6 +329,28 @@ class AudioProcessor: ...@@ -328,6 +329,28 @@ class AudioProcessor:
return start_end_list return start_end_list
def load_image(image: Union[str, Image.Image], to_rgb: bool = True) -> Image.Image:
_image = image
if isinstance(image, str):
if os.path.isfile(image):
_image = Image.open(image)
else:
raise ValueError(f"Incorrect path. {image} is not a valid path.")
# orientation transpose
_image = ImageOps.exif_transpose(_image)
# convert color space to sRGB
icc_profile = _image.info.get("icc_profile")
if icc_profile:
srgb_profile = ImageCms.createProfile("sRGB")
input_profile = ImageCms.ImageCmsProfile(io.BytesIO(icc_profile))
_image = ImageCms.profileToProfile(_image, input_profile, srgb_profile)
# convert to "RGB"
if to_rgb:
_image = _image.convert("RGB")
return _image
@RUNNER_REGISTER("seko_talk") @RUNNER_REGISTER("seko_talk")
class WanAudioRunner(WanRunner): # type:ignore class WanAudioRunner(WanRunner): # type:ignore
def __init__(self, config): def __init__(self, config):
...@@ -399,7 +422,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -399,7 +422,7 @@ class WanAudioRunner(WanRunner): # type:ignore
return audio_files, mask_files return audio_files, mask_files
def process_single_mask(self, mask_file): def process_single_mask(self, mask_file):
mask_img = Image.open(mask_file).convert("RGB") mask_img = load_image(mask_file)
mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda() mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
if mask_img.shape[1] == 3: # If it is an RGB three-channel image if mask_img.shape[1] == 3: # If it is an RGB three-channel image
...@@ -426,7 +449,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -426,7 +449,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if isinstance(img_path, Image.Image): if isinstance(img_path, Image.Image):
ref_img = img_path ref_img = img_path
else: else:
ref_img = Image.open(img_path).convert("RGB") ref_img = load_image(img_path)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda() ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
ref_img, h, w = resize_image( ref_img, h, w = resize_image(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment