Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
ef3db59a
"plugins/vscode:/vscode.git/clone" did not exist on "6717a85cc4255c7e3985d5cfdb022fbc9106830f"
Commit
ef3db59a
authored
Dec 01, 2021
by
yan.yan
Browse files
fix serious bug in weight init
parent
8aa0f1f7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
67 deletions
+63
-67
CHANGELOG.md
CHANGELOG.md
+1
-0
README.md
README.md
+1
-1
spconv/pytorch/conv.py
spconv/pytorch/conv.py
+61
-66
No files found.
CHANGELOG.md
View file @
ef3db59a
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
## [2.1.17] - 2021-11-29
## [2.1.17] - 2021-11-29
### Fixed
### Fixed
-
Fix a bug in sparse add.
-
Fix a bug in sparse add.
-
Fix a serious bug in conv weight init.
### Added
### Added
-
Add more wrong usage check
-
Add more wrong usage check
-
Add insert_exist_keys for hash table
-
Add insert_exist_keys for hash table
...
...
README.md
View file @
ef3db59a
...
@@ -48,7 +48,7 @@
...
@@ -48,7 +48,7 @@
Check [spconv 2.x algorithm introduction](docs/spconv2_algo.pdf) to understand sparse convolution algorithm in spconv 2.x!
Check [spconv 2.x algorithm introduction](docs/spconv2_algo.pdf) to understand sparse convolution algorithm in spconv 2.x!
**WARNING** spconv < 2.1.
4
users need to upgrade your version to 2.1.
4
, it fix a
serious bug in SparseInverseConvXd
.
**WARNING** spconv < 2.1.
17
users need to upgrade your version to 2.1.
17
, it fix a
bug in conv weight init which cause std of inited weight too large
.
## Breaking changes in Spconv 2.x
## Breaking changes in Spconv 2.x
...
...
spconv/pytorch/conv.py
View file @
ef3db59a
...
@@ -33,41 +33,9 @@ from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndice
...
@@ -33,41 +33,9 @@ from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndice
from
spconv.pytorch.modules
import
SparseModule
from
spconv.pytorch.modules
import
SparseModule
from
spconv.constants
import
FILTER_HWIO
from
spconv.constants
import
FILTER_HWIO
from
spconv.utils
import
nullcontext
from
spconv.utils
import
nullcontext
from
torch.nn.init
import
calculate_gain
def
_calculate_fan_in_and_fan_out_hwio
(
tensor
,
algo
:
ConvAlgo
):
dimensions
=
tensor
.
ndimension
()
if
dimensions
<
2
:
raise
ValueError
(
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
)
if
dimensions
==
2
:
# Linear
fan_in
=
tensor
.
size
(
-
2
)
fan_out
=
tensor
.
size
(
-
1
)
else
:
if
algo
==
ConvAlgo
.
Native
:
if
FILTER_HWIO
:
num_input_fmaps
=
tensor
.
size
(
-
2
)
num_output_fmaps
=
tensor
.
size
(
-
1
)
else
:
num_input_fmaps
=
tensor
.
size
(
-
1
)
num_output_fmaps
=
tensor
.
size
(
-
2
)
receptive_field_size
=
1
if
tensor
.
dim
()
>
2
:
receptive_field_size
=
tensor
[...,
0
,
0
].
numel
()
else
:
num_input_fmaps
=
tensor
.
size
(
-
1
)
num_output_fmaps
=
tensor
.
size
(
0
)
receptive_field_size
=
1
if
tensor
.
dim
()
>
2
:
receptive_field_size
=
int
(
np
.
prod
(
tensor
.
shape
[
1
:
-
1
]))
fan_in
=
num_input_fmaps
*
receptive_field_size
fan_out
=
num_output_fmaps
*
receptive_field_size
return
fan_in
,
fan_out
class
SparseConvolution
(
SparseModule
):
class
SparseConvolution
(
SparseModule
):
...
@@ -99,15 +67,18 @@ class SparseConvolution(SparseModule):
...
@@ -99,15 +67,18 @@ class SparseConvolution(SparseModule):
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
expand_nd
(
ndim
,
kernel_size
)
self
.
kernel_size
=
expand_nd
(
ndim
,
kernel_size
)
kv
=
int
(
np
.
prod
(
kernel_size
))
self
.
stride
=
expand_nd
(
ndim
,
stride
)
kv_stride
=
int
(
np
.
prod
(
stride
))
kv
=
int
(
np
.
prod
(
self
.
kernel_size
))
kv_stride
=
int
(
np
.
prod
(
self
.
stride
))
self
.
dilation
=
expand_nd
(
ndim
,
dilation
)
self
.
padding
=
expand_nd
(
ndim
,
padding
)
self
.
conv1x1
=
kv
==
1
self
.
conv1x1
=
kv
==
1
# TODO we should deprecate support for ksize == 1 but stride != 1.
# TODO we should deprecate support for ksize == 1 but stride != 1.
if
not
subm
:
if
not
subm
:
self
.
conv1x1
&=
kv_stride
==
1
self
.
conv1x1
&=
kv_stride
==
1
self
.
stride
=
expand_nd
(
ndim
,
stride
)
if
self
.
conv1x1
:
self
.
padding
=
expand_nd
(
ndim
,
padding
)
assert
self
.
padding
==
[
0
]
*
ndim
,
"padding must be zero for 1x1 conv (k=1,s=1)"
self
.
dilation
=
expand_nd
(
ndim
,
dilation
)
self
.
transposed
=
transposed
self
.
transposed
=
transposed
self
.
inverse
=
inverse
self
.
inverse
=
inverse
self
.
output_padding
=
expand_nd
(
ndim
,
output_padding
)
self
.
output_padding
=
expand_nd
(
ndim
,
output_padding
)
...
@@ -165,20 +136,39 @@ class SparseConvolution(SparseModule):
...
@@ -165,20 +136,39 @@ class SparseConvolution(SparseModule):
s
+=
f
', algo=
{
self
.
algo
}
'
s
+=
f
', algo=
{
self
.
algo
}
'
return
s
.
format
(
**
self
.
__dict__
)
return
s
.
format
(
**
self
.
__dict__
)
def
_calculate_fan_in_and_fan_out
(
self
):
receptive_field_size
=
1
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for
s
in
self
.
kernel_size
:
receptive_field_size
*=
s
fan_in
=
self
.
in_channels
*
receptive_field_size
fan_out
=
self
.
out_channels
*
receptive_field_size
return
fan_in
,
fan_out
def
_calculate_correct_fan
(
self
,
mode
):
mode
=
mode
.
lower
()
valid_modes
=
[
'fan_in'
,
'fan_out'
]
if
mode
not
in
valid_modes
:
raise
ValueError
(
"Mode {} not supported, please use one of {}"
.
format
(
mode
,
valid_modes
))
fan_in
,
fan_out
=
self
.
_calculate_fan_in_and_fan_out
()
return
fan_in
if
mode
==
'fan_in'
else
fan_out
def
_custom_kaiming_uniform_
(
self
,
tensor
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
r
"""same as torch.init.kaiming_uniform_, with KRSC layout support
"""
fan
=
self
.
_calculate_correct_fan
(
mode
)
gain
=
calculate_gain
(
nonlinearity
,
a
)
std
=
gain
/
math
.
sqrt
(
fan
)
bound
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
with
torch
.
no_grad
():
return
tensor
.
uniform_
(
-
bound
,
bound
)
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
n
=
self
.
in_channels
self
.
_custom_kaiming_uniform_
(
self
.
weight
,
a
=
math
.
sqrt
(
5
))
# following commented code is used to make weight different layout have same value
# if self.algo != ConvAlgo.Native:
# weight2 = self.weight.data.permute(1, 2, 3, 0,
# 4).contiguous().clone()
# init.uniform_(weight2, 0, 0.001)
# self.weight.data[:] = weight2.permute(3, 0, 1, 2, 4)
# else:
# init.uniform_(self.weight, 0, 0.001)
init
.
kaiming_uniform_
(
self
.
weight
,
a
=
math
.
sqrt
(
0.005
))
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
fan_in
,
_
=
_calculate_fan_in_and_fan_out_hwio
(
fan_in
,
_
=
self
.
_calculate_fan_in_and_fan_out
()
self
.
weight
,
self
.
algo
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
...
@@ -271,14 +261,14 @@ class SparseConvolution(SparseModule):
...
@@ -271,14 +261,14 @@ class SparseConvolution(SparseModule):
indice_pairs
=
datas
.
indice_pairs
indice_pairs
=
datas
.
indice_pairs
indice_pair_num
=
datas
.
indice_pair_num
indice_pair_num
=
datas
.
indice_pair_num
out_spatial_shape
=
datas
.
spatial_shape
out_spatial_shape
=
datas
.
spatial_shape
assert
indice_pair_num
.
shape
[
0
]
==
np
.
prod
(
assert
datas
.
ksize
==
self
.
kernel_size
,
"inverse conv must have same kernel size as its couple conv"
self
.
kernel_size
),
"inverse conv must have same kernel size as its couple conv"
else
:
else
:
if
self
.
indice_key
is
not
None
and
datas
is
not
None
:
if
self
.
indice_key
is
not
None
and
datas
is
not
None
:
outids
=
datas
.
out_indices
outids
=
datas
.
out_indices
indice_pairs
=
datas
.
indice_pairs
indice_pairs
=
datas
.
indice_pairs
indice_pair_num
=
datas
.
indice_pair_num
indice_pair_num
=
datas
.
indice_pair_num
assert
self
.
subm
,
"only support reuse subm indices"
self
.
_check_subm_reuse_valid
(
input
,
spatial_shape
,
datas
)
else
:
else
:
if
input
.
benchmark
:
if
input
.
benchmark
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -369,19 +359,8 @@ class SparseConvolution(SparseModule):
...
@@ -369,19 +359,8 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits
=
datas
.
mask_argsort_fwd_splits
mask_argsort_fwd_splits
=
datas
.
mask_argsort_fwd_splits
mask_argsort_bwd_splits
=
datas
.
mask_argsort_bwd_splits
mask_argsort_bwd_splits
=
datas
.
mask_argsort_bwd_splits
masks
=
datas
.
masks
masks
=
datas
.
masks
assert
datas
.
is_subm
,
"only support reuse subm indices"
assert
self
.
subm
,
"only support reuse subm indices"
if
self
.
kernel_size
!=
datas
.
ksize
:
self
.
_check_subm_reuse_valid
(
input
,
spatial_shape
,
datas
)
raise
ValueError
(
f
"subm with same indice_key must have same kernel"
f
" size, expect
{
datas
.
ksize
}
, this layer
{
self
.
kernel_size
}
"
)
if
self
.
dilation
!=
datas
.
dilation
:
raise
ValueError
(
f
"subm with same indice_key must have same dilation"
f
", expect
{
datas
.
dilation
}
, this layer
{
self
.
dilation
}
"
)
if
input
.
spatial_shape
!=
datas
.
spatial_shape
:
raise
ValueError
(
f
"subm with same indice_key must have same spatial structure"
f
", expect
{
datas
.
spatial_shape
}
, input
{
spatial_shape
}
"
)
if
input
.
indices
.
shape
[
0
]
!=
datas
.
indices
.
shape
[
0
]:
raise
ValueError
(
f
"subm with same indice_key must have same num of indices"
f
", expect
{
datas
.
indices
.
shape
[
0
]
}
, input
{
input
.
indices
.
shape
[
0
]
}
"
)
else
:
else
:
with
input
.
_timer
.
namespace
(
"gen_pairs"
):
with
input
.
_timer
.
namespace
(
"gen_pairs"
):
...
@@ -471,6 +450,22 @@ class SparseConvolution(SparseModule):
...
@@ -471,6 +450,22 @@ class SparseConvolution(SparseModule):
return
out_tensor
return
out_tensor
def
_check_subm_reuse_valid
(
self
,
inp
:
SparseConvTensor
,
spatial_shape
:
List
[
int
],
datas
:
Union
[
ImplicitGemmIndiceData
,
IndiceData
]):
assert
datas
.
is_subm
,
"only support reuse subm indices"
if
self
.
kernel_size
!=
datas
.
ksize
:
raise
ValueError
(
f
"subm with same indice_key must have same kernel"
f
" size, expect
{
datas
.
ksize
}
, this layer
{
self
.
kernel_size
}
"
)
if
self
.
dilation
!=
datas
.
dilation
:
raise
ValueError
(
f
"subm with same indice_key must have same dilation"
f
", expect
{
datas
.
dilation
}
, this layer
{
self
.
dilation
}
"
)
if
inp
.
spatial_shape
!=
datas
.
spatial_shape
:
raise
ValueError
(
f
"subm with same indice_key must have same spatial structure"
f
", expect
{
datas
.
spatial_shape
}
, input
{
spatial_shape
}
"
)
if
inp
.
indices
.
shape
[
0
]
!=
datas
.
indices
.
shape
[
0
]:
raise
ValueError
(
f
"subm with same indice_key must have same num of indices"
f
", expect
{
datas
.
indices
.
shape
[
0
]
}
, input
{
inp
.
indices
.
shape
[
0
]
}
"
)
class
SparseConv1d
(
SparseConvolution
):
class
SparseConv1d
(
SparseConvolution
):
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
in_channels
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment