Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dcnv3
Commits
41b18fd8
Commit
41b18fd8
authored
Jan 06, 2025
by
zhe chen
Browse files
Use pre-commit to reformat code
Use pre-commit to reformat code
parent
ff20ea39
Changes
390
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
63 additions
and
72 deletions
+63
-72
segmentation/mmseg_custom/models/utils/transformer.py
segmentation/mmseg_custom/models/utils/transformer.py
+7
-8
segmentation/ops_dcnv3/functions/dcnv3_func.py
segmentation/ops_dcnv3/functions/dcnv3_func.py
+3
-4
segmentation/ops_dcnv3/modules/__init__.py
segmentation/ops_dcnv3/modules/__init__.py
+1
-1
segmentation/ops_dcnv3/modules/dcnv3.py
segmentation/ops_dcnv3/modules/dcnv3.py
+18
-16
segmentation/ops_dcnv3/setup.py
segmentation/ops_dcnv3/setup.py
+20
-25
segmentation/ops_dcnv3/src/cuda/dcnv3_cuda.cu
segmentation/ops_dcnv3/src/cuda/dcnv3_cuda.cu
+1
-1
segmentation/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh
segmentation/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh
+1
-1
segmentation/ops_dcnv3/test.py
segmentation/ops_dcnv3/test.py
+4
-6
segmentation/test.py
segmentation/test.py
+4
-5
segmentation/train.py
segmentation/train.py
+4
-5
No files found.
segmentation/mmseg_custom/models/utils/transformer.py
View file @
41b18fd8
...
@@ -9,17 +9,16 @@ import torch.nn.functional as F
...
@@ -9,17 +9,16 @@ import torch.nn.functional as F
import
torch.utils.checkpoint
as
cp
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
(
Linear
,
build_activation_layer
,
build_conv_layer
,
from
mmcv.cnn
import
(
Linear
,
build_activation_layer
,
build_conv_layer
,
build_norm_layer
,
xavier_init
)
build_norm_layer
,
xavier_init
)
from
mmcv.cnn.bricks.registry
import
(
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
,
FEEDFORWARD_NETWORK
)
from
mmcv.cnn.bricks.drop
import
build_dropout
from
mmcv.cnn.bricks.drop
import
build_dropout
from
mmcv.cnn.bricks.registry
import
(
FEEDFORWARD_NETWORK
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
(
BaseTransformerLayer
,
from
mmcv.cnn.bricks.transformer
import
(
BaseTransformerLayer
,
TransformerLayerSequence
,
TransformerLayerSequence
,
build_transformer_layer_sequence
,
build_attention
,
build_attention
,
build_feedforward_network
)
build_feedforward_network
,
build_transformer_layer_sequence
)
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.utils
import
to_2tuple
,
ConfigDict
,
deprecated_api_warning
from
mmcv.utils
import
ConfigDict
,
deprecated_api_warning
,
to_2tuple
from
torch.nn.init
import
normal_
from
torch.nn.init
import
normal_
from
..builder
import
TRANSFORMER
from
..builder
import
TRANSFORMER
...
@@ -319,12 +318,12 @@ class FFN(BaseModule):
...
@@ -319,12 +318,12 @@ class FFN(BaseModule):
"""Forward function for `FFN`.
"""Forward function for `FFN`.
The function would add x to the output tensor if residue is None.
The function would add x to the output tensor if residue is None.
"""
"""
if
self
.
with_cp
and
x
.
requires_grad
:
if
self
.
with_cp
and
x
.
requires_grad
:
out
=
cp
.
checkpoint
(
self
.
layers
,
x
)
out
=
cp
.
checkpoint
(
self
.
layers
,
x
)
else
:
else
:
out
=
self
.
layers
(
x
)
out
=
self
.
layers
(
x
)
if
not
self
.
add_identity
:
if
not
self
.
add_identity
:
return
self
.
dropout_layer
(
out
)
return
self
.
dropout_layer
(
out
)
if
identity
is
None
:
if
identity
is
None
:
...
...
segmentation/ops_dcnv3/functions/dcnv3_func.py
View file @
41b18fd8
...
@@ -4,16 +4,14 @@
...
@@ -4,16 +4,14 @@
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
print_function
from
__future__
import
division
import
DCNv3
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
from
torch.autograd.function
import
once_differentiable
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
import
DCNv3
class
DCNv3Function
(
Function
):
class
DCNv3Function
(
Function
):
...
@@ -88,6 +86,7 @@ class DCNv3Function(Function):
...
@@ -88,6 +86,7 @@ class DCNv3Function(Function):
im2col_step_i
=
int
(
im2col_step
),
im2col_step_i
=
int
(
im2col_step
),
)
)
def
_get_reference_points
(
spatial_shapes
,
device
,
kernel_h
,
kernel_w
,
dilation_h
,
dilation_w
,
pad_h
=
0
,
pad_w
=
0
,
stride_h
=
1
,
stride_w
=
1
):
def
_get_reference_points
(
spatial_shapes
,
device
,
kernel_h
,
kernel_w
,
dilation_h
,
dilation_w
,
pad_h
=
0
,
pad_w
=
0
,
stride_h
=
1
,
stride_w
=
1
):
_
,
H_
,
W_
,
_
=
spatial_shapes
_
,
H_
,
W_
,
_
=
spatial_shapes
H_out
=
(
H_
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
//
stride_h
+
1
H_out
=
(
H_
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
//
stride_h
+
1
...
...
segmentation/ops_dcnv3/modules/__init__.py
View file @
41b18fd8
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
from
.dcnv3
import
DCNv3
,
DCNv3_pytorch
from
.dcnv3
import
DCNv3
,
DCNv3_pytorch
\ No newline at end of file
segmentation/ops_dcnv3/modules/dcnv3.py
View file @
41b18fd8
...
@@ -4,22 +4,24 @@
...
@@ -4,22 +4,24 @@
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
print_function
from
__future__
import
division
import
warnings
import
warnings
import
torch
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.init
import
xavier_uniform_
,
constant_
from
torch
import
nn
from
torch.nn.init
import
constant_
,
xavier_uniform_
from
..functions
import
DCNv3Function
,
dcnv3_core_pytorch
from
..functions
import
DCNv3Function
,
dcnv3_core_pytorch
try
:
try
:
from
DCNv4.functions
import
DCNv4Function
from
DCNv4.functions
import
DCNv4Function
except
:
except
:
warnings
.
warn
(
'Now, we support DCNv4 in InternImage.'
)
warnings
.
warn
(
'Now, we support DCNv4 in InternImage.'
)
import
math
import
math
class
to_channels_first
(
nn
.
Module
):
class
to_channels_first
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -76,7 +78,7 @@ def build_act_layer(act_layer):
...
@@ -76,7 +78,7 @@ def build_act_layer(act_layer):
def
_is_power_of_2
(
n
):
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
raise
ValueError
(
"
invalid input for _is_power_of_2: {} (type: {})
"
.
format
(
n
,
type
(
n
)))
'
invalid input for _is_power_of_2: {} (type: {})
'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
...
@@ -128,7 +130,7 @@ class DCNv3_pytorch(nn.Module):
...
@@ -128,7 +130,7 @@ class DCNv3_pytorch(nn.Module):
if
not
_is_power_of_2
(
_d_per_group
):
if
not
_is_power_of_2
(
_d_per_group
):
warnings
.
warn
(
warnings
.
warn
(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"
which is more efficient in our CUDA implementation.
"
)
'
which is more efficient in our CUDA implementation.
'
)
self
.
offset_scale
=
offset_scale
self
.
offset_scale
=
offset_scale
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -165,7 +167,7 @@ class DCNv3_pytorch(nn.Module):
...
@@ -165,7 +167,7 @@ class DCNv3_pytorch(nn.Module):
self
.
input_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
input_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
output_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
output_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
_reset_parameters
()
self
.
_reset_parameters
()
if
center_feature_scale
:
if
center_feature_scale
:
self
.
center_feature_scale_proj_weight
=
nn
.
Parameter
(
self
.
center_feature_scale_proj_weight
=
nn
.
Parameter
(
torch
.
zeros
((
group
,
channels
),
dtype
=
torch
.
float
))
torch
.
zeros
((
group
,
channels
),
dtype
=
torch
.
float
))
...
@@ -234,7 +236,7 @@ class DCNv3(nn.Module):
...
@@ -234,7 +236,7 @@ class DCNv3(nn.Module):
norm_layer
=
'LN'
,
norm_layer
=
'LN'
,
center_feature_scale
=
False
,
center_feature_scale
=
False
,
use_dcn_v4_op
=
False
,
use_dcn_v4_op
=
False
,
):
):
"""
"""
DCNv3 Module
DCNv3 Module
:param channels
:param channels
...
@@ -257,7 +259,7 @@ class DCNv3(nn.Module):
...
@@ -257,7 +259,7 @@ class DCNv3(nn.Module):
if
not
_is_power_of_2
(
_d_per_group
):
if
not
_is_power_of_2
(
_d_per_group
):
warnings
.
warn
(
warnings
.
warn
(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"
which is more efficient in our CUDA implementation.
"
)
'
which is more efficient in our CUDA implementation.
'
)
self
.
offset_scale
=
offset_scale
self
.
offset_scale
=
offset_scale
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -270,7 +272,7 @@ class DCNv3(nn.Module):
...
@@ -270,7 +272,7 @@ class DCNv3(nn.Module):
self
.
group_channels
=
channels
//
group
self
.
group_channels
=
channels
//
group
self
.
offset_scale
=
offset_scale
self
.
offset_scale
=
offset_scale
self
.
center_feature_scale
=
center_feature_scale
self
.
center_feature_scale
=
center_feature_scale
self
.
use_dcn_v4_op
=
use_dcn_v4_op
self
.
use_dcn_v4_op
=
use_dcn_v4_op
self
.
dw_conv
=
nn
.
Sequential
(
self
.
dw_conv
=
nn
.
Sequential
(
...
@@ -296,7 +298,7 @@ class DCNv3(nn.Module):
...
@@ -296,7 +298,7 @@ class DCNv3(nn.Module):
self
.
input_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
input_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
output_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
output_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
_reset_parameters
()
self
.
_reset_parameters
()
if
center_feature_scale
:
if
center_feature_scale
:
self
.
center_feature_scale_proj_weight
=
nn
.
Parameter
(
self
.
center_feature_scale_proj_weight
=
nn
.
Parameter
(
torch
.
zeros
((
group
,
channels
),
dtype
=
torch
.
float
))
torch
.
zeros
((
group
,
channels
),
dtype
=
torch
.
float
))
...
@@ -329,7 +331,7 @@ class DCNv3(nn.Module):
...
@@ -329,7 +331,7 @@ class DCNv3(nn.Module):
x1
=
self
.
dw_conv
(
x1
)
x1
=
self
.
dw_conv
(
x1
)
offset
=
self
.
offset
(
x1
)
offset
=
self
.
offset
(
x1
)
mask
=
self
.
mask
(
x1
).
reshape
(
N
,
H
,
W
,
self
.
group
,
-
1
)
mask
=
self
.
mask
(
x1
).
reshape
(
N
,
H
,
W
,
self
.
group
,
-
1
)
if
not
self
.
use_dcn_v4_op
:
if
not
self
.
use_dcn_v4_op
:
mask
=
F
.
softmax
(
mask
,
-
1
).
reshape
(
N
,
H
,
W
,
-
1
).
type
(
dtype
)
mask
=
F
.
softmax
(
mask
,
-
1
).
reshape
(
N
,
H
,
W
,
-
1
).
type
(
dtype
)
x
=
DCNv3Function
.
apply
(
x
=
DCNv3Function
.
apply
(
...
@@ -349,12 +351,12 @@ class DCNv3(nn.Module):
...
@@ -349,12 +351,12 @@ class DCNv3(nn.Module):
mask
=
mask
.
view
(
N
,
H
,
W
,
self
.
group
,
-
1
)
mask
=
mask
.
view
(
N
,
H
,
W
,
self
.
group
,
-
1
)
offset_mask
=
torch
.
cat
([
offset
,
mask
],
-
1
).
view
(
N
,
H
,
W
,
-
1
).
contiguous
()
offset_mask
=
torch
.
cat
([
offset
,
mask
],
-
1
).
view
(
N
,
H
,
W
,
-
1
).
contiguous
()
# For efficiency, the last dimension of the offset_mask tensor in dcnv4 is a multiple of 8.
# For efficiency, the last dimension of the offset_mask tensor in dcnv4 is a multiple of 8.
K3
=
offset_mask
.
size
(
-
1
)
K3
=
offset_mask
.
size
(
-
1
)
K3_pad
=
int
(
math
.
ceil
(
K3
/
8
)
*
8
)
K3_pad
=
int
(
math
.
ceil
(
K3
/
8
)
*
8
)
pad_dim
=
K3_pad
-
K3
pad_dim
=
K3_pad
-
K3
offset_mask
=
torch
.
cat
([
offset_mask
,
offset_mask
.
new_zeros
([
*
offset_mask
.
size
()[:
3
],
pad_dim
])],
-
1
)
offset_mask
=
torch
.
cat
([
offset_mask
,
offset_mask
.
new_zeros
([
*
offset_mask
.
size
()[:
3
],
pad_dim
])],
-
1
)
x
=
DCNv4Function
.
apply
(
x
=
DCNv4Function
.
apply
(
x
,
offset_mask
,
x
,
offset_mask
,
self
.
kernel_size
,
self
.
kernel_size
,
self
.
kernel_size
,
self
.
kernel_size
,
...
...
segmentation/ops_dcnv3/setup.py
View file @
41b18fd8
...
@@ -4,39 +4,34 @@
...
@@ -4,39 +4,34 @@
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
import
os
import
glob
import
glob
import
os
import
torch
import
torch
from
setuptools
import
find_packages
,
setup
from
torch.utils.cpp_extension
import
CUDA_HOME
,
CppExtension
,
CUDAExtension
from
torch.utils.cpp_extension
import
CUDA_HOME
requirements
=
[
'torch'
,
'torchvision'
]
from
torch.utils.cpp_extension
import
CppExtension
from
torch.utils.cpp_extension
import
CUDAExtension
from
setuptools
import
find_packages
from
setuptools
import
setup
requirements
=
[
"torch"
,
"torchvision"
]
def
get_extensions
():
def
get_extensions
():
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
extensions_dir
=
os
.
path
.
join
(
this_dir
,
"
src
"
)
extensions_dir
=
os
.
path
.
join
(
this_dir
,
'
src
'
)
main_file
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
"
*.cpp
"
))
main_file
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
'
*.cpp
'
))
source_cpu
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
"
cpu
"
,
"
*.cpp
"
))
source_cpu
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
'
cpu
'
,
'
*.cpp
'
))
source_cuda
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
"
cuda
"
,
"
*.cu
"
))
source_cuda
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
'
cuda
'
,
'
*.cu
'
))
sources
=
main_file
+
source_cpu
sources
=
main_file
+
source_cpu
extension
=
CppExtension
extension
=
CppExtension
extra_compile_args
=
{
"
cxx
"
:
[]}
extra_compile_args
=
{
'
cxx
'
:
[]}
define_macros
=
[]
define_macros
=
[]
if
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
:
if
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
:
extension
=
CUDAExtension
extension
=
CUDAExtension
sources
+=
source_cuda
sources
+=
source_cuda
define_macros
+=
[(
"
WITH_CUDA
"
,
None
)]
define_macros
+=
[(
'
WITH_CUDA
'
,
None
)]
extra_compile_args
[
"
nvcc
"
]
=
[
extra_compile_args
[
'
nvcc
'
]
=
[
# "-DCUDA_HAS_FP16=1",
# "-DCUDA_HAS_FP16=1",
# "-D__CUDA_NO_HALF_OPERATORS__",
# "-D__CUDA_NO_HALF_OPERATORS__",
# "-D__CUDA_NO_HALF_CONVERSIONS__",
# "-D__CUDA_NO_HALF_CONVERSIONS__",
...
@@ -49,7 +44,7 @@ def get_extensions():
...
@@ -49,7 +44,7 @@ def get_extensions():
include_dirs
=
[
extensions_dir
]
include_dirs
=
[
extensions_dir
]
ext_modules
=
[
ext_modules
=
[
extension
(
extension
(
"
DCNv3
"
,
'
DCNv3
'
,
sources
,
sources
,
include_dirs
=
include_dirs
,
include_dirs
=
include_dirs
,
define_macros
=
define_macros
,
define_macros
=
define_macros
,
...
@@ -60,16 +55,16 @@ def get_extensions():
...
@@ -60,16 +55,16 @@ def get_extensions():
setup
(
setup
(
name
=
"
DCNv3
"
,
name
=
'
DCNv3
'
,
version
=
"
1.0
"
,
version
=
'
1.0
'
,
author
=
"
InternImage
"
,
author
=
'
InternImage
'
,
url
=
"
https://github.com/OpenGVLab/InternImage
"
,
url
=
'
https://github.com/OpenGVLab/InternImage
'
,
description
=
description
=
"
PyTorch Wrapper for CUDA Functions of DCNv3
"
,
'
PyTorch Wrapper for CUDA Functions of DCNv3
'
,
packages
=
find_packages
(
exclude
=
(
packages
=
find_packages
(
exclude
=
(
"
configs
"
,
'
configs
'
,
"
tests
"
,
'
tests
'
,
)),
)),
ext_modules
=
get_extensions
(),
ext_modules
=
get_extensions
(),
cmdclass
=
{
"
build_ext
"
:
torch
.
utils
.
cpp_extension
.
BuildExtension
},
cmdclass
=
{
'
build_ext
'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
},
)
)
segmentation/ops_dcnv3/src/cuda/dcnv3_cuda.cu
View file @
41b18fd8
...
@@ -171,4 +171,4 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
...
@@ -171,4 +171,4 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
}
else
{
}
else
{
return
{
grad_input
,
grad_offset
,
grad_mask
};
return
{
grad_input
,
grad_offset
,
grad_mask
};
}
}
}
}
\ No newline at end of file
segmentation/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh
View file @
41b18fd8
...
@@ -1042,4 +1042,4 @@ void dcnv3_col2im_cuda(
...
@@ -1042,4 +1042,4 @@ void dcnv3_col2im_cuda(
if
(
err
!=
cudaSuccess
)
{
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in dcnv3_col2im_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
printf
(
"error in dcnv3_col2im_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
}
}
}
}
\ No newline at end of file
segmentation/ops_dcnv3/test.py
View file @
41b18fd8
...
@@ -4,17 +4,15 @@
...
@@ -4,17 +4,15 @@
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
print_function
from
__future__
import
division
import
math
import
time
import
time
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
math
from
torch.autograd
import
gradcheck
from
functions.dcnv3_func
import
DCNv3Function
,
dcnv3_core_pytorch
from
functions.dcnv3_func
import
DCNv3Function
,
dcnv3_core_pytorch
from
torch.autograd
import
gradcheck
H_in
,
W_in
=
8
,
8
H_in
,
W_in
=
8
,
8
N
,
M
,
D
=
2
,
4
,
16
N
,
M
,
D
=
2
,
4
,
16
...
...
segmentation/test.py
View file @
41b18fd8
...
@@ -12,18 +12,17 @@ import time
...
@@ -12,18 +12,17 @@ import time
import
warnings
import
warnings
import
mmcv
import
mmcv
import
mmcv_custom
# noqa: F401,F403
import
mmseg_custom
# noqa: F401,F403
import
torch
import
torch
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
get_dist_info
,
init_dist
,
load_checkpoint
,
from
mmcv.runner
import
(
get_dist_info
,
init_dist
,
load_checkpoint
,
wrap_fp16_model
,
load_state_dict
)
load_state_dict
,
wrap_fp16_model
)
from
mmcv.utils
import
DictAction
from
mmcv.utils
import
DictAction
from
mmseg.apis
import
multi_gpu_test
,
single_gpu_test
from
mmseg.apis
import
multi_gpu_test
,
single_gpu_test
from
mmseg.datasets
import
build_dataloader
,
build_dataset
from
mmseg.datasets
import
build_dataloader
,
build_dataset
from
mmseg.models
import
build_segmentor
from
mmseg.models
import
build_segmentor
import
mmcv_custom
# noqa: F401,F403
import
mmseg_custom
# noqa: F401,F403
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
...
@@ -197,7 +196,7 @@ def main():
...
@@ -197,7 +196,7 @@ def main():
load_state_dict
(
model
.
module
,
checkpoint
[
'state_dict'
],
strict
=
False
)
load_state_dict
(
model
.
module
,
checkpoint
[
'state_dict'
],
strict
=
False
)
else
:
else
:
load_state_dict
(
model
,
checkpoint
[
'state_dict'
],
strict
=
False
)
load_state_dict
(
model
,
checkpoint
[
'state_dict'
],
strict
=
False
)
if
'CLASSES'
in
checkpoint
.
get
(
'meta'
,
{}):
if
'CLASSES'
in
checkpoint
.
get
(
'meta'
,
{}):
model
.
CLASSES
=
checkpoint
[
'meta'
][
'CLASSES'
]
model
.
CLASSES
=
checkpoint
[
'meta'
][
'CLASSES'
]
else
:
else
:
...
...
segmentation/train.py
View file @
41b18fd8
...
@@ -12,20 +12,19 @@ import time
...
@@ -12,20 +12,19 @@ import time
import
warnings
import
warnings
import
mmcv
import
mmcv
import
mmcv_custom
# noqa: F401,F403
import
mmseg_custom
# noqa: F401,F403
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
mmcv.cnn.utils
import
revert_sync_batchnorm
from
mmcv.cnn.utils
import
revert_sync_batchnorm
from
mmcv.runner
import
get_dist_info
,
init_dist
from
mmcv.runner
import
get_dist_info
,
init_dist
from
mmcv.utils
import
Config
,
DictAction
,
get_git_hash
from
mmcv.utils
import
Config
,
DictAction
,
get_git_hash
from
mmseg
import
__version__
from
mmseg
import
__version__
from
mmseg.apis
import
init_random_seed
,
set_random_seed
,
train_segmentor
from
mmseg.apis
import
init_random_seed
,
set_random_seed
,
train_segmentor
from
mmseg.datasets
import
build_dataset
from
mmseg.datasets
import
build_dataset
from
mmseg.models
import
build_segmentor
from
mmseg.models
import
build_segmentor
from
mmseg.utils
import
(
collect_env
,
get_device
,
get_root_logger
,
from
mmseg.utils
import
(
collect_env
,
get_device
,
get_root_logger
,
setup_multi_processes
)
setup_multi_processes
)
import
mmcv_custom
# noqa: F401,F403
import
mmseg_custom
# noqa: F401,F403
def
parse_args
():
def
parse_args
():
...
@@ -231,10 +230,10 @@ def main():
...
@@ -231,10 +230,10 @@ def main():
model
.
CLASSES
=
datasets
[
0
].
CLASSES
model
.
CLASSES
=
datasets
[
0
].
CLASSES
# passing checkpoint meta for saving best checkpoint
# passing checkpoint meta for saving best checkpoint
meta
.
update
(
cfg
.
checkpoint_config
.
meta
)
meta
.
update
(
cfg
.
checkpoint_config
.
meta
)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cudnn
.
allow_tf32
=
True
torch
.
backends
.
cudnn
.
allow_tf32
=
True
train_segmentor
(
model
,
train_segmentor
(
model
,
datasets
,
datasets
,
cfg
,
cfg
,
...
...
Prev
1
…
16
17
18
19
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment