Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
STAR
Commits
1f5da520
Commit
1f5da520
authored
Dec 05, 2025
by
yangzhong
Browse files
git init
parents
Pipeline
#3144
failed with stages
in 0 seconds
Changes
326
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3568 additions
and
0 deletions
+3568
-0
utils_data/opensora/datasets/datasets_panda50m.py
utils_data/opensora/datasets/datasets_panda50m.py
+181
-0
utils_data/opensora/datasets/datasets_panda50m_dense.py
utils_data/opensora/datasets/datasets_panda50m_dense.py
+191
-0
utils_data/opensora/datasets/datasets_path2text.py
utils_data/opensora/datasets/datasets_path2text.py
+175
-0
utils_data/opensora/datasets/datasets_webvid.py
utils_data/opensora/datasets/datasets_webvid.py
+166
-0
utils_data/opensora/datasets/datasets_webvid10m.py
utils_data/opensora/datasets/datasets_webvid10m.py
+198
-0
utils_data/opensora/datasets/high_order/README.md
utils_data/opensora/datasets/high_order/README.md
+3
-0
utils_data/opensora/datasets/high_order/RealESRGAN_Deg_pipeline.py
...a/opensora/datasets/high_order/RealESRGAN_Deg_pipeline.py
+420
-0
utils_data/opensora/datasets/high_order/__pycache__/degrade_video.cpython-39.pyc
...asets/high_order/__pycache__/degrade_video.cpython-39.pyc
+0
-0
utils_data/opensora/datasets/high_order/__pycache__/utils_.cpython-311.pyc
...ra/datasets/high_order/__pycache__/utils_.cpython-311.pyc
+0
-0
utils_data/opensora/datasets/high_order/__pycache__/utils_.cpython-39.pyc
...ora/datasets/high_order/__pycache__/utils_.cpython-39.pyc
+0
-0
utils_data/opensora/datasets/high_order/__pycache__/utils_blur.cpython-311.pyc
...atasets/high_order/__pycache__/utils_blur.cpython-311.pyc
+0
-0
utils_data/opensora/datasets/high_order/__pycache__/utils_blur.cpython-39.pyc
...datasets/high_order/__pycache__/utils_blur.cpython-39.pyc
+0
-0
utils_data/opensora/datasets/high_order/__pycache__/utils_jpeg.cpython-39.pyc
...datasets/high_order/__pycache__/utils_jpeg.cpython-39.pyc
+0
-0
utils_data/opensora/datasets/high_order/__pycache__/utils_noise.cpython-39.pyc
...atasets/high_order/__pycache__/utils_noise.cpython-39.pyc
+0
-0
utils_data/opensora/datasets/high_order/degrade_video.py
utils_data/opensora/datasets/high_order/degrade_video.py
+498
-0
utils_data/opensora/datasets/high_order/degrade_video_mid.py
utils_data/opensora/datasets/high_order/degrade_video_mid.py
+500
-0
utils_data/opensora/datasets/high_order/matlab_functions.py
utils_data/opensora/datasets/high_order/matlab_functions.py
+178
-0
utils_data/opensora/datasets/high_order/utils_.py
utils_data/opensora/datasets/high_order/utils_.py
+86
-0
utils_data/opensora/datasets/high_order/utils_blur.py
utils_data/opensora/datasets/high_order/utils_blur.py
+498
-0
utils_data/opensora/datasets/high_order/utils_jpeg.py
utils_data/opensora/datasets/high_order/utils_jpeg.py
+474
-0
No files found.
utils_data/opensora/datasets/datasets_panda50m.py
0 → 100644
View file @
1f5da520
import
csv
csv
.
field_size_limit
(
5000000
)
import
os
import
numpy
as
np
import
torch
import
torchvision
import
torchvision.transforms
as
transforms
from
torchvision.datasets.folder
import
IMG_EXTENSIONS
,
pil_loader
from
.
import
video_transforms
from
.utils
import
center_crop_arr
# import video_transforms
# from utils import center_crop_arr
import
json
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
import
json
import
ast
import
pandas
as
pd
def
get_transforms_video
(
resolution
=
256
):
transform_video
=
transforms
.
Compose
(
[
video_transforms
.
ToTensorVideo
(),
# TCHW
video_transforms
.
RandomHorizontalFlipVideo
(),
video_transforms
.
UCFCenterCropVideo
(
resolution
),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
return
transform_video
def
get_transforms_image
(
image_size
=
256
):
transform
=
transforms
.
Compose
(
[
transforms
.
Lambda
(
lambda
pil_image
:
center_crop_arr
(
pil_image
,
image_size
)),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
return
transform
# open-sora-plan+magictime dataset
class
DatasetFromCSV
(
torch
.
utils
.
data
.
Dataset
):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def
__init__
(
self
,
csv_path
,
num_frames
=
16
,
frame_interval
=
1
,
transform
=
None
,
root
=
None
,
):
# video_samples = []
# with open(csv_path, "r") as f:
# reader = csv.reader(f)
# #csv_list = list(reader)
# for idx, v_s in enumerate(reader):
# vid_path = v_s[0]
# vid_caption = v_s[1]
# if os.path.exists(vid_path):
# video_samples.append([vid_path, vid_caption])
# if idx % 1000 == 0:
# print(idx)
video_samples
=
pd
.
read_csv
(
csv_path
)
self
.
samples
=
video_samples
#
print
(
'video num:'
,
self
.
samples
.
shape
[
0
])
self
.
is_video
=
True
self
.
transform
=
transform
self
.
num_frames
=
num_frames
self
.
frame_interval
=
frame_interval
self
.
temporal_sample
=
video_transforms
.
TemporalRandomCrop
(
num_frames
*
frame_interval
)
self
.
root
=
root
def
getitem
(
self
,
index
):
sample
=
self
.
samples
.
iloc
[
index
].
values
path
=
sample
[
0
]
text
=
sample
[
1
]
if
self
.
is_video
:
is_exit
=
os
.
path
.
exists
(
path
)
if
is_exit
:
vframes
,
aframes
,
info
=
torchvision
.
io
.
read_video
(
filename
=
path
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
total_frames
=
len
(
vframes
)
else
:
total_frames
=
0
loop_index
=
index
while
(
total_frames
<
self
.
num_frames
or
is_exit
==
False
):
#print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index
+=
1
if
loop_index
>=
self
.
samples
.
shape
[
0
]:
loop_index
=
0
sample
=
self
.
samples
.
iloc
[
loop_index
].
values
path
=
sample
[
0
]
text
=
sample
[
1
]
is_exit
=
os
.
path
.
exists
(
path
)
if
is_exit
:
vframes
,
aframes
,
info
=
torchvision
.
io
.
read_video
(
filename
=
path
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
total_frames
=
len
(
vframes
)
else
:
total_frames
=
0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind
,
end_frame_ind
=
self
.
temporal_sample
(
total_frames
)
assert
(
end_frame_ind
-
start_frame_ind
>=
self
.
num_frames
),
f
"
{
path
}
with index
{
index
}
has not enough frames."
frame_indice
=
np
.
linspace
(
start_frame_ind
,
end_frame_ind
-
1
,
self
.
num_frames
,
dtype
=
int
)
#print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video
=
vframes
[
frame_indice
]
video
=
self
.
transform
(
video
)
# T C H W
else
:
image
=
pil_loader
(
path
)
image
=
self
.
transform
(
image
)
video
=
image
.
unsqueeze
(
0
).
repeat
(
self
.
num_frames
,
1
,
1
,
1
)
# TCHW -> CTHW
video
=
video
.
permute
(
1
,
0
,
2
,
3
)
return
{
"video"
:
video
,
"text"
:
text
}
def
__getitem__
(
self
,
index
):
for
_
in
range
(
10
):
try
:
return
self
.
getitem
(
index
)
except
Exception
as
e
:
print
(
e
)
index
=
np
.
random
.
randint
(
len
(
self
))
raise
RuntimeError
(
"Too many bad data."
)
def
__len__
(
self
):
return
self
.
samples
.
shape
[
0
]
if
__name__
==
'__main__'
:
data_path
=
'/mnt/bn/videodataset-uswest/VDiT/dataset/panda50m/panda70m_training_full.csv'
root
=
'/mnt/bn/videodataset-uswest/panda70m'
dataset
=
DatasetFromCSV
(
data_path
,
transform
=
get_transforms_video
(),
num_frames
=
16
,
frame_interval
=
3
,
root
=
root
,
)
sampler
=
DistributedSampler
(
dataset
,
num_replicas
=
1
,
rank
=
0
,
shuffle
=
True
,
seed
=
1
)
loader
=
DataLoader
(
dataset
,
batch_size
=
1
,
shuffle
=
False
,
sampler
=
sampler
,
num_workers
=
0
,
pin_memory
=
True
,
drop_last
=
True
)
for
video_data
in
loader
:
print
(
video_data
)
\ No newline at end of file
utils_data/opensora/datasets/datasets_panda50m_dense.py
0 → 100644
View file @
1f5da520
import
csv
csv
.
field_size_limit
(
8000000
)
import
os
import
numpy
as
np
import
torch
import
torchvision
import
torchvision.transforms
as
transforms
from
torchvision.datasets.folder
import
IMG_EXTENSIONS
,
pil_loader
from
.
import
video_transforms
from
.utils
import
center_crop_arr
# import video_transforms
# from utils import center_crop_arr
import
json
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
import
json
import
ast
def
get_transforms_video
(
resolution
=
256
):
transform_video
=
transforms
.
Compose
(
[
video_transforms
.
ToTensorVideo
(),
# TCHW
video_transforms
.
RandomHorizontalFlipVideo
(),
video_transforms
.
UCFCenterCropVideo
(
resolution
),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
return
transform_video
def
get_transforms_image
(
image_size
=
256
):
transform
=
transforms
.
Compose
(
[
transforms
.
Lambda
(
lambda
pil_image
:
center_crop_arr
(
pil_image
,
image_size
)),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
return
transform
# open-sora-plan+magictime dataset
class
DatasetFromCSV
(
torch
.
utils
.
data
.
Dataset
):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def
__init__
(
self
,
csv_path
,
num_frames
=
16
,
frame_interval
=
1
,
transform
=
None
,
root
=
None
,
):
video_samples
=
[]
with
open
(
csv_path
,
"r"
)
as
f
:
reader
=
csv
.
reader
(
f
)
#csv_list = list(reader)
print
(
'csv read end'
)
parts_list
=
os
.
listdir
(
root
)
for
idx
,
v_s
in
enumerate
(
reader
):
if
idx
>
0
:
# no csv head
vid_name
=
v_s
[
0
]
vid_captions_str
=
v_s
[
3
]
vid_captions
=
ast
.
literal_eval
(
vid_captions_str
)
#vid_captions = vid_captions_str.split('\'')[1::2]
for
part
in
parts_list
:
vids_path
=
os
.
path
.
join
(
root
,
part
,
vid_name
)
if
os
.
path
.
isdir
(
vids_path
):
for
ic
,
cap
in
enumerate
(
vid_captions
):
vid_path
=
os
.
path
.
join
(
root
,
part
,
vid_name
,
vid_name
+
"_"
+
str
(
ic
)
+
".mp4"
)
if
os
.
path
.
exists
(
vid_path
):
video_samples
.
append
([
vid_path
,
cap
])
break
if
idx
%
1000
==
0
:
print
(
'read video'
,
idx
)
self
.
samples
=
video_samples
#
print
(
'video num:'
,
len
(
self
.
samples
))
self
.
is_video
=
True
self
.
transform
=
transform
self
.
num_frames
=
num_frames
self
.
frame_interval
=
frame_interval
self
.
temporal_sample
=
video_transforms
.
TemporalRandomCrop
(
num_frames
*
frame_interval
)
self
.
root
=
root
def
getitem
(
self
,
index
):
sample
=
self
.
samples
[
index
]
path
=
sample
[
0
]
text
=
sample
[
1
]
if
self
.
is_video
:
is_exit
=
os
.
path
.
exists
(
path
)
if
is_exit
:
vframes
,
aframes
,
info
=
torchvision
.
io
.
read_video
(
filename
=
path
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
total_frames
=
len
(
vframes
)
else
:
total_frames
=
0
loop_index
=
index
while
(
total_frames
<
self
.
num_frames
or
is_exit
==
False
):
#print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index
+=
1
if
loop_index
>=
len
(
self
.
samples
):
loop_index
=
0
sample
=
self
.
samples
[
loop_index
]
path
=
sample
[
0
]
text
=
sample
[
1
]
is_exit
=
os
.
path
.
exists
(
path
)
if
is_exit
:
vframes
,
aframes
,
info
=
torchvision
.
io
.
read_video
(
filename
=
path
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
total_frames
=
len
(
vframes
)
else
:
total_frames
=
0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind
,
end_frame_ind
=
self
.
temporal_sample
(
total_frames
)
assert
(
end_frame_ind
-
start_frame_ind
>=
self
.
num_frames
),
f
"
{
path
}
with index
{
index
}
has not enough frames."
frame_indice
=
np
.
linspace
(
start_frame_ind
,
end_frame_ind
-
1
,
self
.
num_frames
,
dtype
=
int
)
#print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video
=
vframes
[
frame_indice
]
video
=
self
.
transform
(
video
)
# T C H W
else
:
image
=
pil_loader
(
path
)
image
=
self
.
transform
(
image
)
video
=
image
.
unsqueeze
(
0
).
repeat
(
self
.
num_frames
,
1
,
1
,
1
)
# TCHW -> CTHW
video
=
video
.
permute
(
1
,
0
,
2
,
3
)
return
{
"video"
:
video
,
"text"
:
text
}
def
__getitem__
(
self
,
index
):
for
_
in
range
(
10
):
try
:
return
self
.
getitem
(
index
)
except
Exception
as
e
:
print
(
e
)
index
=
np
.
random
.
randint
(
len
(
self
))
raise
RuntimeError
(
"Too many bad data."
)
def
__len__
(
self
):
return
len
(
self
.
samples
)
if
__name__
==
'__main__'
:
data_path
=
'/mnt/bn/videodataset-uswest/VDiT/dataset/panda50m/panda70m_training_full.csv'
root
=
'/mnt/bn/videodataset-uswest/panda70m'
dataset
=
DatasetFromCSV
(
data_path
,
transform
=
get_transforms_video
(),
num_frames
=
16
,
frame_interval
=
3
,
root
=
root
,
)
sampler
=
DistributedSampler
(
dataset
,
num_replicas
=
1
,
rank
=
0
,
shuffle
=
True
,
seed
=
1
)
loader
=
DataLoader
(
dataset
,
batch_size
=
1
,
shuffle
=
False
,
sampler
=
sampler
,
num_workers
=
0
,
pin_memory
=
True
,
drop_last
=
True
)
for
video_data
in
loader
:
print
(
video_data
)
\ No newline at end of file
utils_data/opensora/datasets/datasets_path2text.py
0 → 100644
View file @
1f5da520
import
csv
import
os
import
numpy
as
np
import
torch
import
torchvision
import
torchvision.transforms
as
transforms
from
torchvision.datasets.folder
import
IMG_EXTENSIONS
,
pil_loader
from
.
import
video_transforms
from
.utils
import
center_crop_arr
# import video_transforms
# from utils import center_crop_arr
import
json
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
import
ipdb
def
get_transforms_video
(
resolution
=
256
):
transform_video
=
transforms
.
Compose
(
[
video_transforms
.
ToTensorVideo
(),
# TCHW
video_transforms
.
RandomHorizontalFlipVideo
(),
video_transforms
.
UCFCenterCropVideo
(
resolution
),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
return
transform_video
def
get_transforms_image
(
image_size
=
256
):
transform
=
transforms
.
Compose
(
[
transforms
.
Lambda
(
lambda
pil_image
:
center_crop_arr
(
pil_image
,
image_size
)),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
return
transform
# open-sora-plan+magictime dataset
class
DatasetFromCSV
(
torch
.
utils
.
data
.
Dataset
):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def
__init__
(
self
,
csv_path
,
num_frames
=
16
,
frame_interval
=
1
,
transform
=
None
,
root
=
None
,
):
video_samples
=
[]
with
open
(
csv_path
,
"r"
)
as
f
:
reader
=
csv
.
reader
(
f
)
csv_list
=
list
(
reader
)
for
v_s
in
csv_list
[
1
:]:
# no csv head
vid_path
=
v_s
[
0
]
vid_caption
=
v_s
[
1
]
if
os
.
path
.
exists
(
vid_path
):
video_samples
.
append
([
vid_path
,
vid_caption
])
self
.
samples
=
video_samples
self
.
is_video
=
True
self
.
transform
=
transform
self
.
num_frames
=
num_frames
self
.
frame_interval
=
frame_interval
self
.
temporal_sample
=
video_transforms
.
TemporalRandomCrop
(
num_frames
*
frame_interval
)
self
.
root
=
root
def
getitem
(
self
,
index
):
sample
=
self
.
samples
[
index
]
path
=
sample
[
0
]
text
=
sample
[
1
]
if
self
.
is_video
:
is_exit
=
os
.
path
.
exists
(
path
)
if
is_exit
:
vframes
,
aframes
,
info
=
torchvision
.
io
.
read_video
(
filename
=
path
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
total_frames
=
len
(
vframes
)
else
:
total_frames
=
0
loop_index
=
index
while
(
total_frames
<
self
.
num_frames
or
is_exit
==
False
):
#print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index
+=
1
if
loop_index
>=
len
(
self
.
samples
):
loop_index
=
0
sample
=
self
.
samples
[
loop_index
]
path
=
sample
[
0
]
text
=
sample
[
1
]
is_exit
=
os
.
path
.
exists
(
path
)
if
is_exit
:
vframes
,
aframes
,
info
=
torchvision
.
io
.
read_video
(
filename
=
path
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
total_frames
=
len
(
vframes
)
else
:
total_frames
=
0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind
,
end_frame_ind
=
self
.
temporal_sample
(
total_frames
)
assert
(
end_frame_ind
-
start_frame_ind
>=
self
.
num_frames
),
f
"
{
path
}
with index
{
index
}
has not enough frames."
frame_indice
=
np
.
linspace
(
start_frame_ind
,
end_frame_ind
-
1
,
self
.
num_frames
,
dtype
=
int
)
#print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video
=
vframes
[
frame_indice
]
video
=
self
.
transform
(
video
)
# T C H W
else
:
image
=
pil_loader
(
path
)
image
=
self
.
transform
(
image
)
video
=
image
.
unsqueeze
(
0
).
repeat
(
self
.
num_frames
,
1
,
1
,
1
)
# TCHW -> CTHW
video
=
video
.
permute
(
1
,
0
,
2
,
3
)
return
{
"video"
:
video
,
"text"
:
text
}
def
__getitem__
(
self
,
index
):
for
_
in
range
(
10
):
try
:
return
self
.
getitem
(
index
)
except
Exception
as
e
:
print
(
e
)
index
=
np
.
random
.
randint
(
len
(
self
))
raise
RuntimeError
(
"Too many bad data."
)
def
__len__
(
self
):
return
len
(
self
.
samples
)
if
__name__
==
'__main__'
:
data_path
=
'/mnt/bn/yh-volume0/dataset/CelebvHQ/CelebvHQ_caption_llava-34B.csv'
root
=
'/mnt/bn/yh-volume0/dataset/CelebvHQ/35666'
dataset
=
DatasetFromCSV
(
data_path
,
transform
=
get_transforms_video
(),
num_frames
=
16
,
frame_interval
=
3
,
root
=
root
,
)
sampler
=
DistributedSampler
(
dataset
,
num_replicas
=
1
,
rank
=
0
,
shuffle
=
True
,
seed
=
1
)
loader
=
DataLoader
(
dataset
,
batch_size
=
1
,
shuffle
=
False
,
sampler
=
sampler
,
num_workers
=
0
,
pin_memory
=
True
,
drop_last
=
True
)
for
video_data
in
loader
:
print
(
video_data
)
\ No newline at end of file
utils_data/opensora/datasets/datasets_webvid.py
0 → 100644
View file @
1f5da520
import
csv
import
os
import
numpy
as
np
import
torch
import
torchvision
import
torchvision.transforms
as
transforms
from
torchvision.datasets.folder
import
IMG_EXTENSIONS
,
pil_loader
from
.
import
video_transforms
from
.utils
import
center_crop_arr
def
get_transforms_video
(
resolution
=
256
):
transform_video
=
transforms
.
Compose
(
[
video_transforms
.
ToTensorVideo
(),
# TCHW
video_transforms
.
RandomHorizontalFlipVideo
(),
video_transforms
.
UCFCenterCropVideo
(
resolution
),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
return
transform_video
def
get_transforms_image
(
image_size
=
256
):
transform
=
transforms
.
Compose
(
[
transforms
.
Lambda
(
lambda
pil_image
:
center_crop_arr
(
pil_image
,
image_size
)),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
return
transform
class
DatasetFromCSV
(
torch
.
utils
.
data
.
Dataset
):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def
__init__
(
self
,
csv_path
,
num_frames
=
16
,
frame_interval
=
1
,
transform
=
None
,
root
=
None
,
):
self
.
csv_path
=
csv_path
with
open
(
csv_path
,
"r"
)
as
f
:
reader
=
csv
.
reader
(
f
)
csv_list
=
list
(
reader
)
all_samples
=
csv_list
[
1
:]
#no head, 10727607
# sample_samples = random.sample(all_samples, 400000) # 400k = 366k + 20k + 20k
sample_samples
=
[]
for
i_s
,
sample
in
enumerate
(
all_samples
):
if
i_s
%
25
==
0
:
if
sample
[
2
]
!=
'0'
:
sample_samples
.
append
(
sample
)
self
.
samples
=
sample_samples
# 429105
ext
=
self
.
samples
[
0
][
0
].
split
(
"."
)[
-
1
]
if
ext
.
lower
()
in
(
"mp4"
,
"avi"
,
"mov"
,
"mkv"
):
self
.
is_video
=
True
else
:
assert
f
".
{
ext
.
lower
()
}
"
in
IMG_EXTENSIONS
,
f
"Unsupported file format:
{
ext
}
"
self
.
is_video
=
False
self
.
transform
=
transform
self
.
num_frames
=
num_frames
self
.
frame_interval
=
frame_interval
self
.
temporal_sample
=
video_transforms
.
TemporalRandomCrop
(
num_frames
*
frame_interval
)
self
.
root
=
root
def
getitem
(
self
,
index
):
sample
=
self
.
samples
[
index
]
path
=
sample
[
0
]
if
self
.
root
:
path
=
os
.
path
.
join
(
self
.
root
,
path
)
text
=
sample
[
-
1
]
if
self
.
is_video
:
# old
# vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
# total_frames = len(vframes)
# # Sampling video frames
# start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
# assert (
# end_frame_ind - start_frame_ind >= self.num_frames
# ), f"{path} with index {index} has not enough frames."
# frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
# video = vframes[frame_indice]
# video = self.transform(video) # T C H W
# new
is_exit
=
os
.
path
.
exists
(
path
)
if
is_exit
:
vframes
,
aframes
,
info
=
torchvision
.
io
.
read_video
(
filename
=
path
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
total_frames
=
len
(
vframes
)
else
:
total_frames
=
0
loop_index
=
index
while
(
total_frames
<
self
.
num_frames
or
is_exit
==
False
):
print
(
"total_frames:"
,
total_frames
,
"<"
,
self
.
num_frames
,
", or"
,
path
,
"does not exit!!!"
)
loop_index
+=
1
if
loop_index
>=
len
(
self
.
samples
):
loop_index
=
0
sample
=
self
.
samples
[
loop_index
]
path
=
sample
[
0
]
if
self
.
root
:
path
=
os
.
path
.
join
(
self
.
root
,
path
)
text
=
sample
[
-
1
]
is_exit
=
os
.
path
.
exists
(
path
)
if
is_exit
:
vframes
,
aframes
,
info
=
torchvision
.
io
.
read_video
(
filename
=
path
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
total_frames
=
len
(
vframes
)
else
:
total_frames
=
0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind
,
end_frame_ind
=
self
.
temporal_sample
(
total_frames
)
assert
(
end_frame_ind
-
start_frame_ind
>=
self
.
num_frames
),
f
"
{
path
}
with index
{
index
}
has not enough frames."
frame_indice
=
np
.
linspace
(
start_frame_ind
,
end_frame_ind
-
1
,
self
.
num_frames
,
dtype
=
int
)
print
(
"total_frames:"
,
total_frames
,
"frame_indice:"
,
frame_indice
,
"sample:"
,
sample
)
video
=
vframes
[
frame_indice
]
video
=
self
.
transform
(
video
)
# T C H W
else
:
image
=
pil_loader
(
path
)
image
=
self
.
transform
(
image
)
video
=
image
.
unsqueeze
(
0
).
repeat
(
self
.
num_frames
,
1
,
1
,
1
)
# TCHW -> CTHW
video
=
video
.
permute
(
1
,
0
,
2
,
3
)
return
{
"video"
:
video
,
"text"
:
text
}
def
__getitem__
(
self
,
index
):
for
_
in
range
(
10
):
try
:
return
self
.
getitem
(
index
)
except
Exception
as
e
:
print
(
e
)
index
=
np
.
random
.
randint
(
len
(
self
))
raise
RuntimeError
(
"Too many bad data."
)
def
__len__
(
self
):
return
len
(
self
.
samples
)
utils_data/opensora/datasets/datasets_webvid10m.py
0 → 100644
View file @
1f5da520
import
csv
import
os
import
numpy
as
np
import
torch
import
torchvision
import
torchvision.transforms
as
transforms
from
torchvision.datasets.folder
import
IMG_EXTENSIONS
,
pil_loader
from
.
import
video_transforms
from
.utils
import
center_crop_arr
# import video_transforms
# from utils import center_crop_arr
import
json
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
import
ipdb
def
get_transforms_video
(
resolution
=
256
):
transform_video
=
transforms
.
Compose
(
[
video_transforms
.
ToTensorVideo
(),
# TCHW
video_transforms
.
RandomHorizontalFlipVideo
(),
video_transforms
.
UCFCenterCropVideo
(
resolution
),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
return
transform_video
def
get_transforms_image
(
image_size
=
256
):
transform
=
transforms
.
Compose
(
[
transforms
.
Lambda
(
lambda
pil_image
:
center_crop_arr
(
pil_image
,
image_size
)),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
return
transform
class
DatasetFromCSV
(
torch
.
utils
.
data
.
Dataset
):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def
__init__
(
self
,
csv_path
,
num_frames
=
16
,
frame_interval
=
1
,
transform
=
None
,
root
=
None
,
):
self
.
csv_path
=
csv_path
with
open
(
csv_path
,
"r"
)
as
f
:
reader
=
csv
.
reader
(
f
)
csv_list
=
list
(
reader
)
all_samples
=
csv_list
[
1
:]
#no head, 10727607
sample_samples
=
[]
for
i_s
,
sample
in
enumerate
(
all_samples
):
if
sample
[
2
]
!=
'0'
:
sample_samples
.
append
(
sample
)
print
(
'samples num:'
,
len
(
sample_samples
))
self
.
samples
=
sample_samples
# 10727337
self
.
is_video
=
True
self
.
transform
=
transform
self
.
num_frames
=
num_frames
self
.
frame_interval
=
frame_interval
self
.
temporal_sample
=
video_transforms
.
TemporalRandomCrop
(
num_frames
*
frame_interval
)
self
.
root
=
root
def
getitem
(
self
,
index
):
sample
=
self
.
samples
[
index
]
path
=
sample
[
0
]
if
self
.
root
:
path
=
os
.
path
.
join
(
self
.
root
,
path
)
text
=
sample
[
-
1
]
#path = "/mnt/bn/yh-volume0/dataset/webvid/raw/videos/train/videos_new/013501_013550-33142969.mp4"
if
self
.
is_video
:
# old
# vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
# total_frames = len(vframes)
# # Sampling video frames
# start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
# assert (
# end_frame_ind - start_frame_ind >= self.num_frames
# ), f"{path} with index {index} has not enough frames."
# frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
# video = vframes[frame_indice]
# video = self.transform(video) # T C H W
# new
is_exit
=
os
.
path
.
exists
(
path
)
if
is_exit
:
vframes
,
aframes
,
info
=
torchvision
.
io
.
read_video
(
filename
=
path
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
total_frames
=
len
(
vframes
)
else
:
total_frames
=
0
loop_index
=
index
while
(
total_frames
<
self
.
num_frames
or
is_exit
==
False
):
#print("total_frames:", total_frames, "<", self.num_frames, ", or", path, "does not exit!!!")
loop_index
+=
1
if
loop_index
>=
len
(
self
.
samples
):
loop_index
=
0
sample
=
self
.
samples
[
loop_index
]
path
=
sample
[
0
]
if
self
.
root
:
path
=
os
.
path
.
join
(
self
.
root
,
path
)
text
=
sample
[
-
1
]
is_exit
=
os
.
path
.
exists
(
path
)
if
is_exit
:
vframes
,
aframes
,
info
=
torchvision
.
io
.
read_video
(
filename
=
path
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
total_frames
=
len
(
vframes
)
else
:
total_frames
=
0
# video exits and total_frames >= self.num_frames
# Sampling video frames
start_frame_ind
,
end_frame_ind
=
self
.
temporal_sample
(
total_frames
)
assert
(
end_frame_ind
-
start_frame_ind
>=
self
.
num_frames
),
f
"
{
path
}
with index
{
index
}
has not enough frames."
frame_indice
=
np
.
linspace
(
start_frame_ind
,
end_frame_ind
-
1
,
self
.
num_frames
,
dtype
=
int
)
#print("total_frames:", total_frames, "frame_indice:", frame_indice, "sample:", sample)
video
=
vframes
[
frame_indice
]
video
=
self
.
transform
(
video
)
# T C H W
else
:
image
=
pil_loader
(
path
)
image
=
self
.
transform
(
image
)
video
=
image
.
unsqueeze
(
0
).
repeat
(
self
.
num_frames
,
1
,
1
,
1
)
# TCHW -> CTHW
video
=
video
.
permute
(
1
,
0
,
2
,
3
)
#print('video shape:', video.shape,'text:', text, 'video path:', path)
return
{
"video"
:
video
,
"text"
:
text
}
def
__getitem__
(
self
,
index
):
for
_
in
range
(
10
):
try
:
return
self
.
getitem
(
index
)
except
Exception
as
e
:
print
(
e
)
index
=
np
.
random
.
randint
(
len
(
self
))
raise
RuntimeError
(
"Too many bad data."
)
def
__len__
(
self
):
return
len
(
self
.
samples
)
if
__name__
==
'__main__'
:
data_path
=
'/mnt/bn/yh-volume0/dataset/webvid/raw/webvid_csv/train.csv'
root
=
'/mnt/bn/yh-volume0/dataset/webvid/raw/videos/train/videos_new'
dataset
=
DatasetFromCSV
(
data_path
,
transform
=
get_transforms_video
(),
num_frames
=
16
,
frame_interval
=
3
,
root
=
root
,
)
sampler
=
DistributedSampler
(
dataset
,
num_replicas
=
1
,
rank
=
0
,
shuffle
=
True
,
seed
=
1
)
loader
=
DataLoader
(
dataset
,
batch_size
=
1
,
shuffle
=
False
,
sampler
=
sampler
,
num_workers
=
0
,
pin_memory
=
True
,
drop_last
=
True
)
for
video_data
in
loader
:
print
(
video_data
)
\ No newline at end of file
utils_data/opensora/datasets/high_order/README.md
0 → 100644
View file @
1f5da520
Real-ESRGAN Degradation Dataset Pipeline. One can generate own degraded datasets using this pipeline.
Note: This Project is derived from https://github.com/xinntao/Real-ESRGAN
utils_data/opensora/datasets/high_order/RealESRGAN_Deg_pipeline.py
0 → 100644
View file @
1f5da520
import
argparse
import
cv2
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
random
import
torch
import
math
import
os
import
torch.nn
as
nn
from
PIL
import
Image
from
utils_
import
filter2D
,
USMSharp
from
utils_blur
import
circular_lowpass_kernel
,
random_mixed_kernels
from
utils_resize
import
random_resizing
from
utils_noise
import
random_add_gaussian_noise_pt
,
random_add_poisson_noise_pt
from
utils_jpeg
import
DiffJPEG
# from matlab_functions import imresize
# from torchvision import transforms
from
torch.nn
import
functional
as
F
class
Degradation
(
nn
.
Module
):
def
__init__
(
self
,
scale
,
gt_size
):
super
(
Degradation
,
self
).
__init__
()
### initization JPEF class
self
.
jpeger
=
DiffJPEG
(
differentiable
=
False
)
#.cuda()
self
.
usm_sharpener
=
USMSharp
()
#.cuda()
# self.queue_size = 180 #opt.get('queue_size', 180)
### global settings
self
.
scale
=
scale
self
.
gt_size
=
gt_size
### the first degradation hypermeters ###
# 1. blur
self
.
blur_kernel_size
=
21
self
.
kernel_range
=
[
2
*
v
+
1
for
v
in
range
(
3
,
11
)]
# kernel size ranges from 7 to 21
self
.
kernel_list
=
[
'iso'
,
'aniso'
,
'generalized_iso'
,
'generalized_aniso'
,
'plateau_iso'
,
'plateau_aniso'
]
self
.
kernel_prob
=
[
0.45
,
0.25
,
0.12
,
0.03
,
0.12
,
0.03
]
self
.
sinc_prob
=
0.1
self
.
blur_sigma
=
[
0.2
,
3
]
# blur_x / y_sigma
self
.
betag_range
=
[
0.5
,
4
]
self
.
betap_range
=
[
1
,
2
]
# 2. resize
self
.
updown_type
=
[
"up"
,
"down"
,
"keep"
]
self
.
mode_list
=
[
"area"
,
"bilinear"
,
"bicubic"
]
# flags:[3,1,2]
self
.
resize_prob
=
[
0.2
,
0.7
,
0.1
]
# up, down, keep
self
.
resize_range
=
[
0.15
,
1.5
]
# 3. noise
self
.
gaussian_noise_prob
=
0.5
self
.
noise_range
=
[
1
,
30
]
self
.
poisson_scale_range
=
[
0.05
,
3
]
self
.
gray_noise_prob
=
0.4
# 4. jpeg
self
.
jpeg_range
=
[
30
,
95
]
### the second degradation hypermeters ###
# 1. blur
self
.
second_blur_prob
=
0.8
self
.
blur_kernel_size2
=
21
self
.
kernel_range2
=
[
2
*
v
+
1
for
v
in
range
(
3
,
11
)]
# kernel size ranges from 7 to 21
self
.
kernel_list2
=
[
'iso'
,
'aniso'
,
'generalized_iso'
,
'generalized_aniso'
,
'plateau_iso'
,
'plateau_aniso'
]
self
.
kernel_prob2
=
[
0.45
,
0.25
,
0.12
,
0.03
,
0.12
,
0.03
]
self
.
sinc_prob2
=
0.1
self
.
blur_sigma2
=
[
0.2
,
1.5
]
self
.
betag_range2
=
[
0.5
,
4
]
self
.
betap_range2
=
[
1
,
2
]
# 2. resize
self
.
updown_type2
=
[
"up"
,
"down"
,
"keep"
]
self
.
mode_list2
=
[
"area"
,
"bilinear"
,
"bicubic"
]
# flags:[3,1,2]
self
.
resize_prob2
=
[
0.3
,
0.4
,
0.3
]
# up, down, keep
self
.
resize_range2
=
[
0.3
,
1.2
]
# 3. noise
self
.
gaussian_noise_prob2
=
0.5
self
.
noise_range2
=
[
1
,
25
]
self
.
poisson_scale_range2
=
[
0.05
,
2.5
]
self
.
gray_noise_prob2
=
0.4
# 4. jpeg
self
.
jpeg_range2
=
[
30
,
95
]
self
.
final_sinc_prob
=
0.8
# TODO: kernel range is now hard-coded, should be in the configure file
self
.
pulse_tensor
=
torch
.
zeros
(
21
,
21
).
float
()
# convolving with pulse tensor brings no blurry effect
self
.
pulse_tensor
[
10
,
10
]
=
1
@
torch
.
no_grad
()
def
forward
(
self
,
gt
):
ori_h
,
ori_w
=
gt
.
size
()[
2
:
4
]
gt_usm
=
self
.
usm_sharpener
(
gt
)
gt_usm_copy
=
gt_usm
.
clone
()
# generate kernel
kernel1
=
self
.
generate_first_kernel
()
kernel2
=
self
.
generate_second_kernel
()
sinc_kernel
=
self
.
generate_sinc_kernel
()
# first degradation
lq
=
self
.
jpeg_1
(
self
.
noise_1
(
self
.
resize_1
(
self
.
blur_1
(
gt_usm_copy
,
kernel1
))))
# second degradation
lq
=
self
.
jpeg_2
(
self
.
noise_2
(
self
.
resize_2
(
self
.
blur_2
(
lq
,
kernel2
),
ori_h
,
ori_w
)),
ori_h
,
ori_w
,
sinc_kernel
)
return
lq
,
gt_usm
#, kernel1, kernel2, sinc_kernel
@
torch
.
no_grad
()
# def forward(self, gt_path, uint8=False):
# # read hwc 0-1 numpy
# img_gt = cv2.imread(gt_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
# # augment
# img_gt = self.augment(img_gt, True, True)
# # numpy 0-1 hwc -> tensor 0-1 chw
# img_gt = self.np2tensor([img_gt], bgr2rgb=True, float32=True)[0]
# # add batch
# img_gt = img_gt.unsqueeze(0)
# img_gt_copy = img_gt.clone()
# # degradation_piepline
# lq, gt_usm, kernel1, kernel2, sinc_kernel = self.forward_deg(img_gt)
# # clamp and round
# lq = torch.clamp((lq * 255.0).round(), 0, 255) / 255.
# print(f'before crop: gt:{img_gt_copy.shape}, lq:{lq.shape}')
# # random crop
# (gt, gt_usm), lq = self.paired_random_crop([img_gt_copy, gt_usm], lq, self.gt_size, self.scale)
# print(f'after crop: gt:{gt.shape}, lq:{lq.shape}')
#
# if uint8:
# gt, gt_usm, lq = self.tensor2np([gt, gt_usm, lq])
# return gt, gt_usm, lq, kernel1, kernel2, sinc_kernel
#
# return gt, gt_usm, lq, kernel1, kernel2, sinc_kernel
def
blur_1
(
self
,
img
,
kernel1
):
img
=
filter2D
(
img
,
kernel1
)
return
img
def
blur_2
(
self
,
img
,
kernel2
):
if
np
.
random
.
uniform
()
<
self
.
second_blur_prob
:
img
=
filter2D
(
img
,
kernel2
)
return
img
def
resize_1
(
self
,
img
):
updown_type
=
random
.
choices
([
'up'
,
'down'
,
'keep'
],
self
.
resize_prob
)[
0
]
if
updown_type
==
'up'
:
scale
=
np
.
random
.
uniform
(
1
,
self
.
resize_range
[
1
])
elif
updown_type
==
'down'
:
scale
=
np
.
random
.
uniform
(
self
.
resize_range
[
0
],
1
)
else
:
scale
=
1
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
img
=
F
.
interpolate
(
img
,
scale_factor
=
scale
,
mode
=
mode
)
return
img
def
resize_2
(
self
,
img
,
ori_h
,
ori_w
):
updown_type
=
random
.
choices
([
'up'
,
'down'
,
'keep'
],
self
.
resize_prob2
)[
0
]
if
updown_type
==
'up'
:
scale
=
np
.
random
.
uniform
(
1
,
self
.
resize_range2
[
1
])
elif
updown_type
==
'down'
:
scale
=
np
.
random
.
uniform
(
self
.
resize_range2
[
0
],
1
)
else
:
scale
=
1
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
img
=
F
.
interpolate
(
img
,
size
=
(
int
(
ori_h
/
self
.
scale
*
scale
),
int
(
ori_w
/
scale
*
scale
)),
mode
=
mode
)
return
img
def
noise_1
(
self
,
img
):
gray_noise_prob
=
self
.
gray_noise_prob
if
np
.
random
.
uniform
()
<
self
.
gaussian_noise_prob
:
img
=
random_add_gaussian_noise_pt
(
img
,
sigma_range
=
self
.
noise_range
,
clip
=
True
,
rounds
=
False
,
gray_prob
=
gray_noise_prob
)
else
:
img
=
random_add_poisson_noise_pt
(
img
,
scale_range
=
self
.
poisson_scale_range
,
gray_prob
=
gray_noise_prob
,
clip
=
True
,
rounds
=
False
)
return
img
def
noise_2
(
self
,
img
):
gray_noise_prob
=
self
.
gray_noise_prob2
if
np
.
random
.
uniform
()
<
self
.
gaussian_noise_prob2
:
img
=
random_add_gaussian_noise_pt
(
img
,
sigma_range
=
self
.
noise_range2
,
clip
=
True
,
rounds
=
False
,
gray_prob
=
gray_noise_prob
)
else
:
img
=
random_add_poisson_noise_pt
(
img
,
scale_range
=
self
.
poisson_scale_range2
,
gray_prob
=
gray_noise_prob
,
clip
=
True
,
rounds
=
False
)
return
img
def
jpeg_1
(
self
,
img
):
jpeg_p
=
img
.
new_zeros
(
img
.
size
(
0
)).
uniform_
(
*
self
.
jpeg_range
)
img
=
torch
.
clamp
(
img
,
0
,
1
)
# clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
img
=
self
.
jpeger
(
img
,
quality
=
jpeg_p
)
return
img
def
jpeg_2
(
self
,
out
,
ori_h
,
ori_w
,
sinc_kernel
):
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if
np
.
random
.
uniform
()
<
0.5
:
# resize back + the final sinc filter
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
ori_h
//
self
.
scale
,
ori_w
//
self
.
scale
),
mode
=
mode
)
out
=
filter2D
(
out
,
sinc_kernel
)
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
jpeg_range2
)
out
=
torch
.
clamp
(
out
,
0
,
1
)
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
else
:
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
jpeg_range2
)
out
=
torch
.
clamp
(
out
,
0
,
1
)
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
# resize back + the final sinc filter
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
ori_h
//
self
.
scale
,
ori_w
//
self
.
scale
),
mode
=
mode
)
out
=
filter2D
(
out
,
sinc_kernel
)
return
out
def
generate_first_kernel
(
self
):
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
if
np
.
random
.
uniform
()
<
self
.
sinc_prob
:
# this sinc filter setting is for kernels ranging from [7, 21]
if
kernel_size
<
13
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
else
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
5
,
np
.
pi
)
kernel
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
False
)
else
:
kernel
=
random_mixed_kernels
(
self
.
kernel_list
,
self
.
kernel_prob
,
kernel_size
,
self
.
blur_sigma
,
self
.
blur_sigma
,
[
-
math
.
pi
,
math
.
pi
],
self
.
betag_range
,
self
.
betap_range
,
noise_range
=
None
)
# pad kernel
pad_size
=
(
21
-
kernel_size
)
//
2
kernel
=
np
.
pad
(
kernel
,
((
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)))
return
torch
.
FloatTensor
(
kernel
)
def
generate_second_kernel
(
self
):
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
if
np
.
random
.
uniform
()
<
self
.
sinc_prob2
:
if
kernel_size
<
13
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
else
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
5
,
np
.
pi
)
kernel2
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
False
)
else
:
kernel2
=
random_mixed_kernels
(
self
.
kernel_list2
,
self
.
kernel_prob2
,
kernel_size
,
self
.
blur_sigma2
,
self
.
blur_sigma2
,
[
-
math
.
pi
,
math
.
pi
],
self
.
betag_range2
,
self
.
betap_range2
,
noise_range
=
None
)
# pad kernel
pad_size
=
(
21
-
kernel_size
)
//
2
kernel2
=
np
.
pad
(
kernel2
,
((
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)))
return
torch
.
FloatTensor
(
kernel2
)
def
generate_sinc_kernel
(
self
):
if
np
.
random
.
uniform
()
<
self
.
final_sinc_prob
:
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
sinc_kernel
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
21
)
sinc_kernel
=
torch
.
FloatTensor
(
sinc_kernel
)
else
:
sinc_kernel
=
self
.
pulse_tensor
return
sinc_kernel
def
np2tensor
(
self
,
imgs
,
bgr2rgb
=
False
,
float32
=
True
):
def
_totensor
(
img
,
bgr2rgb
,
float32
):
if
img
.
shape
[
2
]
==
3
and
bgr2rgb
:
if
img
.
dtype
==
'float64'
:
img
=
img
.
astype
(
'float32'
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
img
=
torch
.
from_numpy
(
img
.
transpose
(
2
,
0
,
1
))
if
float32
:
img
=
img
.
float
()
return
img
if
isinstance
(
imgs
,
list
):
return
[
_totensor
(
img
,
bgr2rgb
,
float32
)
for
img
in
imgs
]
else
:
return
_totensor
(
imgs
,
bgr2rgb
,
float32
)
def
tensor2np
(
self
,
imgs
):
def
_tonumpy
(
img
):
img
=
img
.
data
.
cpu
().
numpy
().
squeeze
(
0
).
transpose
(
1
,
2
,
0
)
#.astype(np.float32)
img
=
np
.
uint8
((
img
.
clip
(
0
,
1
)
*
255.
).
round
())
return
img
if
isinstance
(
imgs
,
list
):
return
[
_tonumpy
(
img
)
for
img
in
imgs
]
else
:
return
_tonumpy
(
imgs
)
def
augment
(
self
,
imgs
,
hflip
=
True
,
rotation
=
True
,
flows
=
None
,
return_status
=
False
):
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rotation
and
random
.
random
()
<
0.5
rot90
=
rotation
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
# horizontal
cv2
.
flip
(
img
,
1
,
img
)
if
vflip
:
# vertical
cv2
.
flip
(
img
,
0
,
img
)
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
if
not
isinstance
(
imgs
,
list
):
imgs
=
[
imgs
]
imgs
=
[
_augment
(
img
)
for
img
in
imgs
]
if
len
(
imgs
)
==
1
:
imgs
=
imgs
[
0
]
return
imgs
def
paired_random_crop
(
self
,
img_gts
,
img_lqs
,
gt_patch_size
,
scale
,
gt_path
=
None
):
"""Paired random crop. Support Numpy array and Tensor inputs.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth. Default: None.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if
not
isinstance
(
img_gts
,
list
):
img_gts
=
[
img_gts
]
if
not
isinstance
(
img_lqs
,
list
):
img_lqs
=
[
img_lqs
]
# determine input type: Numpy array or Tensor
input_type
=
'Tensor'
if
torch
.
is_tensor
(
img_gts
[
0
])
else
'Numpy'
if
input_type
==
'Tensor'
:
h_lq
,
w_lq
=
img_lqs
[
0
].
size
()[
-
2
:]
h_gt
,
w_gt
=
img_gts
[
0
].
size
()[
-
2
:]
else
:
h_lq
,
w_lq
=
img_lqs
[
0
].
shape
[
0
:
2
]
h_gt
,
w_gt
=
img_gts
[
0
].
shape
[
0
:
2
]
lq_patch_size
=
gt_patch_size
//
scale
if
h_gt
!=
h_lq
*
scale
or
w_gt
!=
w_lq
*
scale
:
raise
ValueError
(
f
'Scale mismatches. GT (
{
h_gt
}
,
{
w_gt
}
) is not
{
scale
}
x '
,
f
'multiplication of LQ (
{
h_lq
}
,
{
w_lq
}
).'
)
if
h_lq
<
lq_patch_size
or
w_lq
<
lq_patch_size
:
raise
ValueError
(
f
'LQ (
{
h_lq
}
,
{
w_lq
}
) is smaller than patch size '
f
'(
{
lq_patch_size
}
,
{
lq_patch_size
}
). '
f
'Please remove
{
gt_path
}
.'
)
# randomly choose top and left coordinates for lq patch
top
=
random
.
randint
(
0
,
h_lq
-
lq_patch_size
)
left
=
random
.
randint
(
0
,
w_lq
-
lq_patch_size
)
# crop lq patch
if
input_type
==
'Tensor'
:
img_lqs
=
[
v
[:,
:,
top
:
top
+
lq_patch_size
,
left
:
left
+
lq_patch_size
]
for
v
in
img_lqs
]
else
:
img_lqs
=
[
v
[
top
:
top
+
lq_patch_size
,
left
:
left
+
lq_patch_size
,
...]
for
v
in
img_lqs
]
# crop corresponding gt patch
top_gt
,
left_gt
=
int
(
top
*
scale
),
int
(
left
*
scale
)
if
input_type
==
'Tensor'
:
img_gts
=
[
v
[:,
:,
top_gt
:
top_gt
+
gt_patch_size
,
left_gt
:
left_gt
+
gt_patch_size
]
for
v
in
img_gts
]
else
:
img_gts
=
[
v
[
top_gt
:
top_gt
+
gt_patch_size
,
left_gt
:
left_gt
+
gt_patch_size
,
...]
for
v
in
img_gts
]
if
len
(
img_gts
)
==
1
:
img_gts
=
img_gts
[
0
]
if
len
(
img_lqs
)
==
1
:
img_lqs
=
img_lqs
[
0
]
return
img_gts
,
img_lqs
if
__name__
==
'__main__'
:
# print(os.path.abspath(os.path.join(__file__, os.path.pardir)))
deg_pipeline
=
Degradation
(
scale
=
4
,
gt_size
=
256
)
# gt_path = r'J:\Dataset\SR\Real_ESRGAN\DF2K_multiscale_sub\0052T0_s024.png'
gt_path
=
'./Dataset/train/DIV2K/HR/'
gt
,
gt_usm
,
lq
,
kernel1
,
kernel2
,
sinc_kernel
=
deg_pipeline
(
gt_path
,
uint8
=
True
)
cv2
.
imwrite
(
'lq.png'
,
lq
)
cv2
.
imwrite
(
'gt.png'
,
gt
)
cv2
.
imwrite
(
'gt_usm.png'
,
gt
)
utils_data/opensora/datasets/high_order/__pycache__/degrade_video.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/datasets/high_order/__pycache__/utils_.cpython-311.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/datasets/high_order/__pycache__/utils_.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/datasets/high_order/__pycache__/utils_blur.cpython-311.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/datasets/high_order/__pycache__/utils_blur.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/datasets/high_order/__pycache__/utils_jpeg.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/datasets/high_order/__pycache__/utils_noise.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/datasets/high_order/degrade_video.py
0 → 100644
View file @
1f5da520
from
re
import
T
import
cv2
import
numpy
as
np
import
random
import
torch
import
math
import
torch.nn
as
nn
import
sys
sys
.
path
.
append
(
'./opensora/datasets/high_order'
)
from
utils_
import
filter2D
,
USMSharp
from
utils_blur
import
circular_lowpass_kernel
,
random_mixed_kernels
from
utils_noise
import
random_add_gaussian_noise_pt
,
random_add_poisson_noise_pt
from
utils_jpeg
import
DiffJPEG
from
torch.nn
import
functional
as
F
from
einops
import
rearrange
import
av
import
io
class
ImageCompressor
:
def
__init__
(
self
):
self
.
params
=
{
'codec'
:
[
'libx264'
,
'h264'
,
'mpeg4'
],
'codec_prob'
:
[
1
/
3.
,
1
/
3.
,
1
/
3.
],
'bitrate'
:
[
1e4
,
1e5
]
}
def
_ensure_even_dimensions
(
self
,
img
):
# Ensure width and height are even
h
,
w
=
img
.
shape
[:
2
]
if
h
%
2
!=
0
:
img
=
img
[:
-
1
,
:]
if
w
%
2
!=
0
:
img
=
img
[:,
:
-
1
]
return
img
def
_apply_random_compression
(
self
,
imgs
):
# Convert PyTorch tensor to NumPy array
imgs
=
imgs
.
permute
(
0
,
2
,
3
,
1
).
cpu
().
numpy
()
# Ensure width and height are even
imgs
=
[
self
.
_ensure_even_dimensions
(
img
)
for
img
in
imgs
]
codec
=
random
.
choices
(
self
.
params
[
'codec'
],
self
.
params
[
'codec_prob'
])[
0
]
bitrate
=
self
.
params
[
'bitrate'
]
bitrate
=
np
.
random
.
randint
(
bitrate
[
0
],
bitrate
[
1
]
+
1
)
buf
=
io
.
BytesIO
()
with
av
.
open
(
buf
,
'w'
,
'mp4'
)
as
container
:
stream
=
container
.
add_stream
(
codec
,
rate
=
1
)
stream
.
height
=
imgs
[
0
].
shape
[
0
]
stream
.
width
=
imgs
[
0
].
shape
[
1
]
stream
.
pix_fmt
=
'yuv420p'
stream
.
bit_rate
=
bitrate
for
img
in
imgs
:
img
=
(
img
*
255
).
clip
(
0
,
255
)
# Convert to [0, 255] range
img
=
img
.
astype
(
np
.
uint8
)
frame
=
av
.
VideoFrame
.
from_ndarray
(
img
,
format
=
'rgb24'
)
frame
.
pict_type
=
'NONE'
for
packet
in
stream
.
encode
(
frame
):
container
.
mux
(
packet
)
# Flush stream
for
packet
in
stream
.
encode
():
container
.
mux
(
packet
)
outputs
=
[]
with
av
.
open
(
buf
,
'r'
,
'mp4'
)
as
container
:
if
container
.
streams
.
video
:
for
frame
in
container
.
decode
(
**
{
'video'
:
0
}):
outputs
.
append
(
frame
.
to_rgb
().
to_ndarray
().
astype
(
np
.
float32
)
/
255
)
# Convert back to [0, 1] range
# Convert NumPy array back to PyTorch tensor
outputs
=
torch
.
tensor
(
outputs
).
permute
(
0
,
3
,
1
,
2
)
return
outputs
class
Degradation
(
nn
.
Module
):
def
__init__
(
self
,
scale
,
gt_size
):
super
(
Degradation
,
self
).
__init__
()
### initization JPEF class
self
.
jpeger
=
DiffJPEG
(
differentiable
=
False
)
#.cuda()
self
.
usm_sharpener
=
USMSharp
()
#.cuda()
# self.queue_size = 180 #opt.get('queue_size', 180)
### global settings
self
.
scale
=
scale
self
.
gt_size
=
gt_size
### the first degradation hypermeters ###
# 1. blur
self
.
blur_kernel_size
=
21
self
.
kernel_range
=
[
2
*
v
+
1
for
v
in
range
(
3
,
11
)]
# kernel size ranges from 7 to 21
self
.
kernel_list
=
[
'iso'
,
'aniso'
,
'generalized_iso'
,
'generalized_aniso'
,
'plateau_iso'
,
'plateau_aniso'
]
self
.
kernel_prob
=
[
0.45
,
0.25
,
0.12
,
0.03
,
0.12
,
0.03
]
self
.
sinc_prob
=
0.1
self
.
blur_sigma
=
[
0.2
,
3
]
# blur_x / y_sigma
self
.
betag_range
=
[
0.5
,
4
]
self
.
betap_range
=
[
1
,
2
]
# 2. resize
self
.
updown_type
=
[
"up"
,
"down"
,
"keep"
]
self
.
mode_list
=
[
"area"
,
"bilinear"
,
"bicubic"
]
# flags:[3,1,2]
self
.
resize_prob
=
[
0.2
,
0.7
,
0.1
]
# up, down, keep
self
.
resize_range
=
[
0.15
,
1.5
]
# 3. noise
self
.
gaussian_noise_prob
=
0.5
self
.
noise_range
=
[
1
,
30
]
self
.
poisson_scale_range
=
[
0.05
,
3
]
self
.
gray_noise_prob
=
0.4
# 4. jpeg
self
.
jpeg_range
=
[
30
,
95
]
### the second degradation hypermeters ###
# 1. blur
self
.
second_blur_prob
=
0.8
self
.
blur_kernel_size2
=
21
self
.
kernel_range2
=
[
2
*
v
+
1
for
v
in
range
(
3
,
11
)]
# kernel size ranges from 7 to 21
self
.
kernel_list2
=
[
'iso'
,
'aniso'
,
'generalized_iso'
,
'generalized_aniso'
,
'plateau_iso'
,
'plateau_aniso'
]
self
.
kernel_prob2
=
[
0.45
,
0.25
,
0.12
,
0.03
,
0.12
,
0.03
]
self
.
sinc_prob2
=
0.1
self
.
blur_sigma2
=
[
0.2
,
1.5
]
self
.
betag_range2
=
[
0.5
,
4
]
self
.
betap_range2
=
[
1
,
2
]
# 2. resize
self
.
updown_type2
=
[
"up"
,
"down"
,
"keep"
]
self
.
mode_list2
=
[
"area"
,
"bilinear"
,
"bicubic"
]
# flags:[3,1,2]
self
.
resize_prob2
=
[
0.3
,
0.4
,
0.3
]
# up, down, keep
self
.
resize_range2
=
[
0.3
,
1.2
]
# 3. noise
self
.
gaussian_noise_prob2
=
0.5
self
.
noise_range2
=
[
1
,
25
]
self
.
poisson_scale_range2
=
[
0.05
,
2.5
]
self
.
gray_noise_prob2
=
0.4
# 4. jpeg
self
.
jpeg_range2
=
[
30
,
95
]
self
.
final_sinc_prob
=
0.8
# TODO: kernel range is now hard-coded, should be in the configure file
self
.
pulse_tensor
=
torch
.
zeros
(
21
,
21
).
float
()
# convolving with pulse tensor brings no blurry effect
self
.
pulse_tensor
[
10
,
10
]
=
1
# video compression
self
.
compressor
=
ImageCompressor
()
@
torch
.
no_grad
()
def
forward_deg
(
self
,
gt
):
ori_h
,
ori_w
=
gt
.
size
()[
2
:
4
]
gt_usm
=
self
.
usm_sharpener
(
gt
)
gt_usm_copy
=
gt_usm
.
clone
()
# generate kernel
kernel1
=
self
.
generate_first_kernel
()
kernel2
=
self
.
generate_second_kernel
()
sinc_kernel
=
self
.
generate_sinc_kernel
()
# first degradation
lq
=
self
.
compressor
.
_apply_random_compression
(
self
.
jpeg_1
(
self
.
noise_1
(
self
.
resize_1
(
self
.
blur_1
(
gt_usm_copy
,
kernel1
)))))
# second degradation
lq
=
self
.
compressor
.
_apply_random_compression
(
self
.
jpeg_2
(
self
.
noise_2
(
self
.
resize_2
(
self
.
blur_2
(
lq
,
kernel2
),
ori_h
,
ori_w
)),
ori_h
,
ori_w
,
sinc_kernel
))
return
lq
,
gt_usm
,
kernel1
,
kernel2
,
sinc_kernel
@
torch
.
no_grad
()
def
forward
(
self
,
img_gt
,
uint8
=
False
):
# read hwc 0-1 numpy
# img_gt = cv2.imread(gt_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
# augment
# img_gt = self.augment(img_gt, True, True)
# numpy 0-1 hwc -> tensor 0-1 chw
# img_gt = self.np2tensor([img_gt], bgr2rgb=True, float32=True)[0]
# add batch
img_gt
=
img_gt
.
unsqueeze
(
0
)
img_gt_copy
=
img_gt
.
clone
()
# degradation_piepline
lq
,
gt_usm
,
kernel1
,
kernel2
,
sinc_kernel
=
self
.
forward_deg
(
img_gt_copy
)
# clamp and round
lq
=
torch
.
clamp
((
lq
*
255.0
).
round
(),
0
,
255
)
/
255.
# print(f'before crop: gt:{img_gt_copy.shape}, lq:{lq.shape}')
# random crop
# (gt, gt_usm), lq = self.paired_random_crop([img_gt_copy, gt_usm], lq, self.gt_size, self.scale)
# print(f'after crop: gt:{gt.shape}, lq:{lq.shape}')
# if uint8:
# gt, gt_usm, lq = self.tensor2np([gt, gt_usm, lq])
# return gt, gt_usm, lq, kernel1, kernel2, sinc_kernel
return
lq
,
gt_usm
# gt, kernel1, kernel2, sinc_kernel
def
blur_1
(
self
,
img
,
kernel1
):
img
=
filter2D
(
img
,
kernel1
)
return
img
def
blur_2
(
self
,
img
,
kernel2
):
if
np
.
random
.
uniform
()
<
self
.
second_blur_prob
:
img
=
filter2D
(
img
,
kernel2
)
return
img
def
resize_1
(
self
,
img
):
updown_type
=
random
.
choices
([
'up'
,
'down'
,
'keep'
],
self
.
resize_prob
)[
0
]
if
updown_type
==
'up'
:
scale
=
np
.
random
.
uniform
(
1
,
self
.
resize_range
[
1
])
elif
updown_type
==
'down'
:
scale
=
np
.
random
.
uniform
(
self
.
resize_range
[
0
],
1
)
else
:
scale
=
1
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
img
=
F
.
interpolate
(
img
,
scale_factor
=
scale
,
mode
=
mode
)
return
img
def
resize_2
(
self
,
img
,
ori_h
,
ori_w
):
updown_type
=
random
.
choices
([
'up'
,
'down'
,
'keep'
],
self
.
resize_prob2
)[
0
]
if
updown_type
==
'up'
:
scale
=
np
.
random
.
uniform
(
1
,
self
.
resize_range2
[
1
])
elif
updown_type
==
'down'
:
scale
=
np
.
random
.
uniform
(
self
.
resize_range2
[
0
],
1
)
else
:
scale
=
1
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
img
=
F
.
interpolate
(
img
,
size
=
(
int
(
ori_h
/
self
.
scale
*
scale
),
int
(
ori_w
/
scale
*
scale
)),
mode
=
mode
)
return
img
def
noise_1
(
self
,
img
):
gray_noise_prob
=
self
.
gray_noise_prob
if
np
.
random
.
uniform
()
<
self
.
gaussian_noise_prob
:
img
=
random_add_gaussian_noise_pt
(
img
,
sigma_range
=
self
.
noise_range
,
clip
=
True
,
rounds
=
False
,
gray_prob
=
gray_noise_prob
)
else
:
img
=
random_add_poisson_noise_pt
(
img
,
scale_range
=
self
.
poisson_scale_range
,
gray_prob
=
gray_noise_prob
,
clip
=
True
,
rounds
=
False
)
return
img
def
noise_2
(
self
,
img
):
gray_noise_prob
=
self
.
gray_noise_prob2
if
np
.
random
.
uniform
()
<
self
.
gaussian_noise_prob2
:
img
=
random_add_gaussian_noise_pt
(
img
,
sigma_range
=
self
.
noise_range2
,
clip
=
True
,
rounds
=
False
,
gray_prob
=
gray_noise_prob
)
else
:
img
=
random_add_poisson_noise_pt
(
img
,
scale_range
=
self
.
poisson_scale_range2
,
gray_prob
=
gray_noise_prob
,
clip
=
True
,
rounds
=
False
)
return
img
def
jpeg_1
(
self
,
img
):
jpeg_p
=
img
.
new_zeros
(
img
.
size
(
0
)).
uniform_
(
*
self
.
jpeg_range
)
img
=
torch
.
clamp
(
img
,
0
,
1
)
# clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
img
=
self
.
jpeger
(
img
,
quality
=
jpeg_p
)
return
img
def
jpeg_2
(
self
,
out
,
ori_h
,
ori_w
,
sinc_kernel
):
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if
np
.
random
.
uniform
()
<
0.5
:
# resize back + the final sinc filter
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
ori_h
//
self
.
scale
,
ori_w
//
self
.
scale
),
mode
=
mode
)
out
=
filter2D
(
out
,
sinc_kernel
)
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
jpeg_range2
)
out
=
torch
.
clamp
(
out
,
0
,
1
)
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
else
:
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
jpeg_range2
)
out
=
torch
.
clamp
(
out
,
0
,
1
)
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
# resize back + the final sinc filter
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
ori_h
//
self
.
scale
,
ori_w
//
self
.
scale
),
mode
=
mode
)
out
=
filter2D
(
out
,
sinc_kernel
)
return
out
def
generate_first_kernel
(
self
):
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
if
np
.
random
.
uniform
()
<
self
.
sinc_prob
:
# this sinc filter setting is for kernels ranging from [7, 21]
if
kernel_size
<
13
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
else
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
5
,
np
.
pi
)
kernel
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
False
)
else
:
kernel
=
random_mixed_kernels
(
self
.
kernel_list
,
self
.
kernel_prob
,
kernel_size
,
self
.
blur_sigma
,
self
.
blur_sigma
,
[
-
math
.
pi
,
math
.
pi
],
self
.
betag_range
,
self
.
betap_range
,
noise_range
=
None
)
# pad kernel
pad_size
=
(
21
-
kernel_size
)
//
2
kernel
=
np
.
pad
(
kernel
,
((
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)))
return
torch
.
FloatTensor
(
kernel
)
def
generate_second_kernel
(
self
):
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
if
np
.
random
.
uniform
()
<
self
.
sinc_prob2
:
if
kernel_size
<
13
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
else
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
5
,
np
.
pi
)
kernel2
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
False
)
else
:
kernel2
=
random_mixed_kernels
(
self
.
kernel_list2
,
self
.
kernel_prob2
,
kernel_size
,
self
.
blur_sigma2
,
self
.
blur_sigma2
,
[
-
math
.
pi
,
math
.
pi
],
self
.
betag_range2
,
self
.
betap_range2
,
noise_range
=
None
)
# pad kernel
pad_size
=
(
21
-
kernel_size
)
//
2
kernel2
=
np
.
pad
(
kernel2
,
((
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)))
return
torch
.
FloatTensor
(
kernel2
)
def
generate_sinc_kernel
(
self
):
if
np
.
random
.
uniform
()
<
self
.
final_sinc_prob
:
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
sinc_kernel
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
21
)
sinc_kernel
=
torch
.
FloatTensor
(
sinc_kernel
)
else
:
sinc_kernel
=
self
.
pulse_tensor
return
sinc_kernel
def
np2tensor
(
self
,
imgs
,
bgr2rgb
=
False
,
float32
=
True
):
def
_totensor
(
img
,
bgr2rgb
,
float32
):
if
img
.
shape
[
2
]
==
3
and
bgr2rgb
:
if
img
.
dtype
==
'float64'
:
img
=
img
.
astype
(
'float32'
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
img
=
torch
.
from_numpy
(
img
.
transpose
(
2
,
0
,
1
))
if
float32
:
img
=
img
.
float
()
return
img
if
isinstance
(
imgs
,
list
):
return
[
_totensor
(
img
,
bgr2rgb
,
float32
)
for
img
in
imgs
]
else
:
return
_totensor
(
imgs
,
bgr2rgb
,
float32
)
def
tensor2np
(
self
,
imgs
):
def
_tonumpy
(
img
):
img
=
img
.
data
.
cpu
().
numpy
().
squeeze
(
0
).
transpose
(
1
,
2
,
0
)
#.astype(np.float32)
img
=
np
.
uint8
((
img
.
clip
(
0
,
1
)
*
255.
).
round
())
return
img
if
isinstance
(
imgs
,
list
):
return
[
_tonumpy
(
img
)
for
img
in
imgs
]
else
:
return
_tonumpy
(
imgs
)
def
augment
(
self
,
imgs
,
hflip
=
True
,
rotation
=
True
,
flows
=
None
,
return_status
=
False
):
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rotation
and
random
.
random
()
<
0.5
rot90
=
rotation
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
# horizontal
cv2
.
flip
(
img
,
1
,
img
)
if
vflip
:
# vertical
cv2
.
flip
(
img
,
0
,
img
)
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
if
not
isinstance
(
imgs
,
list
):
imgs
=
[
imgs
]
imgs
=
[
_augment
(
img
)
for
img
in
imgs
]
if
len
(
imgs
)
==
1
:
imgs
=
imgs
[
0
]
return
imgs
def
paired_random_crop
(
self
,
img_gts
,
img_lqs
,
gt_patch_size
,
scale
,
gt_path
=
None
):
"""Paired random crop. Support Numpy array and Tensor inputs.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth. Default: None.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if
not
isinstance
(
img_gts
,
list
):
img_gts
=
[
img_gts
]
if
not
isinstance
(
img_lqs
,
list
):
img_lqs
=
[
img_lqs
]
# determine input type: Numpy array or Tensor
input_type
=
'Tensor'
if
torch
.
is_tensor
(
img_gts
[
0
])
else
'Numpy'
if
input_type
==
'Tensor'
:
h_lq
,
w_lq
=
img_lqs
[
0
].
size
()[
-
2
:]
h_gt
,
w_gt
=
img_gts
[
0
].
size
()[
-
2
:]
else
:
h_lq
,
w_lq
=
img_lqs
[
0
].
shape
[
0
:
2
]
h_gt
,
w_gt
=
img_gts
[
0
].
shape
[
0
:
2
]
lq_patch_size
=
gt_patch_size
//
scale
if
h_gt
!=
h_lq
*
scale
or
w_gt
!=
w_lq
*
scale
:
raise
ValueError
(
f
'Scale mismatches. GT (
{
h_gt
}
,
{
w_gt
}
) is not
{
scale
}
x '
,
f
'multiplication of LQ (
{
h_lq
}
,
{
w_lq
}
).'
)
if
h_lq
<
lq_patch_size
or
w_lq
<
lq_patch_size
:
raise
ValueError
(
f
'LQ (
{
h_lq
}
,
{
w_lq
}
) is smaller than patch size '
f
'(
{
lq_patch_size
}
,
{
lq_patch_size
}
). '
f
'Please remove
{
gt_path
}
.'
)
# randomly choose top and left coordinates for lq patch
top
=
random
.
randint
(
0
,
h_lq
-
lq_patch_size
)
left
=
random
.
randint
(
0
,
w_lq
-
lq_patch_size
)
# crop lq patch
if
input_type
==
'Tensor'
:
img_lqs
=
[
v
[:,
:,
top
:
top
+
lq_patch_size
,
left
:
left
+
lq_patch_size
]
for
v
in
img_lqs
]
else
:
img_lqs
=
[
v
[
top
:
top
+
lq_patch_size
,
left
:
left
+
lq_patch_size
,
...]
for
v
in
img_lqs
]
# crop corresponding gt patch
top_gt
,
left_gt
=
int
(
top
*
scale
),
int
(
left
*
scale
)
if
input_type
==
'Tensor'
:
img_gts
=
[
v
[:,
:,
top_gt
:
top_gt
+
gt_patch_size
,
left_gt
:
left_gt
+
gt_patch_size
]
for
v
in
img_gts
]
else
:
img_gts
=
[
v
[
top_gt
:
top_gt
+
gt_patch_size
,
left_gt
:
left_gt
+
gt_patch_size
,
...]
for
v
in
img_gts
]
if
len
(
img_gts
)
==
1
:
img_gts
=
img_gts
[
0
]
if
len
(
img_lqs
)
==
1
:
img_lqs
=
img_lqs
[
0
]
return
img_gts
,
img_lqs
# init degradation process
degradation
=
Degradation
(
scale
=
4
,
gt_size
=
(
480
,
720
))
def
degradation_process
(
video_array
):
_
,
_
,
t
,
_
,
_
=
video_array
.
shape
# preprocess video
video_array
=
video_array
.
to
(
torch
.
float32
).
cpu
()
video_array
=
(
video_array
+
1
)
*
0.5
# [-1, 1] -> [0, 1]
video_array
=
rearrange
(
video_array
,
"B C T H W -> (B T) C H W"
)
assert
torch
.
max
(
video_array
)
<=
1
and
torch
.
min
(
video_array
)
>=
0
,
"Values are NOT within [0, 1]."
# degrade
lq_list
=
[]
gt_list
=
[]
for
video
in
video_array
:
lq
,
gt
=
degradation
(
video
)
lq_list
.
append
(
lq
)
gt_list
.
append
(
gt
)
lq
=
torch
.
cat
(
lq_list
,
dim
=
0
)
gt
=
torch
.
cat
(
gt_list
,
dim
=
0
)
lq
=
lq
.
clip
(
0
,
1
)
*
2
-
1
gt
=
gt
.
clip
(
0
,
1
)
*
2
-
1
lq
=
rearrange
(
lq
,
"(B T) C H W -> B C T H W"
,
T
=
t
).
to
(
torch
.
float32
)
gt
=
rearrange
(
gt
,
"(B T) C H W -> B C T H W"
,
T
=
t
).
to
(
torch
.
float32
)
return
lq
,
gt
\ No newline at end of file
utils_data/opensora/datasets/high_order/degrade_video_mid.py
0 → 100644
View file @
1f5da520
from
re
import
T
import
cv2
import
numpy
as
np
import
random
import
torch
import
math
# import os
import
torch.nn
as
nn
import
sys
sys
.
path
.
append
(
'/mnt/bn/videodataset-uswest/VSR/VSR/opensora/datasets/high_order'
)
from
utils_
import
filter2D
,
USMSharp
from
utils_blur
import
circular_lowpass_kernel
,
random_mixed_kernels
# from utils_resize import random_resizing
from
utils_noise
import
random_add_gaussian_noise_pt
,
random_add_poisson_noise_pt
from
utils_jpeg
import
DiffJPEG
from
torch.nn
import
functional
as
F
from
einops
import
rearrange
import
av
import
io
class
ImageCompressor
:
def
__init__
(
self
):
self
.
params
=
{
'codec'
:
[
'libx264'
,
'h264'
,
'mpeg4'
],
'codec_prob'
:
[
1
/
3.
,
1
/
3.
,
1
/
3.
],
'bitrate'
:
[
1e4
,
1e5
]
}
def
_ensure_even_dimensions
(
self
,
img
):
# Ensure width and height are even
h
,
w
=
img
.
shape
[:
2
]
if
h
%
2
!=
0
:
img
=
img
[:
-
1
,
:]
if
w
%
2
!=
0
:
img
=
img
[:,
:
-
1
]
return
img
def
_apply_random_compression
(
self
,
imgs
):
# Convert PyTorch tensor to NumPy array
imgs
=
imgs
.
permute
(
0
,
2
,
3
,
1
).
cpu
().
numpy
()
# Ensure width and height are even
imgs
=
[
self
.
_ensure_even_dimensions
(
img
)
for
img
in
imgs
]
codec
=
random
.
choices
(
self
.
params
[
'codec'
],
self
.
params
[
'codec_prob'
])[
0
]
bitrate
=
self
.
params
[
'bitrate'
]
bitrate
=
np
.
random
.
randint
(
bitrate
[
0
],
bitrate
[
1
]
+
1
)
buf
=
io
.
BytesIO
()
with
av
.
open
(
buf
,
'w'
,
'mp4'
)
as
container
:
stream
=
container
.
add_stream
(
codec
,
rate
=
1
)
stream
.
height
=
imgs
[
0
].
shape
[
0
]
stream
.
width
=
imgs
[
0
].
shape
[
1
]
stream
.
pix_fmt
=
'yuv420p'
stream
.
bit_rate
=
bitrate
for
img
in
imgs
:
img
=
(
img
*
255
).
clip
(
0
,
255
)
# Convert to [0, 255] range
img
=
img
.
astype
(
np
.
uint8
)
frame
=
av
.
VideoFrame
.
from_ndarray
(
img
,
format
=
'rgb24'
)
frame
.
pict_type
=
'NONE'
for
packet
in
stream
.
encode
(
frame
):
container
.
mux
(
packet
)
# Flush stream
for
packet
in
stream
.
encode
():
container
.
mux
(
packet
)
outputs
=
[]
with
av
.
open
(
buf
,
'r'
,
'mp4'
)
as
container
:
if
container
.
streams
.
video
:
for
frame
in
container
.
decode
(
**
{
'video'
:
0
}):
outputs
.
append
(
frame
.
to_rgb
().
to_ndarray
().
astype
(
np
.
float32
)
/
255
)
# Convert back to [0, 1] range
# Convert NumPy array back to PyTorch tensor
outputs
=
torch
.
tensor
(
outputs
).
permute
(
0
,
3
,
1
,
2
)
return
outputs
class
Degradation
(
nn
.
Module
):
def
__init__
(
self
,
scale
,
gt_size
):
super
(
Degradation
,
self
).
__init__
()
### initization JPEF class
self
.
jpeger
=
DiffJPEG
(
differentiable
=
False
)
#.cuda()
self
.
usm_sharpener
=
USMSharp
()
#.cuda()
# self.queue_size = 180 #opt.get('queue_size', 180)
### global settings
self
.
scale
=
scale
self
.
gt_size
=
gt_size
### the first degradation hypermeters ###
# 1. blur
self
.
blur_kernel_size
=
21
self
.
kernel_range
=
[
2
*
v
+
1
for
v
in
range
(
3
,
11
)]
# kernel size ranges from 7 to 21
self
.
kernel_list
=
[
'iso'
,
'aniso'
,
'generalized_iso'
,
'generalized_aniso'
,
'plateau_iso'
,
'plateau_aniso'
]
self
.
kernel_prob
=
[
0.45
,
0.25
,
0.12
,
0.03
,
0.12
,
0.03
]
self
.
sinc_prob
=
0.1
self
.
blur_sigma
=
[
0.2
,
3
]
# blur_x / y_sigma
self
.
betag_range
=
[
0.5
,
4
]
self
.
betap_range
=
[
1
,
2
]
# 2. resize
self
.
updown_type
=
[
"up"
,
"down"
,
"keep"
]
self
.
mode_list
=
[
"area"
,
"bilinear"
,
"bicubic"
]
# flags:[3,1,2]
self
.
resize_prob
=
[
0.2
,
0.7
,
0.1
]
# up, down, keep
self
.
resize_range
=
[
0.15
,
1.5
]
# 3. noise
self
.
gaussian_noise_prob
=
0.5
self
.
noise_range
=
[
1
,
30
]
self
.
poisson_scale_range
=
[
0.05
,
3
]
self
.
gray_noise_prob
=
0.4
# 4. jpeg
self
.
jpeg_range
=
[
30
,
95
]
### the second degradation hypermeters ###
# 1. blur
self
.
second_blur_prob
=
0.8
self
.
blur_kernel_size2
=
21
self
.
kernel_range2
=
[
2
*
v
+
1
for
v
in
range
(
3
,
11
)]
# kernel size ranges from 7 to 21
self
.
kernel_list2
=
[
'iso'
,
'aniso'
,
'generalized_iso'
,
'generalized_aniso'
,
'plateau_iso'
,
'plateau_aniso'
]
self
.
kernel_prob2
=
[
0.45
,
0.25
,
0.12
,
0.03
,
0.12
,
0.03
]
self
.
sinc_prob2
=
0.1
self
.
blur_sigma2
=
[
0.2
,
1.5
]
self
.
betag_range2
=
[
0.5
,
4
]
self
.
betap_range2
=
[
1
,
2
]
# 2. resize
self
.
updown_type2
=
[
"up"
,
"down"
,
"keep"
]
self
.
mode_list2
=
[
"area"
,
"bilinear"
,
"bicubic"
]
# flags:[3,1,2]
self
.
resize_prob2
=
[
0.3
,
0.4
,
0.3
]
# up, down, keep
self
.
resize_range2
=
[
0.3
,
1.2
]
# 3. noise
self
.
gaussian_noise_prob2
=
0.5
self
.
noise_range2
=
[
1
,
25
]
self
.
poisson_scale_range2
=
[
0.05
,
2.5
]
self
.
gray_noise_prob2
=
0.4
# 4. jpeg
self
.
jpeg_range2
=
[
30
,
95
]
self
.
final_sinc_prob
=
0.8
# TODO: kernel range is now hard-coded, should be in the configure file
self
.
pulse_tensor
=
torch
.
zeros
(
21
,
21
).
float
()
# convolving with pulse tensor brings no blurry effect
self
.
pulse_tensor
[
10
,
10
]
=
1
# video compression
self
.
compressor
=
ImageCompressor
()
@
torch
.
no_grad
()
def
forward_deg
(
self
,
gt
):
ori_h
,
ori_w
=
gt
.
size
()[
2
:
4
]
gt_usm
=
self
.
usm_sharpener
(
gt
)
gt_usm_copy
=
gt_usm
.
clone
()
# generate kernel
kernel1
=
self
.
generate_first_kernel
()
kernel2
=
self
.
generate_second_kernel
()
sinc_kernel
=
self
.
generate_sinc_kernel
()
# first degradation
lq
=
self
.
compressor
.
_apply_random_compression
(
self
.
jpeg_1
(
self
.
noise_1
(
self
.
resize_2
(
self
.
blur_1
(
gt_usm_copy
,
kernel1
)))))
# second degradation
# lq = self.compressor._apply_random_compression(self.jpeg_2(self.noise_2(self.resize_2(self.blur_2(lq, kernel2), ori_h,ori_w)), ori_h,ori_w, sinc_kernel))
return
lq
,
gt_usm
,
kernel1
,
kernel2
,
sinc_kernel
@
torch
.
no_grad
()
def
forward
(
self
,
img_gt
,
uint8
=
False
):
# read hwc 0-1 numpy
# img_gt = cv2.imread(gt_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
# augment
# img_gt = self.augment(img_gt, True, True)
# numpy 0-1 hwc -> tensor 0-1 chw
# img_gt = self.np2tensor([img_gt], bgr2rgb=True, float32=True)[0]
# add batch
img_gt
=
img_gt
.
unsqueeze
(
0
)
img_gt_copy
=
img_gt
.
clone
()
# degradation_piepline
lq
,
gt_usm
,
kernel1
,
kernel2
,
sinc_kernel
=
self
.
forward_deg
(
img_gt_copy
)
# clamp and round
lq
=
torch
.
clamp
((
lq
*
255.0
).
round
(),
0
,
255
)
/
255.
# print(f'before crop: gt:{img_gt_copy.shape}, lq:{lq.shape}')
# random crop
# (gt, gt_usm), lq = self.paired_random_crop([img_gt_copy, gt_usm], lq, self.gt_size, self.scale)
# print(f'after crop: gt:{gt.shape}, lq:{lq.shape}')
# if uint8:
# gt, gt_usm, lq = self.tensor2np([gt, gt_usm, lq])
# return gt, gt_usm, lq, kernel1, kernel2, sinc_kernel
return
lq
,
gt_usm
# gt, kernel1, kernel2, sinc_kernel
def
blur_1
(
self
,
img
,
kernel1
):
img
=
filter2D
(
img
,
kernel1
)
return
img
def
blur_2
(
self
,
img
,
kernel2
):
if
np
.
random
.
uniform
()
<
self
.
second_blur_prob
:
img
=
filter2D
(
img
,
kernel2
)
return
img
def
resize_1
(
self
,
img
):
updown_type
=
random
.
choices
([
'up'
,
'down'
,
'keep'
],
self
.
resize_prob
)[
0
]
if
updown_type
==
'up'
:
scale
=
np
.
random
.
uniform
(
1
,
self
.
resize_range
[
1
])
elif
updown_type
==
'down'
:
scale
=
np
.
random
.
uniform
(
self
.
resize_range
[
0
],
1
)
else
:
scale
=
1
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
img
=
F
.
interpolate
(
img
,
scale_factor
=
scale
,
mode
=
mode
)
return
img
def
resize_2
(
self
,
img
,
ori_h
,
ori_w
):
updown_type
=
random
.
choices
([
'up'
,
'down'
,
'keep'
],
self
.
resize_prob2
)[
0
]
if
updown_type
==
'up'
:
scale
=
np
.
random
.
uniform
(
1
,
self
.
resize_range2
[
1
])
elif
updown_type
==
'down'
:
scale
=
np
.
random
.
uniform
(
self
.
resize_range2
[
0
],
1
)
else
:
scale
=
1
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
img
=
F
.
interpolate
(
img
,
size
=
(
int
(
ori_h
/
self
.
scale
*
scale
),
int
(
ori_w
/
scale
*
scale
)),
mode
=
mode
)
return
img
def
noise_1
(
self
,
img
):
gray_noise_prob
=
self
.
gray_noise_prob
if
np
.
random
.
uniform
()
<
self
.
gaussian_noise_prob
:
img
=
random_add_gaussian_noise_pt
(
img
,
sigma_range
=
self
.
noise_range
,
clip
=
True
,
rounds
=
False
,
gray_prob
=
gray_noise_prob
)
else
:
img
=
random_add_poisson_noise_pt
(
img
,
scale_range
=
self
.
poisson_scale_range
,
gray_prob
=
gray_noise_prob
,
clip
=
True
,
rounds
=
False
)
return
img
def
noise_2
(
self
,
img
):
gray_noise_prob
=
self
.
gray_noise_prob2
if
np
.
random
.
uniform
()
<
self
.
gaussian_noise_prob2
:
img
=
random_add_gaussian_noise_pt
(
img
,
sigma_range
=
self
.
noise_range2
,
clip
=
True
,
rounds
=
False
,
gray_prob
=
gray_noise_prob
)
else
:
img
=
random_add_poisson_noise_pt
(
img
,
scale_range
=
self
.
poisson_scale_range2
,
gray_prob
=
gray_noise_prob
,
clip
=
True
,
rounds
=
False
)
return
img
def
jpeg_1
(
self
,
img
):
jpeg_p
=
img
.
new_zeros
(
img
.
size
(
0
)).
uniform_
(
*
self
.
jpeg_range
)
img
=
torch
.
clamp
(
img
,
0
,
1
)
# clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
img
=
self
.
jpeger
(
img
,
quality
=
jpeg_p
)
return
img
def
jpeg_2
(
self
,
out
,
ori_h
,
ori_w
,
sinc_kernel
):
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if
np
.
random
.
uniform
()
<
0.5
:
# resize back + the final sinc filter
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
ori_h
//
self
.
scale
,
ori_w
//
self
.
scale
),
mode
=
mode
)
out
=
filter2D
(
out
,
sinc_kernel
)
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
jpeg_range2
)
out
=
torch
.
clamp
(
out
,
0
,
1
)
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
else
:
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
jpeg_range2
)
out
=
torch
.
clamp
(
out
,
0
,
1
)
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
# resize back + the final sinc filter
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
ori_h
//
self
.
scale
,
ori_w
//
self
.
scale
),
mode
=
mode
)
out
=
filter2D
(
out
,
sinc_kernel
)
return
out
def
generate_first_kernel
(
self
):
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
if
np
.
random
.
uniform
()
<
self
.
sinc_prob
:
# this sinc filter setting is for kernels ranging from [7, 21]
if
kernel_size
<
13
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
else
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
5
,
np
.
pi
)
kernel
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
False
)
else
:
kernel
=
random_mixed_kernels
(
self
.
kernel_list
,
self
.
kernel_prob
,
kernel_size
,
self
.
blur_sigma
,
self
.
blur_sigma
,
[
-
math
.
pi
,
math
.
pi
],
self
.
betag_range
,
self
.
betap_range
,
noise_range
=
None
)
# pad kernel
pad_size
=
(
21
-
kernel_size
)
//
2
kernel
=
np
.
pad
(
kernel
,
((
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)))
return
torch
.
FloatTensor
(
kernel
)
def
generate_second_kernel
(
self
):
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
if
np
.
random
.
uniform
()
<
self
.
sinc_prob2
:
if
kernel_size
<
13
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
else
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
5
,
np
.
pi
)
kernel2
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
False
)
else
:
kernel2
=
random_mixed_kernels
(
self
.
kernel_list2
,
self
.
kernel_prob2
,
kernel_size
,
self
.
blur_sigma2
,
self
.
blur_sigma2
,
[
-
math
.
pi
,
math
.
pi
],
self
.
betag_range2
,
self
.
betap_range2
,
noise_range
=
None
)
# pad kernel
pad_size
=
(
21
-
kernel_size
)
//
2
kernel2
=
np
.
pad
(
kernel2
,
((
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)))
return
torch
.
FloatTensor
(
kernel2
)
def
generate_sinc_kernel
(
self
):
if
np
.
random
.
uniform
()
<
self
.
final_sinc_prob
:
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
sinc_kernel
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
21
)
sinc_kernel
=
torch
.
FloatTensor
(
sinc_kernel
)
else
:
sinc_kernel
=
self
.
pulse_tensor
return
sinc_kernel
def
np2tensor
(
self
,
imgs
,
bgr2rgb
=
False
,
float32
=
True
):
def
_totensor
(
img
,
bgr2rgb
,
float32
):
if
img
.
shape
[
2
]
==
3
and
bgr2rgb
:
if
img
.
dtype
==
'float64'
:
img
=
img
.
astype
(
'float32'
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
img
=
torch
.
from_numpy
(
img
.
transpose
(
2
,
0
,
1
))
if
float32
:
img
=
img
.
float
()
return
img
if
isinstance
(
imgs
,
list
):
return
[
_totensor
(
img
,
bgr2rgb
,
float32
)
for
img
in
imgs
]
else
:
return
_totensor
(
imgs
,
bgr2rgb
,
float32
)
def
tensor2np
(
self
,
imgs
):
def
_tonumpy
(
img
):
img
=
img
.
data
.
cpu
().
numpy
().
squeeze
(
0
).
transpose
(
1
,
2
,
0
)
#.astype(np.float32)
img
=
np
.
uint8
((
img
.
clip
(
0
,
1
)
*
255.
).
round
())
return
img
if
isinstance
(
imgs
,
list
):
return
[
_tonumpy
(
img
)
for
img
in
imgs
]
else
:
return
_tonumpy
(
imgs
)
def
augment
(
self
,
imgs
,
hflip
=
True
,
rotation
=
True
,
flows
=
None
,
return_status
=
False
):
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rotation
and
random
.
random
()
<
0.5
rot90
=
rotation
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
# horizontal
cv2
.
flip
(
img
,
1
,
img
)
if
vflip
:
# vertical
cv2
.
flip
(
img
,
0
,
img
)
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
if
not
isinstance
(
imgs
,
list
):
imgs
=
[
imgs
]
imgs
=
[
_augment
(
img
)
for
img
in
imgs
]
if
len
(
imgs
)
==
1
:
imgs
=
imgs
[
0
]
return
imgs
def
paired_random_crop
(
self
,
img_gts
,
img_lqs
,
gt_patch_size
,
scale
,
gt_path
=
None
):
"""Paired random crop. Support Numpy array and Tensor inputs.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth. Default: None.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if
not
isinstance
(
img_gts
,
list
):
img_gts
=
[
img_gts
]
if
not
isinstance
(
img_lqs
,
list
):
img_lqs
=
[
img_lqs
]
# determine input type: Numpy array or Tensor
input_type
=
'Tensor'
if
torch
.
is_tensor
(
img_gts
[
0
])
else
'Numpy'
if
input_type
==
'Tensor'
:
h_lq
,
w_lq
=
img_lqs
[
0
].
size
()[
-
2
:]
h_gt
,
w_gt
=
img_gts
[
0
].
size
()[
-
2
:]
else
:
h_lq
,
w_lq
=
img_lqs
[
0
].
shape
[
0
:
2
]
h_gt
,
w_gt
=
img_gts
[
0
].
shape
[
0
:
2
]
lq_patch_size
=
gt_patch_size
//
scale
if
h_gt
!=
h_lq
*
scale
or
w_gt
!=
w_lq
*
scale
:
raise
ValueError
(
f
'Scale mismatches. GT (
{
h_gt
}
,
{
w_gt
}
) is not
{
scale
}
x '
,
f
'multiplication of LQ (
{
h_lq
}
,
{
w_lq
}
).'
)
if
h_lq
<
lq_patch_size
or
w_lq
<
lq_patch_size
:
raise
ValueError
(
f
'LQ (
{
h_lq
}
,
{
w_lq
}
) is smaller than patch size '
f
'(
{
lq_patch_size
}
,
{
lq_patch_size
}
). '
f
'Please remove
{
gt_path
}
.'
)
# randomly choose top and left coordinates for lq patch
top
=
random
.
randint
(
0
,
h_lq
-
lq_patch_size
)
left
=
random
.
randint
(
0
,
w_lq
-
lq_patch_size
)
# crop lq patch
if
input_type
==
'Tensor'
:
img_lqs
=
[
v
[:,
:,
top
:
top
+
lq_patch_size
,
left
:
left
+
lq_patch_size
]
for
v
in
img_lqs
]
else
:
img_lqs
=
[
v
[
top
:
top
+
lq_patch_size
,
left
:
left
+
lq_patch_size
,
...]
for
v
in
img_lqs
]
# crop corresponding gt patch
top_gt
,
left_gt
=
int
(
top
*
scale
),
int
(
left
*
scale
)
if
input_type
==
'Tensor'
:
img_gts
=
[
v
[:,
:,
top_gt
:
top_gt
+
gt_patch_size
,
left_gt
:
left_gt
+
gt_patch_size
]
for
v
in
img_gts
]
else
:
img_gts
=
[
v
[
top_gt
:
top_gt
+
gt_patch_size
,
left_gt
:
left_gt
+
gt_patch_size
,
...]
for
v
in
img_gts
]
if
len
(
img_gts
)
==
1
:
img_gts
=
img_gts
[
0
]
if
len
(
img_lqs
)
==
1
:
img_lqs
=
img_lqs
[
0
]
return
img_gts
,
img_lqs
# init degradation process
degradation
=
Degradation
(
scale
=
4
,
gt_size
=
(
480
,
720
))
def
degradation_process
(
video_array
):
_
,
_
,
t
,
_
,
_
=
video_array
.
shape
# preprocess video
video_array
=
video_array
.
to
(
torch
.
float32
).
cpu
()
video_array
=
(
video_array
+
1
)
*
0.5
# [-1, 1] -> [0, 1]
video_array
=
rearrange
(
video_array
,
"B C T H W -> (B T) C H W"
)
assert
torch
.
max
(
video_array
)
<=
1
and
torch
.
min
(
video_array
)
>=
0
,
"Values are NOT within [0, 1]."
# degrade
lq_list
=
[]
gt_list
=
[]
for
video
in
video_array
:
lq
,
gt
=
degradation
(
video
)
lq_list
.
append
(
lq
)
gt_list
.
append
(
gt
)
lq
=
torch
.
cat
(
lq_list
,
dim
=
0
)
gt
=
torch
.
cat
(
gt_list
,
dim
=
0
)
lq
=
lq
.
clip
(
0
,
1
)
*
2
-
1
gt
=
gt
.
clip
(
0
,
1
)
*
2
-
1
lq
=
rearrange
(
lq
,
"(B T) C H W -> B C T H W"
,
T
=
t
).
to
(
torch
.
float32
)
gt
=
rearrange
(
gt
,
"(B T) C H W -> B C T H W"
,
T
=
t
).
to
(
torch
.
float32
)
return
lq
,
gt
\ No newline at end of file
utils_data/opensora/datasets/high_order/matlab_functions.py
0 → 100644
View file @
1f5da520
import
math
import
numpy
as
np
import
torch
def
cubic
(
x
):
"""cubic function used for calculate_weights_indices."""
absx
=
torch
.
abs
(
x
)
absx2
=
absx
**
2
absx3
=
absx
**
3
return
(
1.5
*
absx3
-
2.5
*
absx2
+
1
)
*
(
(
absx
<=
1
).
type_as
(
absx
))
+
(
-
0.5
*
absx3
+
2.5
*
absx2
-
4
*
absx
+
2
)
*
(((
absx
>
1
)
*
(
absx
<=
2
)).
type_as
(
absx
))
def
calculate_weights_indices
(
in_length
,
out_length
,
scale
,
kernel
,
kernel_width
,
antialiasing
):
"""Calculate weights and indices, used for imresize function.
Args:
in_length (int): Input length.
out_length (int): Output length.
scale (float): Scale factor.
kernel_width (int): Kernel width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
"""
if
(
scale
<
1
)
and
antialiasing
:
# Use a modified kernel (larger kernel width) to simultaneously
# interpolate and antialias
kernel_width
=
kernel_width
/
scale
# Output-space coordinates
x
=
torch
.
linspace
(
1
,
out_length
,
out_length
)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5 + scale in output
# space maps to 1.5 in input space.
u
=
x
/
scale
+
0.5
*
(
1
-
1
/
scale
)
# What is the left-most pixel that can be involved in the computation?
left
=
torch
.
floor
(
u
-
kernel_width
/
2
)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
p
=
math
.
ceil
(
kernel_width
)
+
2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices
=
left
.
view
(
out_length
,
1
).
expand
(
out_length
,
p
)
+
torch
.
linspace
(
0
,
p
-
1
,
p
).
view
(
1
,
p
).
expand
(
out_length
,
p
)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center
=
u
.
view
(
out_length
,
1
).
expand
(
out_length
,
p
)
-
indices
# apply cubic kernel
if
(
scale
<
1
)
and
antialiasing
:
weights
=
scale
*
cubic
(
distance_to_center
*
scale
)
else
:
weights
=
cubic
(
distance_to_center
)
# Normalize the weights matrix so that each row sums to 1.
weights_sum
=
torch
.
sum
(
weights
,
1
).
view
(
out_length
,
1
)
weights
=
weights
/
weights_sum
.
expand
(
out_length
,
p
)
# If a column in weights is all zero, get rid of it. only consider the
# first and last column.
weights_zero_tmp
=
torch
.
sum
((
weights
==
0
),
0
)
if
not
math
.
isclose
(
weights_zero_tmp
[
0
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
1
,
p
-
2
)
weights
=
weights
.
narrow
(
1
,
1
,
p
-
2
)
if
not
math
.
isclose
(
weights_zero_tmp
[
-
1
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
0
,
p
-
2
)
weights
=
weights
.
narrow
(
1
,
0
,
p
-
2
)
weights
=
weights
.
contiguous
()
indices
=
indices
.
contiguous
()
sym_len_s
=
-
indices
.
min
()
+
1
sym_len_e
=
indices
.
max
()
-
in_length
indices
=
indices
+
sym_len_s
-
1
return
weights
,
indices
,
int
(
sym_len_s
),
int
(
sym_len_e
)
@
torch
.
no_grad
()
def
imresize
(
img
,
scale
,
antialiasing
=
True
):
"""imresize function same as MATLAB.
It now only supports bicubic.
The same scale applies for both height and width.
Args:
img (Tensor | Numpy array):
Tensor: Input image with shape (c, h, w), [0, 1] range.
Numpy: Input image with shape (h, w, c), [0, 1] range.
scale (float): Scale factor. The same scale applies for both height
and width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
Default: True.
Returns:
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
"""
squeeze_flag
=
False
if
type
(
img
).
__module__
==
np
.
__name__
:
# numpy type
numpy_type
=
True
if
img
.
ndim
==
2
:
img
=
img
[:,
:,
None
]
squeeze_flag
=
True
img
=
torch
.
from_numpy
(
img
.
transpose
(
2
,
0
,
1
)).
float
()
else
:
numpy_type
=
False
if
img
.
ndim
==
2
:
img
=
img
.
unsqueeze
(
0
)
squeeze_flag
=
True
in_c
,
in_h
,
in_w
=
img
.
size
()
out_h
,
out_w
=
math
.
ceil
(
in_h
*
scale
),
math
.
ceil
(
in_w
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# get weights and indices
weights_h
,
indices_h
,
sym_len_hs
,
sym_len_he
=
calculate_weights_indices
(
in_h
,
out_h
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_w
,
indices_w
,
sym_len_ws
,
sym_len_we
=
calculate_weights_indices
(
in_w
,
out_w
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
torch
.
FloatTensor
(
in_c
,
in_h
+
sym_len_hs
+
sym_len_he
,
in_w
)
img_aug
.
narrow
(
1
,
sym_len_hs
,
in_h
).
copy_
(
img
)
sym_patch
=
img
[:,
:
sym_len_hs
,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
0
,
sym_len_hs
).
copy_
(
sym_patch_inv
)
sym_patch
=
img
[:,
-
sym_len_he
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
sym_len_hs
+
in_h
,
sym_len_he
).
copy_
(
sym_patch_inv
)
out_1
=
torch
.
FloatTensor
(
in_c
,
out_h
,
in_w
)
kernel_width
=
weights_h
.
size
(
1
)
for
i
in
range
(
out_h
):
idx
=
int
(
indices_h
[
i
][
0
])
for
j
in
range
(
in_c
):
out_1
[
j
,
i
,
:]
=
img_aug
[
j
,
idx
:
idx
+
kernel_width
,
:].
transpose
(
0
,
1
).
mv
(
weights_h
[
i
])
# process W dimension
# symmetric copying
out_1_aug
=
torch
.
FloatTensor
(
in_c
,
out_h
,
in_w
+
sym_len_ws
+
sym_len_we
)
out_1_aug
.
narrow
(
2
,
sym_len_ws
,
in_w
).
copy_
(
out_1
)
sym_patch
=
out_1
[:,
:,
:
sym_len_ws
]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
0
,
sym_len_ws
).
copy_
(
sym_patch_inv
)
sym_patch
=
out_1
[:,
:,
-
sym_len_we
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
sym_len_ws
+
in_w
,
sym_len_we
).
copy_
(
sym_patch_inv
)
out_2
=
torch
.
FloatTensor
(
in_c
,
out_h
,
out_w
)
kernel_width
=
weights_w
.
size
(
1
)
for
i
in
range
(
out_w
):
idx
=
int
(
indices_w
[
i
][
0
])
for
j
in
range
(
in_c
):
out_2
[
j
,
:,
i
]
=
out_1_aug
[
j
,
:,
idx
:
idx
+
kernel_width
].
mv
(
weights_w
[
i
])
if
squeeze_flag
:
out_2
=
out_2
.
squeeze
(
0
)
if
numpy_type
:
out_2
=
out_2
.
numpy
()
if
not
squeeze_flag
:
out_2
=
out_2
.
transpose
(
1
,
2
,
0
)
#tensor CHW [0,1] -> numpy HWC [0,1]
out_2
=
out_2
.
numpy
().
transpose
((
1
,
2
,
0
))
return
out_2
\ No newline at end of file
utils_data/opensora/datasets/high_order/utils_.py
0 → 100644
View file @
1f5da520
import
cv2
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
torch
from
torch.nn
import
functional
as
F
"""
filter2D(img, kernel)
usm_sharp(img, weight=0.5, radius=50, threshold=10)
USMSharp(torch.nn.Module)
"""
def
filter2D
(
img
,
kernel
):
"""PyTorch version of cv2.filter2D
Args:
img (Tensor): (b, c, h, w)
kernel (Tensor): (b, k, k)
"""
k
=
kernel
.
size
(
-
1
)
b
,
c
,
h
,
w
=
img
.
size
()
if
k
%
2
==
1
:
img
=
F
.
pad
(
img
,
(
k
//
2
,
k
//
2
,
k
//
2
,
k
//
2
),
mode
=
'reflect'
)
else
:
raise
ValueError
(
'Wrong kernel size'
)
ph
,
pw
=
img
.
size
()[
-
2
:]
# if kernel.size(0) == 1:
# apply the same kernel to all batch images
img
=
img
.
contiguous
().
view
(
b
*
c
,
1
,
ph
,
pw
)
kernel
=
kernel
.
view
(
1
,
1
,
k
,
k
)
return
F
.
conv2d
(
img
,
kernel
,
padding
=
0
).
view
(
b
,
c
,
h
,
w
)
# else:
# img = img.view(1, b * c, ph, pw)
# kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
# return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
def
usm_sharp
(
img
,
weight
=
0.5
,
radius
=
50
,
threshold
=
10
):
"""USM sharpening.
Input image: I; Blurry image: B.
1. sharp = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * sharp + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if
radius
%
2
==
0
:
radius
+=
1
blur
=
cv2
.
GaussianBlur
(
img
,
(
radius
,
radius
),
0
)
residual
=
img
-
blur
mask
=
np
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
astype
(
'float32'
)
soft_mask
=
cv2
.
GaussianBlur
(
mask
,
(
radius
,
radius
),
0
)
sharp
=
img
+
weight
*
residual
sharp
=
np
.
clip
(
sharp
,
0
,
1
)
return
soft_mask
*
sharp
+
(
1
-
soft_mask
)
*
img
class
USMSharp
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
radius
=
50
,
sigma
=
0
):
super
(
USMSharp
,
self
).
__init__
()
if
radius
%
2
==
0
:
radius
+=
1
self
.
radius
=
radius
kernel
=
cv2
.
getGaussianKernel
(
radius
,
sigma
)
kernel
=
torch
.
FloatTensor
(
np
.
dot
(
kernel
,
kernel
.
transpose
())).
unsqueeze_
(
0
)
self
.
register_buffer
(
'kernel'
,
kernel
)
def
forward
(
self
,
img
,
weight
=
0.5
,
threshold
=
10
):
blur
=
filter2D
(
img
,
self
.
kernel
)
residual
=
img
-
blur
mask
=
torch
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
float
()
soft_mask
=
filter2D
(
mask
,
self
.
kernel
)
sharp
=
img
+
weight
*
residual
sharp
=
torch
.
clip
(
sharp
,
0
,
1
)
return
soft_mask
*
sharp
+
(
1
-
soft_mask
)
*
img
\ No newline at end of file
utils_data/opensora/datasets/high_order/utils_blur.py
0 → 100644
View file @
1f5da520
import
cv2
import
math
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
random
import
torch
from
scipy
import
special
"""
#1.Generate kernel
bivariate_Gaussian: iostropic/anisoprotic
bivariate_generalized_Gaussian: iostropic/anisoprotic
bivariate_plateau: iostropic/anisoprotic
#2.Randomly select kernel
random_bivariate_Gaussian
random_bivariate_generalized_Gaussian
random_bivariate_plateau
#3.Randomly generate mixed kernels
random_mixed_kernels
#4.Generate blur kernels (used in the first/second degradation)
generate_kernel1
generate_kernel2
#5.Auxiliary utils
sigma_matrix2
mesh_grid
pdf2
circular_lowpass_kernel <--- sinc filter
"""
# -------------------------------------------------------------------- #
# --------------------------- Generate kernel ------------------------ #
# -------------------------------------------------------------------- #
def
bivariate_Gaussian
(
kernel_size
,
sig_x
,
sig_y
,
theta
,
grid
=
None
,
isotropic
=
True
):
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
isotropic (bool):
Returns:
kernel (ndarray): normalized kernel.
"""
if
grid
is
None
:
grid
,
_
,
_
=
mesh_grid
(
kernel_size
)
if
isotropic
:
sigma_matrix
=
np
.
array
([[
sig_x
**
2
,
0
],
[
0
,
sig_x
**
2
]])
else
:
sigma_matrix
=
sigma_matrix2
(
sig_x
,
sig_y
,
theta
)
kernel
=
pdf2
(
sigma_matrix
,
grid
)
kernel
=
kernel
/
np
.
sum
(
kernel
)
return
kernel
def
bivariate_generalized_Gaussian
(
kernel_size
,
sig_x
,
sig_y
,
theta
,
beta
,
grid
=
None
,
isotropic
=
True
):
"""Generate a bivariate generalized Gaussian kernel.
Described in `Parameter Estimation For Multivariate Generalized
Gaussian Distributions`_
by Pascal et. al (2013).
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
.. _Parameter Estimation For Multivariate Generalized Gaussian
Distributions: https://arxiv.org/abs/1302.6498
"""
if
grid
is
None
:
grid
,
_
,
_
=
mesh_grid
(
kernel_size
)
if
isotropic
:
sigma_matrix
=
np
.
array
([[
sig_x
**
2
,
0
],
[
0
,
sig_x
**
2
]])
else
:
sigma_matrix
=
sigma_matrix2
(
sig_x
,
sig_y
,
theta
)
inverse_sigma
=
np
.
linalg
.
inv
(
sigma_matrix
)
kernel
=
np
.
exp
(
-
0.5
*
np
.
power
(
np
.
sum
(
np
.
dot
(
grid
,
inverse_sigma
)
*
grid
,
2
),
beta
))
kernel
=
kernel
/
np
.
sum
(
kernel
)
return
kernel
def
bivariate_plateau
(
kernel_size
,
sig_x
,
sig_y
,
theta
,
beta
,
grid
=
None
,
isotropic
=
True
):
"""Generate a plateau-like anisotropic kernel.
1 / (1+x^(beta))
Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if
grid
is
None
:
grid
,
_
,
_
=
mesh_grid
(
kernel_size
)
if
isotropic
:
sigma_matrix
=
np
.
array
([[
sig_x
**
2
,
0
],
[
0
,
sig_x
**
2
]])
else
:
sigma_matrix
=
sigma_matrix2
(
sig_x
,
sig_y
,
theta
)
inverse_sigma
=
np
.
linalg
.
inv
(
sigma_matrix
)
kernel
=
np
.
reciprocal
(
np
.
power
(
np
.
sum
(
np
.
dot
(
grid
,
inverse_sigma
)
*
grid
,
2
),
beta
)
+
1
)
kernel
=
kernel
/
np
.
sum
(
kernel
)
return
kernel
# -------------------------------------------------------------------- #
# ---------------------------Random generate kernel ------------------ #
# -------------------------------------------------------------------- #
def
random_bivariate_Gaussian
(
kernel_size
,
sigma_x_range
,
sigma_y_range
,
rotation_range
,
noise_range
=
None
,
isotropic
=
True
):
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert
kernel_size
%
2
==
1
,
'Kernel size must be an odd number.'
assert
sigma_x_range
[
0
]
<
sigma_x_range
[
1
],
'Wrong sigma_x_range.'
sigma_x
=
np
.
random
.
uniform
(
sigma_x_range
[
0
],
sigma_x_range
[
1
])
if
isotropic
is
False
:
assert
sigma_y_range
[
0
]
<
sigma_y_range
[
1
],
'Wrong sigma_y_range.'
assert
rotation_range
[
0
]
<
rotation_range
[
1
],
'Wrong rotation_range.'
sigma_y
=
np
.
random
.
uniform
(
sigma_y_range
[
0
],
sigma_y_range
[
1
])
rotation
=
np
.
random
.
uniform
(
rotation_range
[
0
],
rotation_range
[
1
])
else
:
sigma_y
=
sigma_x
rotation
=
0
kernel
=
bivariate_Gaussian
(
kernel_size
,
sigma_x
,
sigma_y
,
rotation
,
isotropic
=
isotropic
)
# add multiplicative noise
if
noise_range
is
not
None
:
assert
noise_range
[
0
]
<
noise_range
[
1
],
'Wrong noise range.'
noise
=
np
.
random
.
uniform
(
noise_range
[
0
],
noise_range
[
1
],
size
=
kernel
.
shape
)
kernel
=
kernel
*
noise
kernel
=
kernel
/
np
.
sum
(
kernel
)
return
kernel
def
random_bivariate_generalized_Gaussian
(
kernel_size
,
sigma_x_range
,
sigma_y_range
,
rotation_range
,
beta_range
,
noise_range
=
None
,
isotropic
=
True
):
"""Randomly generate bivariate generalized Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert
kernel_size
%
2
==
1
,
'Kernel size must be an odd number.'
assert
sigma_x_range
[
0
]
<
sigma_x_range
[
1
],
'Wrong sigma_x_range.'
sigma_x
=
np
.
random
.
uniform
(
sigma_x_range
[
0
],
sigma_x_range
[
1
])
if
isotropic
is
False
:
assert
sigma_y_range
[
0
]
<
sigma_y_range
[
1
],
'Wrong sigma_y_range.'
assert
rotation_range
[
0
]
<
rotation_range
[
1
],
'Wrong rotation_range.'
sigma_y
=
np
.
random
.
uniform
(
sigma_y_range
[
0
],
sigma_y_range
[
1
])
rotation
=
np
.
random
.
uniform
(
rotation_range
[
0
],
rotation_range
[
1
])
else
:
sigma_y
=
sigma_x
rotation
=
0
# assume beta_range[0] < 1 < beta_range[1]
if
np
.
random
.
uniform
()
<
0.5
:
beta
=
np
.
random
.
uniform
(
beta_range
[
0
],
1
)
else
:
beta
=
np
.
random
.
uniform
(
1
,
beta_range
[
1
])
kernel
=
bivariate_generalized_Gaussian
(
kernel_size
,
sigma_x
,
sigma_y
,
rotation
,
beta
,
isotropic
=
isotropic
)
# add multiplicative noise
if
noise_range
is
not
None
:
assert
noise_range
[
0
]
<
noise_range
[
1
],
'Wrong noise range.'
noise
=
np
.
random
.
uniform
(
noise_range
[
0
],
noise_range
[
1
],
size
=
kernel
.
shape
)
kernel
=
kernel
*
noise
kernel
=
kernel
/
np
.
sum
(
kernel
)
return
kernel
def
random_bivariate_plateau
(
kernel_size
,
sigma_x_range
,
sigma_y_range
,
rotation_range
,
beta_range
,
noise_range
=
None
,
isotropic
=
True
):
"""Randomly generate bivariate plateau kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi/2, math.pi/2]
beta_range (tuple): [1, 4]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert
kernel_size
%
2
==
1
,
'Kernel size must be an odd number.'
assert
sigma_x_range
[
0
]
<
sigma_x_range
[
1
],
'Wrong sigma_x_range.'
sigma_x
=
np
.
random
.
uniform
(
sigma_x_range
[
0
],
sigma_x_range
[
1
])
if
isotropic
is
False
:
assert
sigma_y_range
[
0
]
<
sigma_y_range
[
1
],
'Wrong sigma_y_range.'
assert
rotation_range
[
0
]
<
rotation_range
[
1
],
'Wrong rotation_range.'
sigma_y
=
np
.
random
.
uniform
(
sigma_y_range
[
0
],
sigma_y_range
[
1
])
rotation
=
np
.
random
.
uniform
(
rotation_range
[
0
],
rotation_range
[
1
])
else
:
sigma_y
=
sigma_x
rotation
=
0
# TODO: this may be not proper
if
np
.
random
.
uniform
()
<
0.5
:
beta
=
np
.
random
.
uniform
(
beta_range
[
0
],
1
)
else
:
beta
=
np
.
random
.
uniform
(
1
,
beta_range
[
1
])
kernel
=
bivariate_plateau
(
kernel_size
,
sigma_x
,
sigma_y
,
rotation
,
beta
,
isotropic
=
isotropic
)
# add multiplicative noise
if
noise_range
is
not
None
:
assert
noise_range
[
0
]
<
noise_range
[
1
],
'Wrong noise range.'
noise
=
np
.
random
.
uniform
(
noise_range
[
0
],
noise_range
[
1
],
size
=
kernel
.
shape
)
kernel
=
kernel
*
noise
kernel
=
kernel
/
np
.
sum
(
kernel
)
return
kernel
# -------------------------------------------------------------------- #
# ---------------- Randomly generate mixed kernels ------------------- #
# -------------------------------------------------------------------- #
def
random_mixed_kernels
(
kernel_list
,
kernel_prob
,
kernel_size
=
21
,
sigma_x_range
=
(
0.6
,
5
),
sigma_y_range
=
(
0.6
,
5
),
rotation_range
=
(
-
math
.
pi
,
math
.
pi
),
betag_range
=
(
0.5
,
8
),
betap_range
=
(
0.5
,
8
),
noise_range
=
None
):
"""Randomly generate mixed kernels.
Args:
kernel_list (tuple): a list name of kernel types,
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
'plateau_aniso']
kernel_prob (tuple): corresponding kernel probability for each
kernel type
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
kernel_type
=
random
.
choices
(
kernel_list
,
kernel_prob
)[
0
]
if
kernel_type
==
'iso'
:
kernel
=
random_bivariate_Gaussian
(
kernel_size
,
sigma_x_range
,
sigma_y_range
,
rotation_range
,
noise_range
=
noise_range
,
isotropic
=
True
)
elif
kernel_type
==
'aniso'
:
kernel
=
random_bivariate_Gaussian
(
kernel_size
,
sigma_x_range
,
sigma_y_range
,
rotation_range
,
noise_range
=
noise_range
,
isotropic
=
False
)
elif
kernel_type
==
'generalized_iso'
:
kernel
=
random_bivariate_generalized_Gaussian
(
kernel_size
,
sigma_x_range
,
sigma_y_range
,
rotation_range
,
betag_range
,
noise_range
=
noise_range
,
isotropic
=
True
)
elif
kernel_type
==
'generalized_aniso'
:
kernel
=
random_bivariate_generalized_Gaussian
(
kernel_size
,
sigma_x_range
,
sigma_y_range
,
rotation_range
,
betag_range
,
noise_range
=
noise_range
,
isotropic
=
False
)
elif
kernel_type
==
'plateau_iso'
:
kernel
=
random_bivariate_plateau
(
kernel_size
,
sigma_x_range
,
sigma_y_range
,
rotation_range
,
betap_range
,
noise_range
=
None
,
isotropic
=
True
)
elif
kernel_type
==
'plateau_aniso'
:
kernel
=
random_bivariate_plateau
(
kernel_size
,
sigma_x_range
,
sigma_y_range
,
rotation_range
,
betap_range
,
noise_range
=
None
,
isotropic
=
False
)
return
kernel
# -------------------------------------------------------------------- #
# ----Generate blur kernels (used in the first/second degradation)---- #
# -------------------------------------------------------------------- #
def
generate_kernel1
(
kernel_range
,
sinc_prob
,
kernel_list
,
kernel_prob
,
blur_sigma
,
betag_range
,
betap_range
,
):
kernel_size
=
random
.
choice
(
kernel_range
)
if
np
.
random
.
uniform
()
<
sinc_prob
:
# this sinc filter setting is for kernels ranging from [7, 21]
if
kernel_size
<
13
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
else
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
5
,
np
.
pi
)
kernel
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
False
)
else
:
kernel
=
random_mixed_kernels
(
kernel_list
,
kernel_prob
,
kernel_size
,
blur_sigma
,
blur_sigma
,
[
-
math
.
pi
,
math
.
pi
],
# Rotation angle
betag_range
,
betap_range
,
noise_range
=
None
,
)
# pad kernel
pad_size
=
(
21
-
kernel_size
)
//
2
kernel
=
np
.
pad
(
kernel
,
((
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)))
return
kernel
def
generate_kernel2
(
kernel_range
,
sinc_prob2
,
kernel_list2
,
kernel_prob2
,
blur_sigma2
,
betag_range2
,
betap_range2
,):
kernel_size
=
random
.
choice
(
kernel_range
)
if
np
.
random
.
uniform
()
<
sinc_prob2
:
if
kernel_size
<
13
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
else
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
5
,
np
.
pi
)
kernel2
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
False
)
else
:
kernel2
=
random_mixed_kernels
(
kernel_list2
,
kernel_prob2
,
kernel_size
,
blur_sigma2
,
blur_sigma2
,
[
-
math
.
pi
,
math
.
pi
],
betag_range2
,
betap_range2
,
noise_range
=
None
,
)
# pad kernel
pad_size
=
(
21
-
kernel_size
)
//
2
kernel
=
np
.
pad
(
kernel2
,
((
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)))
return
kernel
def
generate_sinc_kernel
(
kernel_range
):
kernel_size
=
random
.
choice
(
kernel_range
)
if
kernel_size
<
13
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
else
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
5
,
np
.
pi
)
sinc_kernel
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
False
)
return
sinc_kernel
# -------------------------------------------------------------------- #
# --------------------------- 辅助工具 -------------------------------- #
# -------------------------------------------------------------------- #
def
sigma_matrix2
(
sig_x
,
sig_y
,
theta
):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
d_matrix
=
np
.
array
([[
sig_x
**
2
,
0
],
[
0
,
sig_y
**
2
]])
u_matrix
=
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]])
return
np
.
dot
(
u_matrix
,
np
.
dot
(
d_matrix
,
u_matrix
.
T
))
def
mesh_grid
(
kernel_size
):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax
=
np
.
arange
(
-
kernel_size
//
2
+
1.
,
kernel_size
//
2
+
1.
)
xx
,
yy
=
np
.
meshgrid
(
ax
,
ax
)
xy
=
np
.
hstack
((
xx
.
reshape
((
kernel_size
*
kernel_size
,
1
)),
yy
.
reshape
(
kernel_size
*
kernel_size
,
1
))).
reshape
(
kernel_size
,
kernel_size
,
2
)
return
xy
,
xx
,
yy
def
pdf2
(
sigma_matrix
,
grid
):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma
=
np
.
linalg
.
inv
(
sigma_matrix
)
kernel
=
np
.
exp
(
-
0.5
*
np
.
sum
(
np
.
dot
(
grid
,
inverse_sigma
)
*
grid
,
2
))
return
kernel
def
circular_lowpass_kernel
(
cutoff
,
kernel_size
,
pad_to
=
0
):
"""2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
Args:
cutoff (float): cutoff frequency in radians (pi is max)
kernel_size (int): horizontal and vertical size, must be odd.
pad_to (int): pad kernel size to desired size, must be odd or zero.
"""
assert
kernel_size
%
2
==
1
,
'Kernel size must be an odd number.'
kernel
=
np
.
fromfunction
(
lambda
x
,
y
:
cutoff
*
special
.
j1
(
cutoff
*
np
.
sqrt
(
(
x
-
(
kernel_size
-
1
)
/
2
)
**
2
+
(
y
-
(
kernel_size
-
1
)
/
2
)
**
2
))
/
(
2
*
np
.
pi
*
np
.
sqrt
(
(
x
-
(
kernel_size
-
1
)
/
2
)
**
2
+
(
y
-
(
kernel_size
-
1
)
/
2
)
**
2
)),
[
kernel_size
,
kernel_size
])
kernel
[(
kernel_size
-
1
)
//
2
,
(
kernel_size
-
1
)
//
2
]
=
cutoff
**
2
/
(
4
*
np
.
pi
)
kernel
=
kernel
/
np
.
sum
(
kernel
)
if
pad_to
>
kernel_size
:
pad_size
=
(
pad_to
-
kernel_size
)
//
2
kernel
=
np
.
pad
(
kernel
,
((
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)))
return
kernel
if
__name__
==
'__main__'
:
blur_kernel_size
=
21
kernel_range
=
[
2
*
v
+
1
for
v
in
range
(
3
,
11
)]
# kernel size ranges from 7 to 21
kernel_list
=
[
'iso'
,
'aniso'
,
'generalized_iso'
,
'generalized_aniso'
,
'plateau_iso'
,
'plateau_aniso'
]
kernel_prob
=
[
0.45
,
0.25
,
0.12
,
0.03
,
0.12
,
0.03
]
sinc_prob
=
0.1
blur_sigma
=
[
0.2
,
3
]
# blur_x / y_sigma
betag_range
=
[
0.5
,
4
]
betap_range
=
[
1
,
2
]
kernel
=
generate_kernel1
(
kernel_range
,
sinc_prob
,
kernel_list
,
kernel_prob
,
blur_sigma
,
betag_range
,
betap_range
)
print
(
kernel
.
shape
)
img
=
cv2
.
imread
(
'../qj.png'
)
img
=
np
.
float32
(
img
/
255.
)
img_blur
=
cv2
.
filter2D
(
img
,
-
1
,
kernel
)
#img_blur=img_blur[:,:,::-1]
img_blur
=
np
.
uint8
((
img_blur
.
clip
(
0
,
1
)
*
255.
).
round
())
cv2
.
imwrite
(
'blur2.png'
,
img_blur
)
# plt.imshow(img_blur)
# plt.show()
utils_data/opensora/datasets/high_order/utils_jpeg.py
0 → 100644
View file @
1f5da520
import
random
import
itertools
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
cv2
import
torch
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
class
DiffJPEG
(
nn
.
Module
):
"""This JPEG algorithm result is slightly different from cv2.
DiffJPEG supports batch processing.
Args:
differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
"""
def
__init__
(
self
,
differentiable
=
True
):
super
(
DiffJPEG
,
self
).
__init__
()
if
differentiable
:
rounding
=
diff_round
else
:
rounding
=
torch
.
round
self
.
compress
=
CompressJpeg
(
rounding
=
rounding
)
self
.
decompress
=
DeCompressJpeg
(
rounding
=
rounding
)
def
forward
(
self
,
x
,
quality
):
"""
Args:
x (Tensor): Input image, bchw, rgb, [0, 1]
quality(float): Quality factor for jpeg compression scheme.
"""
factor
=
quality
if
isinstance
(
factor
,
(
int
,
float
)):
factor
=
quality_to_factor
(
factor
)
else
:
for
i
in
range
(
factor
.
size
(
0
)):
factor
[
i
]
=
quality_to_factor
(
factor
[
i
])
h
,
w
=
x
.
size
()[
-
2
:]
h_pad
,
w_pad
=
0
,
0
# why should use 16
if
h
%
16
!=
0
:
h_pad
=
16
-
h
%
16
if
w
%
16
!=
0
:
w_pad
=
16
-
w
%
16
x
=
F
.
pad
(
x
,
(
0
,
w_pad
,
0
,
h_pad
),
mode
=
'constant'
,
value
=
0
)
y
,
cb
,
cr
=
self
.
compress
(
x
,
factor
=
factor
)
recovered
=
self
.
decompress
(
y
,
cb
,
cr
,
(
h
+
h_pad
),
(
w
+
w_pad
),
factor
=
factor
)
recovered
=
recovered
[:,
:,
0
:
h
,
0
:
w
]
return
recovered
#----------------------Compression----------------------#
class
CompressJpeg
(
nn
.
Module
):
"""Full JPEG compression algorithm
Args:
rounding(function): rounding function to use
"""
def
__init__
(
self
,
rounding
=
torch
.
round
):
super
(
CompressJpeg
,
self
).
__init__
()
self
.
l1
=
nn
.
Sequential
(
RGB2YCbCrJpeg
(),
ChromaSubsampling
())
self
.
l2
=
nn
.
Sequential
(
BlockSplitting
(),
DCT8x8
())
self
.
c_quantize
=
CQuantize
(
rounding
=
rounding
)
self
.
y_quantize
=
YQuantize
(
rounding
=
rounding
)
def
forward
(
self
,
image
,
factor
=
1
):
"""
Args:
image(tensor): batch x 3 x height x width
Returns:
dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
"""
y
,
cb
,
cr
=
self
.
l1
(
image
*
255
)
components
=
{
'y'
:
y
,
'cb'
:
cb
,
'cr'
:
cr
}
for
k
in
components
.
keys
():
comp
=
self
.
l2
(
components
[
k
])
if
k
in
(
'cb'
,
'cr'
):
comp
=
self
.
c_quantize
(
comp
,
factor
=
factor
)
else
:
comp
=
self
.
y_quantize
(
comp
,
factor
=
factor
)
components
[
k
]
=
comp
return
components
[
'y'
],
components
[
'cb'
],
components
[
'cr'
]
class
RGB2YCbCrJpeg
(
nn
.
Module
):
""" Converts RGB image to YCbCr
"""
def
__init__
(
self
):
super
(
RGB2YCbCrJpeg
,
self
).
__init__
()
matrix
=
np
.
array
([[
0.299
,
0.587
,
0.114
],
[
-
0.168736
,
-
0.331264
,
0.5
],
[
0.5
,
-
0.418688
,
-
0.081312
]],
dtype
=
np
.
float32
).
T
self
.
shift
=
nn
.
Parameter
(
torch
.
tensor
([
0.
,
128.
,
128.
]))
self
.
matrix
=
nn
.
Parameter
(
torch
.
from_numpy
(
matrix
))
def
forward
(
self
,
image
):
"""
Args:
image(Tensor): batch x 3 x height x width
Returns:
Tensor: batch x height x width x 3
"""
image
=
image
.
permute
(
0
,
2
,
3
,
1
)
result
=
torch
.
tensordot
(
image
,
self
.
matrix
,
dims
=
1
)
+
self
.
shift
return
result
.
view
(
image
.
shape
)
class
ChromaSubsampling
(
nn
.
Module
):
""" Chroma subsampling on CbCr channels
"""
def
__init__
(
self
):
super
(
ChromaSubsampling
,
self
).
__init__
()
def
forward
(
self
,
image
):
"""
Args:
image(tensor): batch x height x width x 3
Returns:
y(tensor): batch x height x width
cb(tensor): batch x height/2 x width/2
cr(tensor): batch x height/2 x width/2
"""
image_2
=
image
.
permute
(
0
,
3
,
1
,
2
).
clone
()
cb
=
F
.
avg_pool2d
(
image_2
[:,
1
,
:,
:].
unsqueeze
(
1
),
kernel_size
=
2
,
stride
=
(
2
,
2
),
count_include_pad
=
False
)
cr
=
F
.
avg_pool2d
(
image_2
[:,
2
,
:,
:].
unsqueeze
(
1
),
kernel_size
=
2
,
stride
=
(
2
,
2
),
count_include_pad
=
False
)
cb
=
cb
.
permute
(
0
,
2
,
3
,
1
)
cr
=
cr
.
permute
(
0
,
2
,
3
,
1
)
return
image
[:,
:,
:,
0
],
cb
.
squeeze
(
3
),
cr
.
squeeze
(
3
)
class
BlockSplitting
(
nn
.
Module
):
""" Splitting image into patches
"""
def
__init__
(
self
):
super
(
BlockSplitting
,
self
).
__init__
()
self
.
k
=
8
def
forward
(
self
,
image
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x h*w/64 x h x w
"""
height
,
_
=
image
.
shape
[
1
:
3
]
batch_size
=
image
.
shape
[
0
]
image_reshaped
=
image
.
view
(
batch_size
,
height
//
self
.
k
,
self
.
k
,
-
1
,
self
.
k
)
image_transposed
=
image_reshaped
.
permute
(
0
,
1
,
3
,
2
,
4
)
return
image_transposed
.
contiguous
().
view
(
batch_size
,
-
1
,
self
.
k
,
self
.
k
)
class
DCT8x8
(
nn
.
Module
):
""" Discrete Cosine Transformation
"""
def
__init__
(
self
):
super
(
DCT8x8
,
self
).
__init__
()
tensor
=
np
.
zeros
((
8
,
8
,
8
,
8
),
dtype
=
np
.
float32
)
for
x
,
y
,
u
,
v
in
itertools
.
product
(
range
(
8
),
repeat
=
4
):
tensor
[
x
,
y
,
u
,
v
]
=
np
.
cos
((
2
*
x
+
1
)
*
u
*
np
.
pi
/
16
)
*
np
.
cos
((
2
*
y
+
1
)
*
v
*
np
.
pi
/
16
)
alpha
=
np
.
array
([
1.
/
np
.
sqrt
(
2
)]
+
[
1
]
*
7
)
self
.
tensor
=
nn
.
Parameter
(
torch
.
from_numpy
(
tensor
).
float
())
self
.
scale
=
nn
.
Parameter
(
torch
.
from_numpy
(
np
.
outer
(
alpha
,
alpha
)
*
0.25
).
float
())
def
forward
(
self
,
image
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
image
=
image
-
128
result
=
self
.
scale
*
torch
.
tensordot
(
image
,
self
.
tensor
,
dims
=
2
)
result
.
view
(
image
.
shape
)
return
result
class
YQuantize
(
nn
.
Module
):
""" JPEG Quantization for Y channel
Args:
rounding(function): rounding function to use
"""
def
__init__
(
self
,
rounding
):
super
(
YQuantize
,
self
).
__init__
()
self
.
rounding
=
rounding
self
.
y_table
=
y_table
def
forward
(
self
,
image
,
factor
=
1
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if
isinstance
(
factor
,
(
int
,
float
)):
image
=
image
.
float
()
/
(
self
.
y_table
*
factor
)
else
:
b
=
factor
.
size
(
0
)
table
=
self
.
y_table
.
expand
(
b
,
1
,
8
,
8
)
*
factor
.
view
(
b
,
1
,
1
,
1
)
image
=
image
.
float
()
/
table
image
=
self
.
rounding
(
image
)
return
image
class
CQuantize
(
nn
.
Module
):
""" JPEG Quantization for CbCr channels
Args:
rounding(function): rounding function to use
"""
def
__init__
(
self
,
rounding
):
super
(
CQuantize
,
self
).
__init__
()
self
.
rounding
=
rounding
self
.
c_table
=
c_table
def
forward
(
self
,
image
,
factor
=
1
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if
isinstance
(
factor
,
(
int
,
float
)):
image
=
image
.
float
()
/
(
self
.
c_table
*
factor
)
else
:
b
=
factor
.
size
(
0
)
table
=
self
.
c_table
.
expand
(
b
,
1
,
8
,
8
)
*
factor
.
view
(
b
,
1
,
1
,
1
)
image
=
image
.
float
()
/
table
image
=
self
.
rounding
(
image
)
return
image
#----------------------Decompression----------------------#
class
DeCompressJpeg
(
nn
.
Module
):
"""Full JPEG decompression algorithm
Args:
rounding(function): rounding function to use
"""
def
__init__
(
self
,
rounding
=
torch
.
round
):
super
(
DeCompressJpeg
,
self
).
__init__
()
self
.
c_dequantize
=
CDequantize
()
self
.
y_dequantize
=
YDequantize
()
self
.
idct
=
iDCT8x8
()
self
.
merging
=
BlockMerging
()
self
.
chroma
=
ChromaUpsampling
()
self
.
colors
=
YCbCr2RGBJpeg
()
def
forward
(
self
,
y
,
cb
,
cr
,
imgh
,
imgw
,
factor
=
1
):
"""
Args:
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
imgh(int)
imgw(int)
factor(float)
Returns:
Tensor: batch x 3 x height x width
"""
components
=
{
'y'
:
y
,
'cb'
:
cb
,
'cr'
:
cr
}
for
k
in
components
.
keys
():
if
k
in
(
'cb'
,
'cr'
):
comp
=
self
.
c_dequantize
(
components
[
k
],
factor
=
factor
)
height
,
width
=
int
(
imgh
/
2
),
int
(
imgw
/
2
)
else
:
comp
=
self
.
y_dequantize
(
components
[
k
],
factor
=
factor
)
height
,
width
=
imgh
,
imgw
comp
=
self
.
idct
(
comp
)
components
[
k
]
=
self
.
merging
(
comp
,
height
,
width
)
#
image
=
self
.
chroma
(
components
[
'y'
],
components
[
'cb'
],
components
[
'cr'
])
image
=
self
.
colors
(
image
)
image
=
torch
.
min
(
255
*
torch
.
ones_like
(
image
),
torch
.
max
(
torch
.
zeros_like
(
image
),
image
))
return
image
/
255
class
YDequantize
(
nn
.
Module
):
"""Dequantize Y channel
"""
def
__init__
(
self
):
super
(
YDequantize
,
self
).
__init__
()
self
.
y_table
=
y_table
def
forward
(
self
,
image
,
factor
=
1
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if
isinstance
(
factor
,
(
int
,
float
)):
out
=
image
*
(
self
.
y_table
*
factor
)
else
:
b
=
factor
.
size
(
0
)
table
=
self
.
y_table
.
expand
(
b
,
1
,
8
,
8
)
*
factor
.
view
(
b
,
1
,
1
,
1
)
out
=
image
*
table
return
out
class
CDequantize
(
nn
.
Module
):
"""Dequantize CbCr channel
"""
def
__init__
(
self
):
super
(
CDequantize
,
self
).
__init__
()
self
.
c_table
=
c_table
def
forward
(
self
,
image
,
factor
=
1
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if
isinstance
(
factor
,
(
int
,
float
)):
out
=
image
*
(
self
.
c_table
*
factor
)
else
:
b
=
factor
.
size
(
0
)
table
=
self
.
c_table
.
expand
(
b
,
1
,
8
,
8
)
*
factor
.
view
(
b
,
1
,
1
,
1
)
out
=
image
*
table
return
out
class
iDCT8x8
(
nn
.
Module
):
"""Inverse discrete Cosine Transformation
"""
def
__init__
(
self
):
super
(
iDCT8x8
,
self
).
__init__
()
alpha
=
np
.
array
([
1.
/
np
.
sqrt
(
2
)]
+
[
1
]
*
7
)
self
.
alpha
=
nn
.
Parameter
(
torch
.
from_numpy
(
np
.
outer
(
alpha
,
alpha
)).
float
())
tensor
=
np
.
zeros
((
8
,
8
,
8
,
8
),
dtype
=
np
.
float32
)
for
x
,
y
,
u
,
v
in
itertools
.
product
(
range
(
8
),
repeat
=
4
):
tensor
[
x
,
y
,
u
,
v
]
=
np
.
cos
((
2
*
u
+
1
)
*
x
*
np
.
pi
/
16
)
*
np
.
cos
((
2
*
v
+
1
)
*
y
*
np
.
pi
/
16
)
self
.
tensor
=
nn
.
Parameter
(
torch
.
from_numpy
(
tensor
).
float
())
def
forward
(
self
,
image
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
image
=
image
*
self
.
alpha
result
=
0.25
*
torch
.
tensordot
(
image
,
self
.
tensor
,
dims
=
2
)
+
128
result
.
view
(
image
.
shape
)
return
result
class
BlockMerging
(
nn
.
Module
):
"""Merge patches into image
"""
def
__init__
(
self
):
super
(
BlockMerging
,
self
).
__init__
()
def
forward
(
self
,
patches
,
height
,
width
):
"""
Args:
patches(tensor) batch x height*width/64, height x width
height(int)
width(int)
Returns:
Tensor: batch x height x width
"""
k
=
8
batch_size
=
patches
.
shape
[
0
]
image_reshaped
=
patches
.
view
(
batch_size
,
height
//
k
,
width
//
k
,
k
,
k
)
image_transposed
=
image_reshaped
.
permute
(
0
,
1
,
3
,
2
,
4
)
return
image_transposed
.
contiguous
().
view
(
batch_size
,
height
,
width
)
class
ChromaUpsampling
(
nn
.
Module
):
"""Upsample chroma layers
"""
def
__init__
(
self
):
super
(
ChromaUpsampling
,
self
).
__init__
()
def
forward
(
self
,
y
,
cb
,
cr
):
"""
Args:
y(tensor): y channel image
cb(tensor): cb channel
cr(tensor): cr channel
Returns:
Tensor: batch x height x width x 3
"""
def
repeat
(
x
,
k
=
2
):
height
,
width
=
x
.
shape
[
1
:
3
]
x
=
x
.
unsqueeze
(
-
1
)
x
=
x
.
repeat
(
1
,
1
,
k
,
k
)
x
=
x
.
view
(
-
1
,
height
*
k
,
width
*
k
)
return
x
cb
=
repeat
(
cb
)
cr
=
repeat
(
cr
)
return
torch
.
cat
([
y
.
unsqueeze
(
3
),
cb
.
unsqueeze
(
3
),
cr
.
unsqueeze
(
3
)],
dim
=
3
)
class
YCbCr2RGBJpeg
(
nn
.
Module
):
"""Converts YCbCr image to RGB JPEG
"""
def
__init__
(
self
):
super
(
YCbCr2RGBJpeg
,
self
).
__init__
()
matrix
=
np
.
array
([[
1.
,
0.
,
1.402
],
[
1
,
-
0.344136
,
-
0.714136
],
[
1
,
1.772
,
0
]],
dtype
=
np
.
float32
).
T
self
.
shift
=
nn
.
Parameter
(
torch
.
tensor
([
0
,
-
128.
,
-
128.
]))
self
.
matrix
=
nn
.
Parameter
(
torch
.
from_numpy
(
matrix
))
def
forward
(
self
,
image
):
"""
Args:
image(tensor): batch x height x width x 3
Returns:
Tensor: batch x 3 x height x width
"""
result
=
torch
.
tensordot
(
image
+
self
.
shift
,
self
.
matrix
,
dims
=
1
)
return
result
.
view
(
image
.
shape
).
permute
(
0
,
3
,
1
,
2
)
# ------------------------ utils ------------------------#
y_table
=
np
.
array
(
[[
16
,
11
,
10
,
16
,
24
,
40
,
51
,
61
],
[
12
,
12
,
14
,
19
,
26
,
58
,
60
,
55
],
[
14
,
13
,
16
,
24
,
40
,
57
,
69
,
56
],
[
14
,
17
,
22
,
29
,
51
,
87
,
80
,
62
],
[
18
,
22
,
37
,
56
,
68
,
109
,
103
,
77
],
[
24
,
35
,
55
,
64
,
81
,
104
,
113
,
92
],
[
49
,
64
,
78
,
87
,
103
,
121
,
120
,
101
],
[
72
,
92
,
95
,
98
,
112
,
100
,
103
,
99
]],
dtype
=
np
.
float32
).
T
y_table
=
nn
.
Parameter
(
torch
.
from_numpy
(
y_table
))
c_table
=
np
.
empty
((
8
,
8
),
dtype
=
np
.
float32
)
c_table
.
fill
(
99
)
c_table
[:
4
,
:
4
]
=
np
.
array
([[
17
,
18
,
24
,
47
],
[
18
,
21
,
26
,
66
],
[
24
,
26
,
56
,
99
],
[
47
,
66
,
99
,
99
]]).
T
c_table
=
nn
.
Parameter
(
torch
.
from_numpy
(
c_table
))
def
diff_round
(
x
):
""" Differentiable rounding function
"""
return
torch
.
round
(
x
)
+
(
x
-
torch
.
round
(
x
))
**
3
def
quality_to_factor
(
quality
):
""" Calculate factor corresponding to quality
Args:
quality(float): Quality for jpeg compression.
Returns:
float: Compression factor.
"""
if
quality
<
50
:
quality
=
5000.
/
quality
else
:
quality
=
200.
-
quality
*
2
return
quality
/
100.
if
__name__
==
'__main__'
:
def
uint2single
(
img
):
# uint8 [0,255] -> float32 [0.,1.]
return
np
.
float32
(
img
/
255.
)
def
single2uint
(
img
):
return
np
.
uint8
((
img
.
clip
(
0
,
1
)
*
255.
).
round
())
jpeg_range2
=
[
30
,
95
]
img
=
cv2
.
imread
(
'../qj.png'
)
img
=
uint2single
(
img
)
img_jpeg
=
random_add_jpg_compression
(
img
)
img_jpeg
=
single2uint
(
img_jpeg
)
#img_jpeg=random_add_jpg_compression(img,[30,95])
img_jpeg
=
img_jpeg
[:,:,::
-
1
]
plt
.
imshow
(
img_jpeg
)
plt
.
show
()
\ No newline at end of file
Prev
1
…
6
7
8
9
10
11
12
13
14
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment