Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
import csv
csv.field_size_limit(5000000)
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr
# import video_transforms
# from utils import center_crop_arr
import json
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import json
import ast
import pandas as pd
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
# open-sora-plan+magictime dataset
class DatasetFromCSV(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
csv_path,
num_frames=16,
frame_interval=1,
transform=None,
root=None,
):
# video_samples = []
# with open(csv_path, "r") as f:
# reader = csv.reader(f)
# #csv_list = list(reader)
# for idx, v_s in enumerate(reader):
# vid_path = v_s[0]
# vid_caption = v_s[1]
# if os.path.exists(vid_path):
# video_samples.append([vid_path, vid_caption])
# if idx % 1000 == 0:
# print(idx)
video_samples = pd.read_csv(csv_path)
self.samples = video_samples #
print('video num:', self.samples.shape[0])
self.is_video = True
self.transform = transform
self.num_frames = num_frames
self.frame_interval = frame_interval
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
self.root = root
def getitem(self, index):
sample = self.samples.iloc[index].values
path = sample[0]
text = sample[1]
if self.is_video:
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
loop_index = index
while(total_frames < self.num_frames or is_exit == False):
#print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index += 1
if loop_index >= self.samples.shape[0]:
loop_index = 0
sample = self.samples.iloc[loop_index].values
path = sample[0]
text = sample[1]
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= self.num_frames
), f"{path} with index {index} has not enough frames."
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
#print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video = vframes[frame_indice]
video = self.transform(video) # T C H W
else:
image = pil_loader(path)
image = self.transform(image)
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return self.samples.shape[0]
if __name__ == '__main__':
data_path = '/mnt/bn/videodataset-uswest/VDiT/dataset/panda50m/panda70m_training_full.csv'
root='/mnt/bn/videodataset-uswest/panda70m'
dataset = DatasetFromCSV(
data_path,
transform=get_transforms_video(),
num_frames=16,
frame_interval=3,
root=root,
)
sampler = DistributedSampler(
dataset,
num_replicas=1,
rank=0,
shuffle=True,
seed=1
)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=sampler,
num_workers=0,
pin_memory=True,
drop_last=True
)
for video_data in loader:
print(video_data)
\ No newline at end of file
import csv
csv.field_size_limit(8000000)
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr
# import video_transforms
# from utils import center_crop_arr
import json
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import json
import ast
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
# open-sora-plan+magictime dataset
class DatasetFromCSV(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
csv_path,
num_frames=16,
frame_interval=1,
transform=None,
root=None,
):
video_samples = []
with open(csv_path, "r") as f:
reader = csv.reader(f)
#csv_list = list(reader)
print('csv read end')
parts_list = os.listdir(root)
for idx, v_s in enumerate(reader):
if idx > 0: # no csv head
vid_name = v_s[0]
vid_captions_str = v_s[3]
vid_captions=ast.literal_eval(vid_captions_str)
#vid_captions = vid_captions_str.split('\'')[1::2]
for part in parts_list:
vids_path = os.path.join(root, part, vid_name)
if os.path.isdir(vids_path):
for ic, cap in enumerate(vid_captions):
vid_path = os.path.join(root, part, vid_name, vid_name+"_"+str(ic)+".mp4")
if os.path.exists(vid_path):
video_samples.append([vid_path, cap])
break
if idx % 1000 == 0:
print('read video', idx)
self.samples = video_samples #
print('video num:', len(self.samples))
self.is_video = True
self.transform = transform
self.num_frames = num_frames
self.frame_interval = frame_interval
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
self.root = root
def getitem(self, index):
sample = self.samples[index]
path = sample[0]
text = sample[1]
if self.is_video:
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
loop_index = index
while(total_frames < self.num_frames or is_exit == False):
#print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index += 1
if loop_index >= len(self.samples):
loop_index = 0
sample = self.samples[loop_index]
path = sample[0]
text = sample[1]
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= self.num_frames
), f"{path} with index {index} has not enough frames."
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
#print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video = vframes[frame_indice]
video = self.transform(video) # T C H W
else:
image = pil_loader(path)
image = self.transform(image)
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return len(self.samples)
if __name__ == '__main__':
data_path = '/mnt/bn/videodataset-uswest/VDiT/dataset/panda50m/panda70m_training_full.csv'
root='/mnt/bn/videodataset-uswest/panda70m'
dataset = DatasetFromCSV(
data_path,
transform=get_transforms_video(),
num_frames=16,
frame_interval=3,
root=root,
)
sampler = DistributedSampler(
dataset,
num_replicas=1,
rank=0,
shuffle=True,
seed=1
)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=sampler,
num_workers=0,
pin_memory=True,
drop_last=True
)
for video_data in loader:
print(video_data)
\ No newline at end of file
import csv
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr
# import video_transforms
# from utils import center_crop_arr
import json
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import ipdb
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
# open-sora-plan+magictime dataset
class DatasetFromCSV(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
csv_path,
num_frames=16,
frame_interval=1,
transform=None,
root=None,
):
video_samples = []
with open(csv_path, "r") as f:
reader = csv.reader(f)
csv_list = list(reader)
for v_s in csv_list[1:]: # no csv head
vid_path = v_s[0]
vid_caption = v_s[1]
if os.path.exists(vid_path):
video_samples.append([vid_path, vid_caption])
self.samples = video_samples
self.is_video = True
self.transform = transform
self.num_frames = num_frames
self.frame_interval = frame_interval
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
self.root = root
def getitem(self, index):
sample = self.samples[index]
path = sample[0]
text = sample[1]
if self.is_video:
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
loop_index = index
while(total_frames < self.num_frames or is_exit == False):
#print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index += 1
if loop_index >= len(self.samples):
loop_index = 0
sample = self.samples[loop_index]
path = sample[0]
text = sample[1]
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= self.num_frames
), f"{path} with index {index} has not enough frames."
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
#print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video = vframes[frame_indice]
video = self.transform(video) # T C H W
else:
image = pil_loader(path)
image = self.transform(image)
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return len(self.samples)
if __name__ == '__main__':
data_path = '/mnt/bn/yh-volume0/dataset/CelebvHQ/CelebvHQ_caption_llava-34B.csv'
root='/mnt/bn/yh-volume0/dataset/CelebvHQ/35666'
dataset = DatasetFromCSV(
data_path,
transform=get_transforms_video(),
num_frames=16,
frame_interval=3,
root=root,
)
sampler = DistributedSampler(
dataset,
num_replicas=1,
rank=0,
shuffle=True,
seed=1
)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=sampler,
num_workers=0,
pin_memory=True,
drop_last=True
)
for video_data in loader:
print(video_data)
\ No newline at end of file
import csv
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
class DatasetFromCSV(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
csv_path,
num_frames=16,
frame_interval=1,
transform=None,
root=None,
):
self.csv_path = csv_path
with open(csv_path, "r") as f:
reader = csv.reader(f)
csv_list = list(reader)
all_samples = csv_list[1:] #no head, 10727607
# sample_samples = random.sample(all_samples, 400000) # 400k = 366k + 20k + 20k
sample_samples = []
for i_s, sample in enumerate(all_samples):
if i_s % 25 == 0:
if sample[2] != '0':
sample_samples.append(sample)
self.samples = sample_samples # 429105
ext = self.samples[0][0].split(".")[-1]
if ext.lower() in ("mp4", "avi", "mov", "mkv"):
self.is_video = True
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
self.is_video = False
self.transform = transform
self.num_frames = num_frames
self.frame_interval = frame_interval
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
self.root = root
def getitem(self, index):
sample = self.samples[index]
path = sample[0]
if self.root:
path = os.path.join(self.root, path)
text = sample[-1]
if self.is_video:
# old
# vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
# total_frames = len(vframes)
# # Sampling video frames
# start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
# assert (
# end_frame_ind - start_frame_ind >= self.num_frames
# ), f"{path} with index {index} has not enough frames."
# frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
# video = vframes[frame_indice]
# video = self.transform(video) # T C H W
# new
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
loop_index = index
while(total_frames < self.num_frames or is_exit == False):
print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index += 1
if loop_index >= len(self.samples):
loop_index = 0
sample = self.samples[loop_index]
path = sample[0]
if self.root:
path = os.path.join(self.root, path)
text = sample[-1]
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= self.num_frames
), f"{path} with index {index} has not enough frames."
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video = vframes[frame_indice]
video = self.transform(video) # T C H W
else:
image = pil_loader(path)
image = self.transform(image)
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return len(self.samples)
import csv
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr
# import video_transforms
# from utils import center_crop_arr
import json
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import ipdb
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
class DatasetFromCSV(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
csv_path,
num_frames=16,
frame_interval=1,
transform=None,
root=None,
):
self.csv_path = csv_path
with open(csv_path, "r") as f:
reader = csv.reader(f)
csv_list = list(reader)
all_samples = csv_list[1:] #no head, 10727607
sample_samples = []
for i_s, sample in enumerate(all_samples):
if sample[2] != '0':
sample_samples.append(sample)
print('samples num:', len(sample_samples))
self.samples = sample_samples # 10727337
self.is_video = True
self.transform = transform
self.num_frames = num_frames
self.frame_interval = frame_interval
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
self.root = root
def getitem(self, index):
sample = self.samples[index]
path = sample[0]
if self.root:
path = os.path.join(self.root, path)
text = sample[-1]
#path = "/mnt/bn/yh-volume0/dataset/webvid/raw/videos/train/videos_new/013501_013550-33142969.mp4"
if self.is_video:
# old
# vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
# total_frames = len(vframes)
# # Sampling video frames
# start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
# assert (
# end_frame_ind - start_frame_ind >= self.num_frames
# ), f"{path} with index {index} has not enough frames."
# frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
# video = vframes[frame_indice]
# video = self.transform(video) # T C H W
# new
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
loop_index = index
while(total_frames < self.num_frames or is_exit == False):
#print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index += 1
if loop_index >= len(self.samples):
loop_index = 0
sample = self.samples[loop_index]
path = sample[0]
if self.root:
path = os.path.join(self.root, path)
text = sample[-1]
is_exit = os.path.exists(path)
if is_exit:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
total_frames = len(vframes)
else:
total_frames = 0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= self.num_frames
), f"{path} with index {index} has not enough frames."
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
#print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video = vframes[frame_indice]
video = self.transform(video) # T C H W
else:
image = pil_loader(path)
image = self.transform(image)
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
#print('video shape:', video.shape,'text:', text, 'video path:', path)
return {"video": video, "text": text}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return len(self.samples)
if __name__ == '__main__':
data_path = '/mnt/bn/yh-volume0/dataset/webvid/raw/webvid_csv/train.csv'
root='/mnt/bn/yh-volume0/dataset/webvid/raw/videos/train/videos_new'
dataset = DatasetFromCSV(
data_path,
transform=get_transforms_video(),
num_frames=16,
frame_interval=3,
root=root,
)
sampler = DistributedSampler(
dataset,
num_replicas=1,
rank=0,
shuffle=True,
seed=1
)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=sampler,
num_workers=0,
pin_memory=True,
drop_last=True
)
for video_data in loader:
print(video_data)
\ No newline at end of file
Real-ESRGAN Degradation Dataset Pipeline. One can generate own degraded datasets using this pipeline.
Note: This Project is derived from https://github.com/xinntao/Real-ESRGAN
import argparse
import cv2
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import math
import os
import torch.nn as nn
from PIL import Image
from utils_ import filter2D, USMSharp
from utils_blur import circular_lowpass_kernel, random_mixed_kernels
from utils_resize import random_resizing
from utils_noise import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from utils_jpeg import DiffJPEG
# from matlab_functions import imresize
# from torchvision import transforms
from torch.nn import functional as F
class Degradation(nn.Module):
def __init__(self, scale, gt_size):
super(Degradation, self).__init__()
### initization JPEF class
self.jpeger = DiffJPEG(differentiable=False)#.cuda()
self.usm_sharpener = USMSharp()#.cuda()
# self.queue_size = 180 #opt.get('queue_size', 180)
### global settings
self.scale = scale
self.gt_size = gt_size
### the first degradation hypermeters ###
# 1. blur
self.blur_kernel_size = 21
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
self.kernel_list = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
self.kernel_prob = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
self.sinc_prob = 0.1
self.blur_sigma = [0.2, 3] # blur_x / y_sigma
self.betag_range = [0.5, 4]
self.betap_range = [1, 2]
# 2. resize
self.updown_type = ["up", "down", "keep"]
self.mode_list = ["area", "bilinear", "bicubic"] # flags:[3,1,2]
self.resize_prob = [0.2, 0.7, 0.1] # up, down, keep
self.resize_range = [0.15, 1.5]
# 3. noise
self.gaussian_noise_prob = 0.5
self.noise_range = [1, 30]
self.poisson_scale_range = [0.05, 3]
self.gray_noise_prob = 0.4
# 4. jpeg
self.jpeg_range = [30, 95]
### the second degradation hypermeters ###
# 1. blur
self.second_blur_prob = 0.8
self.blur_kernel_size2 = 21
self.kernel_range2 = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
self.kernel_list2 = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
self.kernel_prob2 = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
self.sinc_prob2 = 0.1
self.blur_sigma2 = [0.2, 1.5]
self.betag_range2 = [0.5, 4]
self.betap_range2 = [1, 2]
# 2. resize
self.updown_type2 = ["up", "down", "keep"]
self.mode_list2 = ["area", "bilinear", "bicubic"] # flags:[3,1,2]
self.resize_prob2 = [0.3, 0.4, 0.3] # up, down, keep
self.resize_range2 = [0.3, 1.2]
# 3. noise
self.gaussian_noise_prob2 = 0.5
self.noise_range2 = [1, 25]
self.poisson_scale_range2 = [0.05, 2.5]
self.gray_noise_prob2 = 0.4
# 4. jpeg
self.jpeg_range2 = [30, 95]
self.final_sinc_prob = 0.8
# TODO: kernel range is now hard-coded, should be in the configure file
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
self.pulse_tensor[10, 10] = 1
@torch.no_grad()
def forward(self, gt):
ori_h, ori_w = gt.size()[2:4]
gt_usm = self.usm_sharpener(gt)
gt_usm_copy = gt_usm.clone()
# generate kernel
kernel1 = self.generate_first_kernel()
kernel2 = self.generate_second_kernel()
sinc_kernel = self.generate_sinc_kernel()
# first degradation
lq = self.jpeg_1(self.noise_1(self.resize_1(self.blur_1(gt_usm_copy, kernel1))))
# second degradation
lq = self.jpeg_2(self.noise_2(self.resize_2(self.blur_2(lq, kernel2), ori_h,ori_w)), ori_h,ori_w, sinc_kernel)
return lq, gt_usm #, kernel1, kernel2, sinc_kernel
@torch.no_grad()
# def forward(self, gt_path, uint8=False):
# # read hwc 0-1 numpy
# img_gt = cv2.imread(gt_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
# # augment
# img_gt = self.augment(img_gt, True, True)
# # numpy 0-1 hwc -> tensor 0-1 chw
# img_gt = self.np2tensor([img_gt], bgr2rgb=True, float32=True)[0]
# # add batch
# img_gt = img_gt.unsqueeze(0)
# img_gt_copy = img_gt.clone()
# # degradation_piepline
# lq, gt_usm, kernel1, kernel2, sinc_kernel = self.forward_deg(img_gt)
# # clamp and round
# lq = torch.clamp((lq * 255.0).round(), 0, 255) / 255.
# print(f'before crop: gt:{img_gt_copy.shape}, lq:{lq.shape}')
# # random crop
# (gt, gt_usm), lq = self.paired_random_crop([img_gt_copy, gt_usm], lq, self.gt_size, self.scale)
# print(f'after crop: gt:{gt.shape}, lq:{lq.shape}')
#
# if uint8:
# gt, gt_usm, lq = self.tensor2np([gt, gt_usm, lq])
# return gt, gt_usm, lq, kernel1, kernel2, sinc_kernel
#
# return gt, gt_usm, lq, kernel1, kernel2, sinc_kernel
def blur_1(self, img, kernel1):
img = filter2D(img, kernel1)
return img
def blur_2(self, img, kernel2):
if np.random.uniform() < self.second_blur_prob:
img = filter2D(img, kernel2)
return img
def resize_1(self, img):
updown_type = random.choices(['up', 'down', 'keep'], self.resize_prob)[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.resize_range[1])
elif updown_type == 'down':
scale = np.random.uniform(self.resize_range[0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
img = F.interpolate(img, scale_factor=scale, mode=mode)
return img
def resize_2(self, img, ori_h, ori_w):
updown_type = random.choices(['up', 'down', 'keep'], self.resize_prob2)[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.resize_range2[1])
elif updown_type == 'down':
scale = np.random.uniform(self.resize_range2[0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
img = F.interpolate(
img, size=(int(ori_h / self.scale * scale), int(ori_w / scale * scale)), mode=mode)
return img
def noise_1(self, img):
gray_noise_prob = self.gray_noise_prob
if np.random.uniform() < self.gaussian_noise_prob:
img = random_add_gaussian_noise_pt(img, sigma_range=self.noise_range, clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
img = random_add_poisson_noise_pt(
img,
scale_range=self.poisson_scale_range,
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
return img
def noise_2(self, img):
gray_noise_prob = self.gray_noise_prob2
if np.random.uniform() < self.gaussian_noise_prob2:
img = random_add_gaussian_noise_pt(
img, sigma_range=self.noise_range2, clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
img = random_add_poisson_noise_pt(
img,
scale_range=self.poisson_scale_range2,
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
return img
def jpeg_1(self, img):
jpeg_p = img.new_zeros(img.size(0)).uniform_(*self.jpeg_range)
img = torch.clamp(img, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
img = self.jpeger(img, quality=jpeg_p)
return img
def jpeg_2(self, out, ori_h, ori_w, sinc_kernel):
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.scale, ori_w // self.scale), mode=mode)
out = filter2D(out, sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.jpeg_range2)
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.jpeg_range2)
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.scale, ori_w // self.scale), mode=mode)
out = filter2D(out, sinc_kernel)
return out
def generate_first_kernel(self):
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.sinc_prob:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
return torch.FloatTensor(kernel)
def generate_second_kernel(self):
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.sinc_prob2:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
return torch.FloatTensor(kernel2)
def generate_sinc_kernel(self):
if np.random.uniform() < self.final_sinc_prob:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
return sinc_kernel
def np2tensor(self, imgs, bgr2rgb=False, float32=True):
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def tensor2np(self, imgs):
def _tonumpy(img):
img = img.data.cpu().numpy().squeeze(0).transpose(1,2,0) #.astype(np.float32)
img = np.uint8((img.clip(0,1) * 255.).round())
return img
if isinstance(imgs, list):
return [_tonumpy(img) for img in imgs]
else:
return _tonumpy(imgs)
def augment(self, imgs, hflip=True, rotation=True, flows=None, return_status=False):
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
return imgs
def paired_random_crop(self, img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
"""Paired random crop. Support Numpy array and Tensor inputs.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth. Default: None.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
# determine input type: Numpy array or Tensor
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
if input_type == 'Tensor':
h_lq, w_lq = img_lqs[0].size()[-2:]
h_gt, w_gt = img_gts[0].size()[-2:]
else:
h_lq, w_lq = img_lqs[0].shape[0:2]
h_gt, w_gt = img_gts[0].shape[0:2]
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
if input_type == 'Tensor':
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
else:
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
if input_type == 'Tensor':
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
else:
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
if __name__ == '__main__':
# print(os.path.abspath(os.path.join(__file__, os.path.pardir)))
deg_pipeline = Degradation(scale=4, gt_size=256)
# gt_path = r'J:\Dataset\SR\Real_ESRGAN\DF2K_multiscale_sub\0052T0_s024.png'
gt_path = './Dataset/train/DIV2K/HR/'
gt, gt_usm, lq, kernel1, kernel2, sinc_kernel = deg_pipeline(gt_path, uint8=True)
cv2.imwrite('lq.png', lq)
cv2.imwrite('gt.png', gt)
cv2.imwrite('gt_usm.png', gt)
from re import T
import cv2
import numpy as np
import random
import torch
import math
import torch.nn as nn
import sys
sys.path.append('./opensora/datasets/high_order')
from utils_ import filter2D, USMSharp
from utils_blur import circular_lowpass_kernel, random_mixed_kernels
from utils_noise import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from utils_jpeg import DiffJPEG
from torch.nn import functional as F
from einops import rearrange
import av
import io
class ImageCompressor:
def __init__(self):
self.params = {
'codec': ['libx264', 'h264', 'mpeg4'],
'codec_prob': [1 / 3., 1 / 3., 1 / 3.],
'bitrate': [1e4, 1e5]
}
def _ensure_even_dimensions(self, img):
# Ensure width and height are even
h, w = img.shape[:2]
if h % 2 != 0:
img = img[:-1, :]
if w % 2 != 0:
img = img[:, :-1]
return img
def _apply_random_compression(self, imgs):
# Convert PyTorch tensor to NumPy array
imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
# Ensure width and height are even
imgs = [self._ensure_even_dimensions(img) for img in imgs]
codec = random.choices(self.params['codec'], self.params['codec_prob'])[0]
bitrate = self.params['bitrate']
bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
buf = io.BytesIO()
with av.open(buf, 'w', 'mp4') as container:
stream = container.add_stream(codec, rate=1)
stream.height = imgs[0].shape[0]
stream.width = imgs[0].shape[1]
stream.pix_fmt = 'yuv420p'
stream.bit_rate = bitrate
for img in imgs:
img = (img * 255).clip(0, 255) # Convert to [0, 255] range
img = img.astype(np.uint8)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame.pict_type = 'NONE'
for packet in stream.encode(frame):
container.mux(packet)
# Flush stream
for packet in stream.encode():
container.mux(packet)
outputs = []
with av.open(buf, 'r', 'mp4') as container:
if container.streams.video:
for frame in container.decode(**{'video': 0}):
outputs.append(frame.to_rgb().to_ndarray().astype(np.float32) / 255) # Convert back to [0, 1] range
# Convert NumPy array back to PyTorch tensor
outputs = torch.tensor(outputs).permute(0, 3, 1, 2)
return outputs
class Degradation(nn.Module):
def __init__(self, scale, gt_size):
super(Degradation, self).__init__()
### initization JPEF class
self.jpeger = DiffJPEG(differentiable=False)#.cuda()
self.usm_sharpener = USMSharp()#.cuda()
# self.queue_size = 180 #opt.get('queue_size', 180)
### global settings
self.scale = scale
self.gt_size = gt_size
### the first degradation hypermeters ###
# 1. blur
self.blur_kernel_size = 21
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
self.kernel_list = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
self.kernel_prob = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
self.sinc_prob = 0.1
self.blur_sigma = [0.2, 3] # blur_x / y_sigma
self.betag_range = [0.5, 4]
self.betap_range = [1, 2]
# 2. resize
self.updown_type = ["up", "down", "keep"]
self.mode_list = ["area", "bilinear", "bicubic"] # flags:[3,1,2]
self.resize_prob = [0.2, 0.7, 0.1] # up, down, keep
self.resize_range = [0.15, 1.5]
# 3. noise
self.gaussian_noise_prob = 0.5
self.noise_range = [1, 30]
self.poisson_scale_range = [0.05, 3]
self.gray_noise_prob = 0.4
# 4. jpeg
self.jpeg_range = [30, 95]
### the second degradation hypermeters ###
# 1. blur
self.second_blur_prob = 0.8
self.blur_kernel_size2 = 21
self.kernel_range2 = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
self.kernel_list2 = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
self.kernel_prob2 = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
self.sinc_prob2 = 0.1
self.blur_sigma2 = [0.2, 1.5]
self.betag_range2 = [0.5, 4]
self.betap_range2 = [1, 2]
# 2. resize
self.updown_type2 = ["up", "down", "keep"]
self.mode_list2 = ["area", "bilinear", "bicubic"] # flags:[3,1,2]
self.resize_prob2 = [0.3, 0.4, 0.3] # up, down, keep
self.resize_range2 = [0.3, 1.2]
# 3. noise
self.gaussian_noise_prob2 = 0.5
self.noise_range2 = [1, 25]
self.poisson_scale_range2 = [0.05, 2.5]
self.gray_noise_prob2 = 0.4
# 4. jpeg
self.jpeg_range2 = [30, 95]
self.final_sinc_prob = 0.8
# TODO: kernel range is now hard-coded, should be in the configure file
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
self.pulse_tensor[10, 10] = 1
# video compression
self.compressor = ImageCompressor()
@torch.no_grad()
def forward_deg(self, gt):
ori_h, ori_w = gt.size()[2:4]
gt_usm = self.usm_sharpener(gt)
gt_usm_copy = gt_usm.clone()
# generate kernel
kernel1 = self.generate_first_kernel()
kernel2 = self.generate_second_kernel()
sinc_kernel = self.generate_sinc_kernel()
# first degradation
lq = self.compressor._apply_random_compression(self.jpeg_1(self.noise_1(self.resize_1(self.blur_1(gt_usm_copy, kernel1)))))
# second degradation
lq = self.compressor._apply_random_compression(self.jpeg_2(self.noise_2(self.resize_2(self.blur_2(lq, kernel2), ori_h,ori_w)), ori_h,ori_w, sinc_kernel))
return lq, gt_usm, kernel1, kernel2, sinc_kernel
@torch.no_grad()
def forward(self, img_gt, uint8=False):
# read hwc 0-1 numpy
# img_gt = cv2.imread(gt_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
# augment
# img_gt = self.augment(img_gt, True, True)
# numpy 0-1 hwc -> tensor 0-1 chw
# img_gt = self.np2tensor([img_gt], bgr2rgb=True, float32=True)[0]
# add batch
img_gt = img_gt.unsqueeze(0)
img_gt_copy = img_gt.clone()
# degradation_piepline
lq, gt_usm, kernel1, kernel2, sinc_kernel = self.forward_deg(img_gt_copy)
# clamp and round
lq = torch.clamp((lq * 255.0).round(), 0, 255) / 255.
# print(f'before crop: gt:{img_gt_copy.shape}, lq:{lq.shape}')
# random crop
# (gt, gt_usm), lq = self.paired_random_crop([img_gt_copy, gt_usm], lq, self.gt_size, self.scale)
# print(f'after crop: gt:{gt.shape}, lq:{lq.shape}')
# if uint8:
# gt, gt_usm, lq = self.tensor2np([gt, gt_usm, lq])
# return gt, gt_usm, lq, kernel1, kernel2, sinc_kernel
return lq, gt_usm # gt, kernel1, kernel2, sinc_kernel
def blur_1(self, img, kernel1):
img = filter2D(img, kernel1)
return img
def blur_2(self, img, kernel2):
if np.random.uniform() < self.second_blur_prob:
img = filter2D(img, kernel2)
return img
def resize_1(self, img):
updown_type = random.choices(['up', 'down', 'keep'], self.resize_prob)[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.resize_range[1])
elif updown_type == 'down':
scale = np.random.uniform(self.resize_range[0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
img = F.interpolate(img, scale_factor=scale, mode=mode)
return img
def resize_2(self, img, ori_h, ori_w):
updown_type = random.choices(['up', 'down', 'keep'], self.resize_prob2)[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.resize_range2[1])
elif updown_type == 'down':
scale = np.random.uniform(self.resize_range2[0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
img = F.interpolate(
img, size=(int(ori_h / self.scale * scale), int(ori_w / scale * scale)), mode=mode)
return img
def noise_1(self, img):
gray_noise_prob = self.gray_noise_prob
if np.random.uniform() < self.gaussian_noise_prob:
img = random_add_gaussian_noise_pt(img, sigma_range=self.noise_range, clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
img = random_add_poisson_noise_pt(
img,
scale_range=self.poisson_scale_range,
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
return img
def noise_2(self, img):
gray_noise_prob = self.gray_noise_prob2
if np.random.uniform() < self.gaussian_noise_prob2:
img = random_add_gaussian_noise_pt(
img, sigma_range=self.noise_range2, clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
img = random_add_poisson_noise_pt(
img,
scale_range=self.poisson_scale_range2,
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
return img
def jpeg_1(self, img):
jpeg_p = img.new_zeros(img.size(0)).uniform_(*self.jpeg_range)
img = torch.clamp(img, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
img = self.jpeger(img, quality=jpeg_p)
return img
def jpeg_2(self, out, ori_h, ori_w, sinc_kernel):
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.scale, ori_w // self.scale), mode=mode)
out = filter2D(out, sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.jpeg_range2)
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.jpeg_range2)
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.scale, ori_w // self.scale), mode=mode)
out = filter2D(out, sinc_kernel)
return out
def generate_first_kernel(self):
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.sinc_prob:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
return torch.FloatTensor(kernel)
def generate_second_kernel(self):
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.sinc_prob2:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
return torch.FloatTensor(kernel2)
def generate_sinc_kernel(self):
if np.random.uniform() < self.final_sinc_prob:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
return sinc_kernel
def np2tensor(self, imgs, bgr2rgb=False, float32=True):
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def tensor2np(self, imgs):
def _tonumpy(img):
img = img.data.cpu().numpy().squeeze(0).transpose(1,2,0) #.astype(np.float32)
img = np.uint8((img.clip(0,1) * 255.).round())
return img
if isinstance(imgs, list):
return [_tonumpy(img) for img in imgs]
else:
return _tonumpy(imgs)
def augment(self, imgs, hflip=True, rotation=True, flows=None, return_status=False):
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
return imgs
def paired_random_crop(self, img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
"""Paired random crop. Support Numpy array and Tensor inputs.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth. Default: None.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
# determine input type: Numpy array or Tensor
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
if input_type == 'Tensor':
h_lq, w_lq = img_lqs[0].size()[-2:]
h_gt, w_gt = img_gts[0].size()[-2:]
else:
h_lq, w_lq = img_lqs[0].shape[0:2]
h_gt, w_gt = img_gts[0].shape[0:2]
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
if input_type == 'Tensor':
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
else:
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
if input_type == 'Tensor':
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
else:
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
# init degradation process
degradation = Degradation(scale=4, gt_size=(480, 720))
def degradation_process(video_array):
_, _, t, _, _ = video_array.shape
# preprocess video
video_array = video_array.to(torch.float32).cpu()
video_array = (video_array + 1) * 0.5 # [-1, 1] -> [0, 1]
video_array = rearrange(video_array, "B C T H W -> (B T) C H W")
assert torch.max(video_array) <= 1 and torch.min(video_array) >= 0, "Values are NOT within [0, 1]."
# degrade
lq_list = []
gt_list = []
for video in video_array:
lq, gt = degradation(video)
lq_list.append(lq)
gt_list.append(gt)
lq = torch.cat(lq_list, dim=0)
gt = torch.cat(gt_list, dim=0)
lq = lq.clip(0, 1) * 2 - 1
gt = gt.clip(0, 1) * 2 - 1
lq = rearrange(lq, "(B T) C H W -> B C T H W", T=t).to(torch.float32)
gt = rearrange(gt, "(B T) C H W -> B C T H W", T=t).to(torch.float32)
return lq, gt
\ No newline at end of file
from re import T
import cv2
import numpy as np
import random
import torch
import math
# import os
import torch.nn as nn
import sys
sys.path.append('/mnt/bn/videodataset-uswest/VSR/VSR/opensora/datasets/high_order')
from utils_ import filter2D, USMSharp
from utils_blur import circular_lowpass_kernel, random_mixed_kernels
# from utils_resize import random_resizing
from utils_noise import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from utils_jpeg import DiffJPEG
from torch.nn import functional as F
from einops import rearrange
import av
import io
class ImageCompressor:
def __init__(self):
self.params = {
'codec': ['libx264', 'h264', 'mpeg4'],
'codec_prob': [1 / 3., 1 / 3., 1 / 3.],
'bitrate': [1e4, 1e5]
}
def _ensure_even_dimensions(self, img):
# Ensure width and height are even
h, w = img.shape[:2]
if h % 2 != 0:
img = img[:-1, :]
if w % 2 != 0:
img = img[:, :-1]
return img
def _apply_random_compression(self, imgs):
# Convert PyTorch tensor to NumPy array
imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
# Ensure width and height are even
imgs = [self._ensure_even_dimensions(img) for img in imgs]
codec = random.choices(self.params['codec'], self.params['codec_prob'])[0]
bitrate = self.params['bitrate']
bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
buf = io.BytesIO()
with av.open(buf, 'w', 'mp4') as container:
stream = container.add_stream(codec, rate=1)
stream.height = imgs[0].shape[0]
stream.width = imgs[0].shape[1]
stream.pix_fmt = 'yuv420p'
stream.bit_rate = bitrate
for img in imgs:
img = (img * 255).clip(0, 255) # Convert to [0, 255] range
img = img.astype(np.uint8)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame.pict_type = 'NONE'
for packet in stream.encode(frame):
container.mux(packet)
# Flush stream
for packet in stream.encode():
container.mux(packet)
outputs = []
with av.open(buf, 'r', 'mp4') as container:
if container.streams.video:
for frame in container.decode(**{'video': 0}):
outputs.append(frame.to_rgb().to_ndarray().astype(np.float32) / 255) # Convert back to [0, 1] range
# Convert NumPy array back to PyTorch tensor
outputs = torch.tensor(outputs).permute(0, 3, 1, 2)
return outputs
class Degradation(nn.Module):
def __init__(self, scale, gt_size):
super(Degradation, self).__init__()
### initization JPEF class
self.jpeger = DiffJPEG(differentiable=False)#.cuda()
self.usm_sharpener = USMSharp()#.cuda()
# self.queue_size = 180 #opt.get('queue_size', 180)
### global settings
self.scale = scale
self.gt_size = gt_size
### the first degradation hypermeters ###
# 1. blur
self.blur_kernel_size = 21
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
self.kernel_list = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
self.kernel_prob = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
self.sinc_prob = 0.1
self.blur_sigma = [0.2, 3] # blur_x / y_sigma
self.betag_range = [0.5, 4]
self.betap_range = [1, 2]
# 2. resize
self.updown_type = ["up", "down", "keep"]
self.mode_list = ["area", "bilinear", "bicubic"] # flags:[3,1,2]
self.resize_prob = [0.2, 0.7, 0.1] # up, down, keep
self.resize_range = [0.15, 1.5]
# 3. noise
self.gaussian_noise_prob = 0.5
self.noise_range = [1, 30]
self.poisson_scale_range = [0.05, 3]
self.gray_noise_prob = 0.4
# 4. jpeg
self.jpeg_range = [30, 95]
### the second degradation hypermeters ###
# 1. blur
self.second_blur_prob = 0.8
self.blur_kernel_size2 = 21
self.kernel_range2 = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
self.kernel_list2 = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
self.kernel_prob2 = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
self.sinc_prob2 = 0.1
self.blur_sigma2 = [0.2, 1.5]
self.betag_range2 = [0.5, 4]
self.betap_range2 = [1, 2]
# 2. resize
self.updown_type2 = ["up", "down", "keep"]
self.mode_list2 = ["area", "bilinear", "bicubic"] # flags:[3,1,2]
self.resize_prob2 = [0.3, 0.4, 0.3] # up, down, keep
self.resize_range2 = [0.3, 1.2]
# 3. noise
self.gaussian_noise_prob2 = 0.5
self.noise_range2 = [1, 25]
self.poisson_scale_range2 = [0.05, 2.5]
self.gray_noise_prob2 = 0.4
# 4. jpeg
self.jpeg_range2 = [30, 95]
self.final_sinc_prob = 0.8
# TODO: kernel range is now hard-coded, should be in the configure file
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
self.pulse_tensor[10, 10] = 1
# video compression
self.compressor = ImageCompressor()
@torch.no_grad()
def forward_deg(self, gt):
ori_h, ori_w = gt.size()[2:4]
gt_usm = self.usm_sharpener(gt)
gt_usm_copy = gt_usm.clone()
# generate kernel
kernel1 = self.generate_first_kernel()
kernel2 = self.generate_second_kernel()
sinc_kernel = self.generate_sinc_kernel()
# first degradation
lq = self.compressor._apply_random_compression(self.jpeg_1(self.noise_1(self.resize_2(self.blur_1(gt_usm_copy, kernel1)))))
# second degradation
# lq = self.compressor._apply_random_compression(self.jpeg_2(self.noise_2(self.resize_2(self.blur_2(lq, kernel2), ori_h,ori_w)), ori_h,ori_w, sinc_kernel))
return lq, gt_usm, kernel1, kernel2, sinc_kernel
@torch.no_grad()
def forward(self, img_gt, uint8=False):
# read hwc 0-1 numpy
# img_gt = cv2.imread(gt_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
# augment
# img_gt = self.augment(img_gt, True, True)
# numpy 0-1 hwc -> tensor 0-1 chw
# img_gt = self.np2tensor([img_gt], bgr2rgb=True, float32=True)[0]
# add batch
img_gt = img_gt.unsqueeze(0)
img_gt_copy = img_gt.clone()
# degradation_piepline
lq, gt_usm, kernel1, kernel2, sinc_kernel = self.forward_deg(img_gt_copy)
# clamp and round
lq = torch.clamp((lq * 255.0).round(), 0, 255) / 255.
# print(f'before crop: gt:{img_gt_copy.shape}, lq:{lq.shape}')
# random crop
# (gt, gt_usm), lq = self.paired_random_crop([img_gt_copy, gt_usm], lq, self.gt_size, self.scale)
# print(f'after crop: gt:{gt.shape}, lq:{lq.shape}')
# if uint8:
# gt, gt_usm, lq = self.tensor2np([gt, gt_usm, lq])
# return gt, gt_usm, lq, kernel1, kernel2, sinc_kernel
return lq, gt_usm # gt, kernel1, kernel2, sinc_kernel
def blur_1(self, img, kernel1):
img = filter2D(img, kernel1)
return img
def blur_2(self, img, kernel2):
if np.random.uniform() < self.second_blur_prob:
img = filter2D(img, kernel2)
return img
def resize_1(self, img):
updown_type = random.choices(['up', 'down', 'keep'], self.resize_prob)[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.resize_range[1])
elif updown_type == 'down':
scale = np.random.uniform(self.resize_range[0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
img = F.interpolate(img, scale_factor=scale, mode=mode)
return img
def resize_2(self, img, ori_h, ori_w):
updown_type = random.choices(['up', 'down', 'keep'], self.resize_prob2)[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.resize_range2[1])
elif updown_type == 'down':
scale = np.random.uniform(self.resize_range2[0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
img = F.interpolate(
img, size=(int(ori_h / self.scale * scale), int(ori_w / scale * scale)), mode=mode)
return img
def noise_1(self, img):
gray_noise_prob = self.gray_noise_prob
if np.random.uniform() < self.gaussian_noise_prob:
img = random_add_gaussian_noise_pt(img, sigma_range=self.noise_range, clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
img = random_add_poisson_noise_pt(
img,
scale_range=self.poisson_scale_range,
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
return img
def noise_2(self, img):
gray_noise_prob = self.gray_noise_prob2
if np.random.uniform() < self.gaussian_noise_prob2:
img = random_add_gaussian_noise_pt(
img, sigma_range=self.noise_range2, clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
img = random_add_poisson_noise_pt(
img,
scale_range=self.poisson_scale_range2,
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
return img
def jpeg_1(self, img):
jpeg_p = img.new_zeros(img.size(0)).uniform_(*self.jpeg_range)
img = torch.clamp(img, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
img = self.jpeger(img, quality=jpeg_p)
return img
def jpeg_2(self, out, ori_h, ori_w, sinc_kernel):
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.scale, ori_w // self.scale), mode=mode)
out = filter2D(out, sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.jpeg_range2)
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.jpeg_range2)
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.scale, ori_w // self.scale), mode=mode)
out = filter2D(out, sinc_kernel)
return out
def generate_first_kernel(self):
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.sinc_prob:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
return torch.FloatTensor(kernel)
def generate_second_kernel(self):
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.sinc_prob2:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
return torch.FloatTensor(kernel2)
def generate_sinc_kernel(self):
if np.random.uniform() < self.final_sinc_prob:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
return sinc_kernel
def np2tensor(self, imgs, bgr2rgb=False, float32=True):
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def tensor2np(self, imgs):
def _tonumpy(img):
img = img.data.cpu().numpy().squeeze(0).transpose(1,2,0) #.astype(np.float32)
img = np.uint8((img.clip(0,1) * 255.).round())
return img
if isinstance(imgs, list):
return [_tonumpy(img) for img in imgs]
else:
return _tonumpy(imgs)
def augment(self, imgs, hflip=True, rotation=True, flows=None, return_status=False):
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
return imgs
def paired_random_crop(self, img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
"""Paired random crop. Support Numpy array and Tensor inputs.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth. Default: None.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
# determine input type: Numpy array or Tensor
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
if input_type == 'Tensor':
h_lq, w_lq = img_lqs[0].size()[-2:]
h_gt, w_gt = img_gts[0].size()[-2:]
else:
h_lq, w_lq = img_lqs[0].shape[0:2]
h_gt, w_gt = img_gts[0].shape[0:2]
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
if input_type == 'Tensor':
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
else:
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
if input_type == 'Tensor':
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
else:
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
# init degradation process
degradation = Degradation(scale=4, gt_size=(480, 720))
def degradation_process(video_array):
_, _, t, _, _ = video_array.shape
# preprocess video
video_array = video_array.to(torch.float32).cpu()
video_array = (video_array + 1) * 0.5 # [-1, 1] -> [0, 1]
video_array = rearrange(video_array, "B C T H W -> (B T) C H W")
assert torch.max(video_array) <= 1 and torch.min(video_array) >= 0, "Values are NOT within [0, 1]."
# degrade
lq_list = []
gt_list = []
for video in video_array:
lq, gt = degradation(video)
lq_list.append(lq)
gt_list.append(gt)
lq = torch.cat(lq_list, dim=0)
gt = torch.cat(gt_list, dim=0)
lq = lq.clip(0, 1) * 2 - 1
gt = gt.clip(0, 1) * 2 - 1
lq = rearrange(lq, "(B T) C H W -> B C T H W", T=t).to(torch.float32)
gt = rearrange(gt, "(B T) C H W -> B C T H W", T=t).to(torch.float32)
return lq, gt
\ No newline at end of file
import math
import numpy as np
import torch
def cubic(x):
"""cubic function used for calculate_weights_indices."""
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
(absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
"""Calculate weights and indices, used for imresize function.
Args:
in_length (int): Input length.
out_length (int): Output length.
scale (float): Scale factor.
kernel_width (int): Kernel width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
"""
if (scale < 1) and antialiasing:
# Use a modified kernel (larger kernel width) to simultaneously
# interpolate and antialias
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5 + scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
p = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
out_length, p)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
# apply cubic kernel
if (scale < 1) and antialiasing:
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, p)
# If a column in weights is all zero, get rid of it. only consider the
# first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, p - 2)
weights = weights.narrow(1, 1, p - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, p - 2)
weights = weights.narrow(1, 0, p - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
@torch.no_grad()
def imresize(img, scale, antialiasing=True):
"""imresize function same as MATLAB.
It now only supports bicubic.
The same scale applies for both height and width.
Args:
img (Tensor | Numpy array):
Tensor: Input image with shape (c, h, w), [0, 1] range.
Numpy: Input image with shape (h, w, c), [0, 1] range.
scale (float): Scale factor. The same scale applies for both height
and width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
Default: True.
Returns:
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
"""
squeeze_flag = False
if type(img).__module__ == np.__name__: # numpy type
numpy_type = True
if img.ndim == 2:
img = img[:, :, None]
squeeze_flag = True
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
else:
numpy_type = False
if img.ndim == 2:
img = img.unsqueeze(0)
squeeze_flag = True
in_c, in_h, in_w = img.size()
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
kernel_width = 4
kernel = 'cubic'
# get weights and indices
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
antialiasing)
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
sym_patch = img[:, :sym_len_hs, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
sym_patch = img[:, -sym_len_he:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(in_c, out_h, in_w)
kernel_width = weights_h.size(1)
for i in range(out_h):
idx = int(indices_h[i][0])
for j in range(in_c):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
sym_patch = out_1[:, :, :sym_len_ws]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
sym_patch = out_1[:, :, -sym_len_we:]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(in_c, out_h, out_w)
kernel_width = weights_w.size(1)
for i in range(out_w):
idx = int(indices_w[i][0])
for j in range(in_c):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
if squeeze_flag:
out_2 = out_2.squeeze(0)
if numpy_type:
out_2 = out_2.numpy()
if not squeeze_flag:
out_2 = out_2.transpose(1, 2, 0)
#tensor CHW [0,1] -> numpy HWC [0,1]
out_2 = out_2.numpy().transpose((1,2,0))
return out_2
\ No newline at end of file
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import functional as F
"""
filter2D(img, kernel)
usm_sharp(img, weight=0.5, radius=50, threshold=10)
USMSharp(torch.nn.Module)
"""
def filter2D(img, kernel):
"""PyTorch version of cv2.filter2D
Args:
img (Tensor): (b, c, h, w)
kernel (Tensor): (b, k, k)
"""
k = kernel.size(-1)
b, c, h, w = img.size()
if k % 2 == 1:
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
else:
raise ValueError('Wrong kernel size')
ph, pw = img.size()[-2:]
# if kernel.size(0) == 1:
# apply the same kernel to all batch images
img = img.contiguous().view(b * c, 1, ph, pw)
kernel = kernel.view(1, 1, k, k)
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
# else:
# img = img.view(1, b * c, ph, pw)
# kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
# return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening.
Input image: I; Blurry image: B.
1. sharp = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * sharp + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype('float32')
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
sharp = img + weight * residual
sharp = np.clip(sharp, 0, 1)
return soft_mask * sharp + (1 - soft_mask) * img
class USMSharp(torch.nn.Module):
def __init__(self, radius=50, sigma=0):
super(USMSharp, self).__init__()
if radius % 2 == 0:
radius += 1
self.radius = radius
kernel = cv2.getGaussianKernel(radius, sigma)
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
self.register_buffer('kernel', kernel)
def forward(self, img, weight=0.5, threshold=10):
blur = filter2D(img, self.kernel)
residual = img - blur
mask = torch.abs(residual) * 255 > threshold
mask = mask.float()
soft_mask = filter2D(mask, self.kernel)
sharp = img + weight * residual
sharp = torch.clip(sharp, 0, 1)
return soft_mask * sharp + (1 - soft_mask) * img
\ No newline at end of file
import cv2
import math
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
from scipy import special
"""
#1.Generate kernel
bivariate_Gaussian: iostropic/anisoprotic
bivariate_generalized_Gaussian: iostropic/anisoprotic
bivariate_plateau: iostropic/anisoprotic
#2.Randomly select kernel
random_bivariate_Gaussian
random_bivariate_generalized_Gaussian
random_bivariate_plateau
#3.Randomly generate mixed kernels
random_mixed_kernels
#4.Generate blur kernels (used in the first/second degradation)
generate_kernel1
generate_kernel2
#5.Auxiliary utils
sigma_matrix2
mesh_grid
pdf2
circular_lowpass_kernel <--- sinc filter
"""
# -------------------------------------------------------------------- #
# --------------------------- Generate kernel ------------------------ #
# -------------------------------------------------------------------- #
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
isotropic (bool):
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
kernel = pdf2(sigma_matrix, grid)
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
"""Generate a bivariate generalized Gaussian kernel.
Described in `Parameter Estimation For Multivariate Generalized
Gaussian Distributions`_
by Pascal et. al (2013).
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
.. _Parameter Estimation For Multivariate Generalized Gaussian
Distributions: https://arxiv.org/abs/1302.6498
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
"""Generate a plateau-like anisotropic kernel.
1 / (1+x^(beta))
Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
kernel = kernel / np.sum(kernel)
return kernel
# -------------------------------------------------------------------- #
# ---------------------------Random generate kernel ------------------ #
# -------------------------------------------------------------------- #
def random_bivariate_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_generalized_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate generalized Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
# assume beta_range[0] < 1 < beta_range[1]
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_plateau(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate plateau kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi/2, math.pi/2]
beta_range (tuple): [1, 4]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
# TODO: this may be not proper
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
# -------------------------------------------------------------------- #
# ---------------- Randomly generate mixed kernels ------------------- #
# -------------------------------------------------------------------- #
def random_mixed_kernels(kernel_list,
kernel_prob,
kernel_size=21,
sigma_x_range=(0.6, 5),
sigma_y_range=(0.6, 5),
rotation_range=(-math.pi, math.pi),
betag_range=(0.5, 8),
betap_range=(0.5, 8),
noise_range=None):
"""Randomly generate mixed kernels.
Args:
kernel_list (tuple): a list name of kernel types,
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
'plateau_aniso']
kernel_prob (tuple): corresponding kernel probability for each
kernel type
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
kernel_type = random.choices(kernel_list, kernel_prob)[0]
if kernel_type == 'iso':
kernel = random_bivariate_Gaussian(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
elif kernel_type == 'aniso':
kernel = random_bivariate_Gaussian(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
elif kernel_type == 'generalized_iso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
betag_range,
noise_range=noise_range,
isotropic=True)
elif kernel_type == 'generalized_aniso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
betag_range,
noise_range=noise_range,
isotropic=False)
elif kernel_type == 'plateau_iso':
kernel = random_bivariate_plateau(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
elif kernel_type == 'plateau_aniso':
kernel = random_bivariate_plateau(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
return kernel
# -------------------------------------------------------------------- #
# ----Generate blur kernels (used in the first/second degradation)---- #
# -------------------------------------------------------------------- #
def generate_kernel1(kernel_range,
sinc_prob,
kernel_list,
kernel_prob,
blur_sigma,
betag_range,
betap_range,
):
kernel_size = random.choice(kernel_range)
if np.random.uniform() < sinc_prob:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
kernel_list,
kernel_prob,
kernel_size,
blur_sigma,
blur_sigma,
[-math.pi, math.pi], # Rotation angle
betag_range,
betap_range,
noise_range=None,
)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
return kernel
def generate_kernel2(kernel_range,
sinc_prob2,
kernel_list2,
kernel_prob2,
blur_sigma2,
betag_range2,
betap_range2,):
kernel_size = random.choice(kernel_range)
if np.random.uniform() < sinc_prob2:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
kernel_list2,
kernel_prob2,
kernel_size,
blur_sigma2,
blur_sigma2,
[-math.pi, math.pi],
betag_range2,
betap_range2,
noise_range=None,
)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
return kernel
def generate_sinc_kernel(kernel_range):
kernel_size = random.choice(kernel_range)
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
return sinc_kernel
# -------------------------------------------------------------------- #
# --------------------------- 辅助工具 -------------------------------- #
# -------------------------------------------------------------------- #
def sigma_matrix2(sig_x, sig_y, theta):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
def mesh_grid(kernel_size):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
1))).reshape(kernel_size, kernel_size, 2)
return xy, xx, yy
def pdf2(sigma_matrix, grid):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
return kernel
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
"""2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
Args:
cutoff (float): cutoff frequency in radians (pi is max)
kernel_size (int): horizontal and vertical size, must be odd.
pad_to (int): pad kernel size to desired size, must be odd or zero.
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
kernel = np.fromfunction(
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
kernel = kernel / np.sum(kernel)
if pad_to > kernel_size:
pad_size = (pad_to - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
return kernel
if __name__ == '__main__':
blur_kernel_size = 21
kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
kernel_list = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob = 0.1
blur_sigma = [0.2, 3] # blur_x / y_sigma
betag_range = [0.5, 4]
betap_range = [1, 2]
kernel=generate_kernel1(kernel_range,sinc_prob,kernel_list,kernel_prob,blur_sigma,betag_range,betap_range)
print(kernel.shape)
img=cv2.imread('../qj.png')
img = np.float32(img / 255.)
img_blur=cv2.filter2D(img,-1,kernel)
#img_blur=img_blur[:,:,::-1]
img_blur = np.uint8((img_blur.clip(0, 1) * 255.).round())
cv2.imwrite('blur2.png',img_blur)
# plt.imshow(img_blur)
# plt.show()
import random
import itertools
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
import torch.nn as nn
from torch.nn import functional as F
class DiffJPEG(nn.Module):
"""This JPEG algorithm result is slightly different from cv2.
DiffJPEG supports batch processing.
Args:
differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
"""
def __init__(self, differentiable=True):
super(DiffJPEG, self).__init__()
if differentiable:
rounding = diff_round
else:
rounding = torch.round
self.compress = CompressJpeg(rounding=rounding)
self.decompress = DeCompressJpeg(rounding=rounding)
def forward(self, x, quality):
"""
Args:
x (Tensor): Input image, bchw, rgb, [0, 1]
quality(float): Quality factor for jpeg compression scheme.
"""
factor = quality
if isinstance(factor, (int, float)):
factor = quality_to_factor(factor)
else:
for i in range(factor.size(0)):
factor[i] = quality_to_factor(factor[i])
h, w = x.size()[-2:]
h_pad, w_pad = 0, 0
# why should use 16
if h % 16 != 0:
h_pad = 16 - h % 16
if w % 16 != 0:
w_pad = 16 - w % 16
x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
y, cb, cr = self.compress(x, factor=factor)
recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
recovered = recovered[:, :, 0:h, 0:w]
return recovered
#----------------------Compression----------------------#
class CompressJpeg(nn.Module):
"""Full JPEG compression algorithm
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding=torch.round):
super(CompressJpeg, self).__init__()
self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
self.c_quantize = CQuantize(rounding=rounding)
self.y_quantize = YQuantize(rounding=rounding)
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x 3 x height x width
Returns:
dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
"""
y, cb, cr = self.l1(image * 255)
components = {'y': y, 'cb': cb, 'cr': cr}
for k in components.keys():
comp = self.l2(components[k])
if k in ('cb', 'cr'):
comp = self.c_quantize(comp, factor=factor)
else:
comp = self.y_quantize(comp, factor=factor)
components[k] = comp
return components['y'], components['cb'], components['cr']
class RGB2YCbCrJpeg(nn.Module):
""" Converts RGB image to YCbCr
"""
def __init__(self):
super(RGB2YCbCrJpeg, self).__init__()
matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
dtype=np.float32).T
self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
self.matrix = nn.Parameter(torch.from_numpy(matrix))
def forward(self, image):
"""
Args:
image(Tensor): batch x 3 x height x width
Returns:
Tensor: batch x height x width x 3
"""
image = image.permute(0, 2, 3, 1)
result = torch.tensordot(image, self.matrix, dims=1) + self.shift
return result.view(image.shape)
class ChromaSubsampling(nn.Module):
""" Chroma subsampling on CbCr channels
"""
def __init__(self):
super(ChromaSubsampling, self).__init__()
def forward(self, image):
"""
Args:
image(tensor): batch x height x width x 3
Returns:
y(tensor): batch x height x width
cb(tensor): batch x height/2 x width/2
cr(tensor): batch x height/2 x width/2
"""
image_2 = image.permute(0, 3, 1, 2).clone()
cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
cb = cb.permute(0, 2, 3, 1)
cr = cr.permute(0, 2, 3, 1)
return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
class BlockSplitting(nn.Module):
""" Splitting image into patches
"""
def __init__(self):
super(BlockSplitting, self).__init__()
self.k = 8
def forward(self, image):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x h*w/64 x h x w
"""
height, _ = image.shape[1:3]
batch_size = image.shape[0]
image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
class DCT8x8(nn.Module):
""" Discrete Cosine Transformation
"""
def __init__(self):
super(DCT8x8, self).__init__()
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
def forward(self, image):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
image = image - 128
result = self.scale * torch.tensordot(image, self.tensor, dims=2)
result.view(image.shape)
return result
class YQuantize(nn.Module):
""" JPEG Quantization for Y channel
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding):
super(YQuantize, self).__init__()
self.rounding = rounding
self.y_table = y_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
image = image.float() / (self.y_table * factor)
else:
b = factor.size(0)
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
image = image.float() / table
image = self.rounding(image)
return image
class CQuantize(nn.Module):
""" JPEG Quantization for CbCr channels
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding):
super(CQuantize, self).__init__()
self.rounding = rounding
self.c_table = c_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
image = image.float() / (self.c_table * factor)
else:
b = factor.size(0)
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
image = image.float() / table
image = self.rounding(image)
return image
#----------------------Decompression----------------------#
class DeCompressJpeg(nn.Module):
"""Full JPEG decompression algorithm
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding=torch.round):
super(DeCompressJpeg, self).__init__()
self.c_dequantize = CDequantize()
self.y_dequantize = YDequantize()
self.idct = iDCT8x8()
self.merging = BlockMerging()
self.chroma = ChromaUpsampling()
self.colors = YCbCr2RGBJpeg()
def forward(self, y, cb, cr, imgh, imgw, factor=1):
"""
Args:
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
imgh(int)
imgw(int)
factor(float)
Returns:
Tensor: batch x 3 x height x width
"""
components = {'y': y, 'cb': cb, 'cr': cr}
for k in components.keys():
if k in ('cb', 'cr'):
comp = self.c_dequantize(components[k], factor=factor)
height, width = int(imgh / 2), int(imgw / 2)
else:
comp = self.y_dequantize(components[k], factor=factor)
height, width = imgh, imgw
comp = self.idct(comp)
components[k] = self.merging(comp, height, width)
#
image = self.chroma(components['y'], components['cb'], components['cr'])
image = self.colors(image)
image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
return image / 255
class YDequantize(nn.Module):
"""Dequantize Y channel
"""
def __init__(self):
super(YDequantize, self).__init__()
self.y_table = y_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
out = image * (self.y_table * factor)
else:
b = factor.size(0)
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
out = image * table
return out
class CDequantize(nn.Module):
"""Dequantize CbCr channel
"""
def __init__(self):
super(CDequantize, self).__init__()
self.c_table = c_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
out = image * (self.c_table * factor)
else:
b = factor.size(0)
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
out = image * table
return out
class iDCT8x8(nn.Module):
"""Inverse discrete Cosine Transformation
"""
def __init__(self):
super(iDCT8x8, self).__init__()
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
def forward(self, image):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
image = image * self.alpha
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
result.view(image.shape)
return result
class BlockMerging(nn.Module):
"""Merge patches into image
"""
def __init__(self):
super(BlockMerging, self).__init__()
def forward(self, patches, height, width):
"""
Args:
patches(tensor) batch x height*width/64, height x width
height(int)
width(int)
Returns:
Tensor: batch x height x width
"""
k = 8
batch_size = patches.shape[0]
image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
return image_transposed.contiguous().view(batch_size, height, width)
class ChromaUpsampling(nn.Module):
"""Upsample chroma layers
"""
def __init__(self):
super(ChromaUpsampling, self).__init__()
def forward(self, y, cb, cr):
"""
Args:
y(tensor): y channel image
cb(tensor): cb channel
cr(tensor): cr channel
Returns:
Tensor: batch x height x width x 3
"""
def repeat(x, k=2):
height, width = x.shape[1:3]
x = x.unsqueeze(-1)
x = x.repeat(1, 1, k, k)
x = x.view(-1, height * k, width * k)
return x
cb = repeat(cb)
cr = repeat(cr)
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
class YCbCr2RGBJpeg(nn.Module):
"""Converts YCbCr image to RGB JPEG
"""
def __init__(self):
super(YCbCr2RGBJpeg, self).__init__()
matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
self.matrix = nn.Parameter(torch.from_numpy(matrix))
def forward(self, image):
"""
Args:
image(tensor): batch x height x width x 3
Returns:
Tensor: batch x 3 x height x width
"""
result = torch.tensordot(image + self.shift, self.matrix, dims=1)
return result.view(image.shape).permute(0, 3, 1, 2)
# ------------------------ utils ------------------------#
y_table = np.array(
[[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
[14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
[49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
dtype=np.float32).T
y_table = nn.Parameter(torch.from_numpy(y_table))
c_table = np.empty((8, 8), dtype=np.float32)
c_table.fill(99)
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
c_table = nn.Parameter(torch.from_numpy(c_table))
def diff_round(x):
""" Differentiable rounding function
"""
return torch.round(x) + (x - torch.round(x))**3
def quality_to_factor(quality):
""" Calculate factor corresponding to quality
Args:
quality(float): Quality for jpeg compression.
Returns:
float: Compression factor.
"""
if quality < 50:
quality = 5000. / quality
else:
quality = 200. - quality * 2
return quality / 100.
if __name__ == '__main__':
def uint2single(img):
# uint8 [0,255] -> float32 [0.,1.]
return np.float32(img / 255.)
def single2uint(img):
return np.uint8((img.clip(0,1) *255.).round())
jpeg_range2 = [30, 95]
img=cv2.imread('../qj.png')
img=uint2single(img)
img_jpeg=random_add_jpg_compression(img)
img_jpeg=single2uint(img_jpeg)
#img_jpeg=random_add_jpg_compression(img,[30,95])
img_jpeg=img_jpeg[:,:,::-1]
plt.imshow(img_jpeg)
plt.show()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment