Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
IDM-VTON_pytorch
Commits
d5096d86
Commit
d5096d86
authored
Jun 14, 2024
by
mashun1
Browse files
idmvton
parents
Pipeline
#1220
canceled with stages
Changes
292
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
460 additions
and
0 deletions
+460
-0
BasicSR/tests/data/lq.lmdb/lock.mdb
BasicSR/tests/data/lq.lmdb/lock.mdb
+0
-0
BasicSR/tests/data/lq.lmdb/meta_info.txt
BasicSR/tests/data/lq.lmdb/meta_info.txt
+2
-0
BasicSR/tests/data/meta_info_gt.txt
BasicSR/tests/data/meta_info_gt.txt
+2
-0
BasicSR/tests/data/meta_info_pair.txt
BasicSR/tests/data/meta_info_pair.txt
+2
-0
BasicSR/tests/test_archs/test_basicvsr_arch.py
BasicSR/tests/test_archs/test_basicvsr_arch.py
+41
-0
BasicSR/tests/test_archs/test_discriminator_arch.py
BasicSR/tests/test_archs/test_discriminator_arch.py
+29
-0
BasicSR/tests/test_archs/test_duf_arch.py
BasicSR/tests/test_archs/test_duf_arch.py
+45
-0
BasicSR/tests/test_archs/test_ecbsr_arch.py
BasicSR/tests/test_archs/test_ecbsr_arch.py
+89
-0
BasicSR/tests/test_archs/test_srresnet_arch.py
BasicSR/tests/test_archs/test_srresnet_arch.py
+19
-0
BasicSR/tests/test_data/test_paired_image_dataset.py
BasicSR/tests/test_data/test_paired_image_dataset.py
+97
-0
BasicSR/tests/test_data/test_single_image_dataset.py
BasicSR/tests/test_data/test_single_image_dataset.py
+62
-0
BasicSR/tests/test_losses/test_losses.py
BasicSR/tests/test_losses/test_losses.py
+72
-0
No files found.
Too many changes to show.
To preserve performance only
292 of 292+
files are displayed.
Plain diff
Email patch
BasicSR/tests/data/lq.lmdb/lock.mdb
0 → 100644
View file @
d5096d86
File added
BasicSR/tests/data/lq.lmdb/meta_info.txt
0 → 100644
View file @
d5096d86
baboon.png (120,123,3) 1
comic.png (80,60,3) 1
BasicSR/tests/data/meta_info_gt.txt
0 → 100644
View file @
d5096d86
baboon.png
comic.png
BasicSR/tests/data/meta_info_pair.txt
0 → 100644
View file @
d5096d86
gt/baboon.png, lq/baboon.png
gt/comic.png, lq/comic.png
BasicSR/tests/test_archs/test_basicvsr_arch.py
0 → 100644
View file @
d5096d86
import
torch
from
basicsr.archs.basicvsr_arch
import
BasicVSR
,
ConvResidualBlocks
,
IconVSR
def
test_basicvsr
():
"""Test arch: BasicVSR."""
# model init and forward
net
=
BasicVSR
(
num_feat
=
12
,
num_block
=
2
,
spynet_path
=
None
).
cuda
()
img
=
torch
.
rand
((
1
,
2
,
3
,
64
,
64
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
2
,
3
,
256
,
256
)
def
test_convresidualblocks
():
"""Test block: ConvResidualBlocks."""
# model init and forward
net
=
ConvResidualBlocks
(
num_in_ch
=
3
,
num_out_ch
=
8
,
num_block
=
2
).
cuda
()
img
=
torch
.
rand
((
1
,
3
,
16
,
16
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
8
,
16
,
16
)
def
test_iconvsr
():
"""Test arch: IconVSR."""
# model init and forward
net
=
IconVSR
(
num_feat
=
8
,
num_block
=
1
,
keyframe_stride
=
2
,
temporal_padding
=
2
,
spynet_path
=
None
,
edvr_path
=
None
).
cuda
()
img
=
torch
.
rand
((
1
,
6
,
3
,
64
,
64
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
6
,
3
,
256
,
256
)
# --------------------------- temporal padding 3 ------------------------- #
net
=
IconVSR
(
num_feat
=
8
,
num_block
=
1
,
keyframe_stride
=
2
,
temporal_padding
=
3
,
spynet_path
=
None
,
edvr_path
=
None
).
cuda
()
img
=
torch
.
rand
((
1
,
8
,
3
,
64
,
64
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
8
,
3
,
256
,
256
)
BasicSR/tests/test_archs/test_discriminator_arch.py
0 → 100644
View file @
d5096d86
import
pytest
import
torch
from
basicsr.archs.discriminator_arch
import
VGGStyleDiscriminator
def
test_vggstylediscriminator
():
"""Test arch: VGGStyleDiscriminator."""
# model init and forward
net
=
VGGStyleDiscriminator
(
num_in_ch
=
3
,
num_feat
=
4
).
cuda
()
img
=
torch
.
rand
((
1
,
3
,
128
,
128
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
1
)
# ----------------------- input_size is 256 x 256------------------------ #
net
=
VGGStyleDiscriminator
(
num_in_ch
=
3
,
num_feat
=
4
,
input_size
=
256
).
cuda
()
img
=
torch
.
rand
((
1
,
3
,
256
,
256
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
1
)
# ----------------------- input feature size is not identical to input_size------------------------- #
with
pytest
.
raises
(
AssertionError
):
img
=
torch
.
rand
((
1
,
3
,
128
,
128
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
# ----------------------- input_size is not 128 or 256------------------------- #
with
pytest
.
raises
(
AssertionError
):
net
=
VGGStyleDiscriminator
(
num_in_ch
=
3
,
num_feat
=
4
,
input_size
=
64
)
BasicSR/tests/test_archs/test_duf_arch.py
0 → 100644
View file @
d5096d86
import
pytest
import
torch
from
basicsr.archs.duf_arch
import
DUF
,
DynamicUpsamplingFilter
def
test_duf
():
"""Test arch: DUF."""
# model init and forward
net
=
DUF
(
scale
=
4
,
num_layer
=
16
,
adapt_official_weights
=
False
).
cuda
()
img
=
torch
.
rand
((
1
,
7
,
3
,
48
,
48
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
3
,
192
,
192
)
# ----------------- test scale x3, num_layer=28 ---------------------- #
net
=
DUF
(
scale
=
3
,
num_layer
=
28
,
adapt_official_weights
=
True
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
3
,
144
,
144
)
# ----------------- test scale x2, num_layer=52 ---------------------- #
net
=
DUF
(
scale
=
2
,
num_layer
=
52
,
adapt_official_weights
=
True
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
3
,
96
,
96
)
# ----------------- unsupported num_layers ---------------------- #
with
pytest
.
raises
(
ValueError
):
net
=
DUF
(
scale
=
2
,
num_layer
=
4
,
adapt_official_weights
=
True
)
def
test_dynamicupsamplingfilter
():
"""Test block: DynamicUpsamplingFilter"""
net
=
DynamicUpsamplingFilter
(
filter_size
=
(
3
,
3
)).
cuda
()
img
=
torch
.
rand
((
2
,
3
,
12
,
12
),
dtype
=
torch
.
float32
).
cuda
()
filters
=
torch
.
rand
((
2
,
9
,
2
,
12
,
12
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
,
filters
)
assert
output
.
shape
==
(
2
,
6
,
12
,
12
)
# ----------------- wrong filter_size type ---------------------- #
with
pytest
.
raises
(
TypeError
):
DynamicUpsamplingFilter
(
filter_size
=
4
)
# ----------------- wrong filter_size shape ---------------------- #
with
pytest
.
raises
(
ValueError
):
DynamicUpsamplingFilter
(
filter_size
=
(
3
,
3
,
3
))
BasicSR/tests/test_archs/test_ecbsr_arch.py
0 → 100644
View file @
d5096d86
import
pytest
import
torch
from
basicsr.archs.ecbsr_arch
import
ECB
,
ECBSR
,
SeqConv3x3
def
test_ecbsr
():
"""Test arch: ECBSR."""
# model init and forward
net
=
ECBSR
(
num_in_ch
=
1
,
num_out_ch
=
1
,
num_block
=
1
,
num_channel
=
4
,
with_idt
=
False
,
act_type
=
'prelu'
,
scale
=
4
).
cuda
()
img
=
torch
.
rand
((
1
,
1
,
12
,
12
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
1
,
48
,
48
)
# ----------------- test 3 channels ---------------------- #
net
=
ECBSR
(
num_in_ch
=
3
,
num_out_ch
=
3
,
num_block
=
1
,
num_channel
=
4
,
with_idt
=
True
,
act_type
=
'rrelu'
,
scale
=
2
).
cuda
()
img
=
torch
.
rand
((
1
,
3
,
12
,
12
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
3
,
24
,
24
)
def
test_seqconv3x3
():
"""Test block: SeqConv3x3."""
# model init and forward
net
=
SeqConv3x3
(
seq_type
=
'conv1x1-conv3x3'
,
in_channels
=
2
,
out_channels
=
2
,
depth_multiplier
=
2
).
cuda
()
img
=
torch
.
rand
((
1
,
2
,
12
,
12
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
2
,
12
,
12
)
# test rep_params
conv
=
torch
.
nn
.
Conv2d
(
2
,
2
,
3
,
1
,
1
).
cuda
()
conv
.
weight
.
data
,
conv
.
bias
.
data
=
net
.
rep_params
()
output_rep
=
conv
(
img
)
assert
output_rep
.
shape
==
(
1
,
2
,
12
,
12
)
# whether the two results are close
assert
torch
.
allclose
(
output
,
output_rep
,
rtol
=
1e-5
,
atol
=
1e-5
)
# ----------------- test rep_params with conv1x1-laplacian seq ---------------------- #
net
=
SeqConv3x3
(
seq_type
=
'conv1x1-laplacian'
,
in_channels
=
4
,
out_channels
=
4
,
depth_multiplier
=
3
).
cuda
()
img
=
torch
.
rand
((
1
,
4
,
12
,
12
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
4
,
12
,
12
)
# test rep_params
conv
=
torch
.
nn
.
Conv2d
(
4
,
4
,
3
,
1
,
1
).
cuda
()
conv
.
weight
.
data
,
conv
.
bias
.
data
=
net
.
rep_params
()
output_rep
=
conv
(
img
)
assert
output_rep
.
shape
==
(
1
,
4
,
12
,
12
)
# whether the two results are close
assert
torch
.
allclose
(
output
,
output_rep
,
rtol
=
1e-5
,
atol
=
1e-5
)
# ----------------- unsupported type ---------------------- #
with
pytest
.
raises
(
ValueError
):
SeqConv3x3
(
seq_type
=
'noseq'
,
in_channels
=
1
,
out_channels
=
1
)
def
test_ecb
():
"""Test block: ECB."""
# model init and forward
net
=
ECB
(
in_channels
=
2
,
out_channels
=
2
,
depth_multiplier
=
1
,
act_type
=
'softplus'
,
with_idt
=
False
).
cuda
()
img
=
torch
.
rand
((
1
,
2
,
12
,
12
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
2
,
12
,
12
)
# test rep_params
net
=
net
.
eval
()
output_rep
=
net
(
img
)
assert
output_rep
.
shape
==
(
1
,
2
,
12
,
12
)
# whether the two results are close
assert
torch
.
allclose
(
output
,
output_rep
,
rtol
=
1e-5
,
atol
=
1e-5
)
# ----------------- linear activation function and identity---------------------- #
net
=
ECB
(
in_channels
=
2
,
out_channels
=
2
,
depth_multiplier
=
1
,
act_type
=
'linear'
,
with_idt
=
True
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
2
,
12
,
12
)
# test rep_params
net
=
net
.
eval
()
output_rep
=
net
(
img
)
assert
output_rep
.
shape
==
(
1
,
2
,
12
,
12
)
# whether the two results are close
assert
torch
.
allclose
(
output
,
output_rep
,
rtol
=
1e-5
,
atol
=
1e-5
)
# ----------------- relu activation function---------------------- #
net
=
ECB
(
in_channels
=
2
,
out_channels
=
2
,
depth_multiplier
=
1
,
act_type
=
'relu'
,
with_idt
=
False
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
2
,
12
,
12
)
# ----------------- unsupported type ---------------------- #
with
pytest
.
raises
(
ValueError
):
ECB
(
in_channels
=
2
,
out_channels
=
2
,
depth_multiplier
=
1
,
act_type
=
'unknown'
,
with_idt
=
False
)
BasicSR/tests/test_archs/test_srresnet_arch.py
0 → 100644
View file @
d5096d86
import
torch
from
basicsr.archs.srresnet_arch
import
MSRResNet
def
test_msrresnet
():
"""Test arch: MSRResNet."""
# model init and forward
net
=
MSRResNet
(
num_in_ch
=
3
,
num_out_ch
=
3
,
num_feat
=
12
,
num_block
=
2
,
upscale
=
4
).
cuda
()
img
=
torch
.
rand
((
1
,
3
,
16
,
16
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
3
,
64
,
64
)
# ----------------- the x3 case ---------------------- #
net
=
MSRResNet
(
num_in_ch
=
1
,
num_out_ch
=
1
,
num_feat
=
4
,
num_block
=
1
,
upscale
=
3
).
cuda
()
img
=
torch
.
rand
((
1
,
1
,
16
,
16
),
dtype
=
torch
.
float32
).
cuda
()
output
=
net
(
img
)
assert
output
.
shape
==
(
1
,
1
,
48
,
48
)
BasicSR/tests/test_data/test_paired_image_dataset.py
0 → 100644
View file @
d5096d86
import
yaml
from
basicsr.data.paired_image_dataset
import
PairedImageDataset
def
test_pairedimagedataset
():
"""Test dataset: PairedImageDataset"""
opt_str
=
r
"""
name: Test
type: PairedImageDataset
dataroot_gt: tests/data/gt
dataroot_lq: tests/data/lq
meta_info_file: tests/data/meta_info_gt.txt
filename_tmpl: '{}'
io_backend:
type: disk
scale: 4
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
gt_size: 128
use_hflip: true
use_rot: true
phase: train
"""
opt
=
yaml
.
safe_load
(
opt_str
)
dataset
=
PairedImageDataset
(
opt
)
assert
dataset
.
io_backend_opt
[
'type'
]
==
'disk'
# io backend
assert
len
(
dataset
)
==
2
# whether to read correct meta info
assert
dataset
.
mean
==
[
0.5
,
0.5
,
0.5
]
# test __getitem__
result
=
dataset
.
__getitem__
(
0
)
# check returned keys
expected_keys
=
[
'lq'
,
'gt'
,
'lq_path'
,
'gt_path'
]
assert
set
(
expected_keys
).
issubset
(
set
(
result
.
keys
()))
# check shape and contents
assert
result
[
'gt'
].
shape
==
(
3
,
128
,
128
)
assert
result
[
'lq'
].
shape
==
(
3
,
32
,
32
)
assert
result
[
'lq_path'
]
==
'tests/data/lq/baboon.png'
assert
result
[
'gt_path'
]
==
'tests/data/gt/baboon.png'
# ------------------ test filename_tmpl -------------------- #
opt
.
pop
(
'filename_tmpl'
)
opt
[
'io_backend'
]
=
dict
(
type
=
'disk'
)
dataset
=
PairedImageDataset
(
opt
)
assert
dataset
.
filename_tmpl
==
'{}'
# ------------------ test scan folder mode -------------------- #
opt
.
pop
(
'meta_info_file'
)
opt
[
'io_backend'
]
=
dict
(
type
=
'disk'
)
dataset
=
PairedImageDataset
(
opt
)
assert
dataset
.
io_backend_opt
[
'type'
]
==
'disk'
# io backend
assert
len
(
dataset
)
==
2
# whether to correctly scan folders
# ------------------ test lmdb backend and with y channel-------------------- #
opt
[
'dataroot_gt'
]
=
'tests/data/gt.lmdb'
opt
[
'dataroot_lq'
]
=
'tests/data/lq.lmdb'
opt
[
'io_backend'
]
=
dict
(
type
=
'lmdb'
)
opt
[
'color'
]
=
'y'
opt
[
'mean'
]
=
[
0.5
]
opt
[
'std'
]
=
[
0.5
]
dataset
=
PairedImageDataset
(
opt
)
assert
dataset
.
io_backend_opt
[
'type'
]
==
'lmdb'
# io backend
assert
len
(
dataset
)
==
2
# whether to read correct meta info
assert
dataset
.
std
==
[
0.5
]
# test __getitem__
result
=
dataset
.
__getitem__
(
1
)
# check returned keys
expected_keys
=
[
'lq'
,
'gt'
,
'lq_path'
,
'gt_path'
]
assert
set
(
expected_keys
).
issubset
(
set
(
result
.
keys
()))
# check shape and contents
assert
result
[
'gt'
].
shape
==
(
1
,
128
,
128
)
assert
result
[
'lq'
].
shape
==
(
1
,
32
,
32
)
assert
result
[
'lq_path'
]
==
'comic'
assert
result
[
'gt_path'
]
==
'comic'
# ------------------ test case: val/test mode -------------------- #
opt
[
'phase'
]
=
'test'
opt
[
'io_backend'
]
=
dict
(
type
=
'lmdb'
)
dataset
=
PairedImageDataset
(
opt
)
# test __getitem__
result
=
dataset
.
__getitem__
(
0
)
# check returned keys
expected_keys
=
[
'lq'
,
'gt'
,
'lq_path'
,
'gt_path'
]
assert
set
(
expected_keys
).
issubset
(
set
(
result
.
keys
()))
# check shape and contents
assert
result
[
'gt'
].
shape
==
(
1
,
480
,
492
)
assert
result
[
'lq'
].
shape
==
(
1
,
120
,
123
)
assert
result
[
'lq_path'
]
==
'baboon'
assert
result
[
'gt_path'
]
==
'baboon'
BasicSR/tests/test_data/test_single_image_dataset.py
0 → 100644
View file @
d5096d86
import
yaml
from
basicsr.data.single_image_dataset
import
SingleImageDataset
def
test_singleimagedataset
():
"""Test dataset: SingleImageDataset"""
opt_str
=
r
"""
name: Test
type: SingleImageDataset
dataroot_lq: tests/data/lq
meta_info_file: tests/data/meta_info_gt.txt
io_backend:
type: disk
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
"""
opt
=
yaml
.
safe_load
(
opt_str
)
dataset
=
SingleImageDataset
(
opt
)
assert
dataset
.
io_backend_opt
[
'type'
]
==
'disk'
# io backend
assert
len
(
dataset
)
==
2
# whether to read correct meta info
assert
dataset
.
mean
==
[
0.5
,
0.5
,
0.5
]
# test __getitem__
result
=
dataset
.
__getitem__
(
0
)
# check returned keys
expected_keys
=
[
'lq'
,
'lq_path'
]
assert
set
(
expected_keys
).
issubset
(
set
(
result
.
keys
()))
# check shape and contents
assert
result
[
'lq'
].
shape
==
(
3
,
120
,
123
)
assert
result
[
'lq_path'
]
==
'tests/data/lq/baboon.png'
# ------------------ test scan folder mode -------------------- #
opt
.
pop
(
'meta_info_file'
)
opt
[
'io_backend'
]
=
dict
(
type
=
'disk'
)
dataset
=
SingleImageDataset
(
opt
)
assert
dataset
.
io_backend_opt
[
'type'
]
==
'disk'
# io backend
assert
len
(
dataset
)
==
2
# whether to correctly scan folders
# ------------------ test lmdb backend and with y channel-------------------- #
opt
[
'dataroot_lq'
]
=
'tests/data/lq.lmdb'
opt
[
'io_backend'
]
=
dict
(
type
=
'lmdb'
)
opt
[
'color'
]
=
'y'
opt
[
'mean'
]
=
[
0.5
]
opt
[
'std'
]
=
[
0.5
]
dataset
=
SingleImageDataset
(
opt
)
assert
dataset
.
io_backend_opt
[
'type'
]
==
'lmdb'
# io backend
assert
len
(
dataset
)
==
2
# whether to read correct meta info
assert
dataset
.
std
==
[
0.5
]
# test __getitem__
result
=
dataset
.
__getitem__
(
1
)
# check returned keys
expected_keys
=
[
'lq'
,
'lq_path'
]
assert
set
(
expected_keys
).
issubset
(
set
(
result
.
keys
()))
# check shape and contents
assert
result
[
'lq'
].
shape
==
(
1
,
90
,
60
)
assert
result
[
'lq_path'
]
==
'comic'
BasicSR/tests/test_losses/test_losses.py
0 → 100644
View file @
d5096d86
import
pytest
import
torch
from
basicsr.losses.basic_loss
import
CharbonnierLoss
,
L1Loss
,
MSELoss
,
WeightedTVLoss
@
pytest
.
mark
.
parametrize
(
'loss_class'
,
[
L1Loss
,
MSELoss
,
CharbonnierLoss
])
def
test_pixellosses
(
loss_class
):
"""Test loss: pixel losses"""
pred
=
torch
.
rand
((
1
,
3
,
4
,
4
),
dtype
=
torch
.
float32
)
target
=
torch
.
rand
((
1
,
3
,
4
,
4
),
dtype
=
torch
.
float32
)
loss
=
loss_class
(
loss_weight
=
1.0
,
reduction
=
'mean'
)
out
=
loss
(
pred
,
target
,
weight
=
None
)
assert
isinstance
(
out
,
torch
.
Tensor
)
assert
out
.
shape
==
torch
.
Size
([])
# -------------------- test with other reduction -------------------- #
# reduction = none
loss
=
loss_class
(
loss_weight
=
1.0
,
reduction
=
'none'
)
out
=
loss
(
pred
,
target
,
weight
=
None
)
assert
isinstance
(
out
,
torch
.
Tensor
)
assert
out
.
shape
==
(
1
,
3
,
4
,
4
)
# test with spatial weights
weight
=
torch
.
rand
((
1
,
3
,
4
,
4
),
dtype
=
torch
.
float32
)
out
=
loss
(
pred
,
target
,
weight
=
weight
)
assert
isinstance
(
out
,
torch
.
Tensor
)
assert
out
.
shape
==
(
1
,
3
,
4
,
4
)
# reduction = sum
loss
=
loss_class
(
loss_weight
=
1.0
,
reduction
=
'sum'
)
out
=
loss
(
pred
,
target
,
weight
=
None
)
assert
isinstance
(
out
,
torch
.
Tensor
)
assert
out
.
shape
==
torch
.
Size
([])
# -------------------- test unsupported loss reduction -------------------- #
with
pytest
.
raises
(
ValueError
):
loss_class
(
loss_weight
=
1.0
,
reduction
=
'unknown'
)
def
test_weightedtvloss
():
"""Test loss: WeightedTVLoss"""
pred
=
torch
.
rand
((
1
,
3
,
4
,
4
),
dtype
=
torch
.
float32
)
loss
=
WeightedTVLoss
(
loss_weight
=
1.0
,
reduction
=
'mean'
)
out
=
loss
(
pred
,
weight
=
None
)
assert
isinstance
(
out
,
torch
.
Tensor
)
assert
out
.
shape
==
torch
.
Size
([])
# test with spatial weights
weight
=
torch
.
rand
((
1
,
3
,
4
,
4
),
dtype
=
torch
.
float32
)
out
=
loss
(
pred
,
weight
=
weight
)
assert
isinstance
(
out
,
torch
.
Tensor
)
assert
out
.
shape
==
torch
.
Size
([])
# -------------------- test reduction = sum-------------------- #
loss
=
WeightedTVLoss
(
loss_weight
=
1.0
,
reduction
=
'sum'
)
out
=
loss
(
pred
,
weight
=
None
)
assert
isinstance
(
out
,
torch
.
Tensor
)
assert
out
.
shape
==
torch
.
Size
([])
# test with spatial weights
weight
=
torch
.
rand
((
1
,
3
,
4
,
4
),
dtype
=
torch
.
float32
)
out
=
loss
(
pred
,
weight
=
weight
)
assert
isinstance
(
out
,
torch
.
Tensor
)
assert
out
.
shape
==
torch
.
Size
([])
# -------------------- test unsupported loss reduction -------------------- #
with
pytest
.
raises
(
ValueError
):
WeightedTVLoss
(
loss_weight
=
1.0
,
reduction
=
'unknown'
)
with
pytest
.
raises
(
ValueError
):
WeightedTVLoss
(
loss_weight
=
1.0
,
reduction
=
'none'
)
Prev
1
…
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment