Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
S3Diff
Commits
76b9024b
"mmdet3d/structures/bbox_3d/utils.py" did not exist on "21cb2aa6fb3b6a086d10633b0ba46c4f6e340174"
Commit
76b9024b
authored
Dec 05, 2025
by
yangzhong
Browse files
git init
parents
Pipeline
#3145
failed with stages
in 0 seconds
Changes
281
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
120 additions
and
0 deletions
+120
-0
utils/wavelet_color.py
utils/wavelet_color.py
+120
-0
No files found.
utils/wavelet_color.py
0 → 100644
View file @
76b9024b
'''
# --------------------------------------------------------------------------------
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
# --------------------------------------------------------------------------------
'''
import
torch
from
PIL
import
Image
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
from
torchvision.transforms
import
ToTensor
,
ToPILImage
def
adain_color_fix
(
target
:
Image
,
source
:
Image
):
# Convert images to tensors
to_tensor
=
ToTensor
()
target_tensor
=
to_tensor
(
target
).
unsqueeze
(
0
)
source_tensor
=
to_tensor
(
source
).
unsqueeze
(
0
)
# Apply adaptive instance normalization
result_tensor
=
adaptive_instance_normalization
(
target_tensor
,
source_tensor
)
# Convert tensor back to image
to_image
=
ToPILImage
()
result_image
=
to_image
(
result_tensor
.
squeeze
(
0
).
clamp_
(
0.0
,
1.0
))
return
result_image
def
wavelet_color_fix
(
target
:
Image
,
source
:
Image
):
# Convert images to tensors
to_tensor
=
ToTensor
()
target_tensor
=
to_tensor
(
target
).
unsqueeze
(
0
)
source_tensor
=
to_tensor
(
source
).
unsqueeze
(
0
)
# Apply wavelet reconstruction
result_tensor
=
wavelet_reconstruction
(
target_tensor
,
source_tensor
)
# Convert tensor back to image
to_image
=
ToPILImage
()
result_image
=
to_image
(
result_tensor
.
squeeze
(
0
).
clamp_
(
0.0
,
1.0
))
return
result_image
def
calc_mean_std
(
feat
:
Tensor
,
eps
=
1e-5
):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size
=
feat
.
size
()
assert
len
(
size
)
==
4
,
'The input feature should be 4D tensor.'
b
,
c
=
size
[:
2
]
feat_var
=
feat
.
reshape
(
b
,
c
,
-
1
).
var
(
dim
=
2
)
+
eps
feat_std
=
feat_var
.
sqrt
().
reshape
(
b
,
c
,
1
,
1
)
feat_mean
=
feat
.
reshape
(
b
,
c
,
-
1
).
mean
(
dim
=
2
).
reshape
(
b
,
c
,
1
,
1
)
return
feat_mean
,
feat_std
def
adaptive_instance_normalization
(
content_feat
:
Tensor
,
style_feat
:
Tensor
):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size
=
content_feat
.
size
()
style_mean
,
style_std
=
calc_mean_std
(
style_feat
)
content_mean
,
content_std
=
calc_mean_std
(
content_feat
)
normalized_feat
=
(
content_feat
-
content_mean
.
expand
(
size
))
/
content_std
.
expand
(
size
)
return
normalized_feat
*
style_std
.
expand
(
size
)
+
style_mean
.
expand
(
size
)
def
wavelet_blur
(
image
:
Tensor
,
radius
:
int
):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals
=
[
[
0.0625
,
0.125
,
0.0625
],
[
0.125
,
0.25
,
0.125
],
[
0.0625
,
0.125
,
0.0625
],
]
kernel
=
torch
.
tensor
(
kernel_vals
,
dtype
=
image
.
dtype
,
device
=
image
.
device
)
# add channel dimensions to the kernel to make it a 4D tensor
kernel
=
kernel
[
None
,
None
]
# repeat the kernel across all input channels
kernel
=
kernel
.
repeat
(
3
,
1
,
1
,
1
)
image
=
F
.
pad
(
image
,
(
radius
,
radius
,
radius
,
radius
),
mode
=
'replicate'
)
# apply convolution
output
=
F
.
conv2d
(
image
,
kernel
,
groups
=
3
,
dilation
=
radius
)
return
output
def
wavelet_decomposition
(
image
:
Tensor
,
levels
=
5
):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq
=
torch
.
zeros_like
(
image
)
for
i
in
range
(
levels
):
radius
=
2
**
i
low_freq
=
wavelet_blur
(
image
,
radius
)
high_freq
+=
(
image
-
low_freq
)
image
=
low_freq
return
high_freq
,
low_freq
def
wavelet_reconstruction
(
content_feat
:
Tensor
,
style_feat
:
Tensor
):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq
,
content_low_freq
=
wavelet_decomposition
(
content_feat
)
del
content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq
,
style_low_freq
=
wavelet_decomposition
(
style_feat
)
del
style_high_freq
# reconstruct the content feature with the style's high frequency
return
content_high_freq
+
style_low_freq
\ No newline at end of file
Prev
1
…
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment