Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
aa26c99e
Commit
aa26c99e
authored
Dec 29, 2022
by
yan.yan
Browse files
working on quantization
parent
ee8c9465
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
965 additions
and
133 deletions
+965
-133
spconv/pytorch/quantization/__init__.py
spconv/pytorch/quantization/__init__.py
+0
-0
spconv/pytorch/quantization/conv_fused.py
spconv/pytorch/quantization/conv_fused.py
+475
-0
spconv/pytorch/quantization/fake_q.py
spconv/pytorch/quantization/fake_q.py
+23
-0
spconv/pytorch/quantization/fuse_mapping.py
spconv/pytorch/quantization/fuse_mapping.py
+130
-0
spconv/pytorch/quantization/intrinsic.py
spconv/pytorch/quantization/intrinsic.py
+35
-0
spconv/pytorch/quantization/modules.py
spconv/pytorch/quantization/modules.py
+21
-0
spconv/pytorch/quantization/utils.py
spconv/pytorch/quantization/utils.py
+52
-0
spconv/test_utils.py
spconv/test_utils.py
+5
-5
test/test_all_algo.py
test/test_all_algo.py
+224
-128
No files found.
spconv/pytorch/quantization/__init__.py
0 → 100644
View file @
aa26c99e
spconv/pytorch/quantization/conv_fused.py
0 → 100644
View file @
aa26c99e
# torch.ao.nn.intrinsic.qat.modules.conv_fused
import
math
import
torch
import
torch.nn
as
nn
import
torch.ao.nn.intrinsic
as
nni
import
torch.ao.nn.qat
as
nnqat
import
torch.nn.functional
as
F
from
torch.nn
import
init
from
torch.nn.utils
import
fuse_conv_bn_weights
from
torch.nn.modules.utils
import
_single
,
_pair
,
_triple
from
torch.nn.parameter
import
Parameter
from
typing
import
TypeVar
from
spconv.pytorch.conv
import
SparseConvolution
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
spconv.core
import
ConvAlgo
from
cumm
import
tensorview
as
tv
from
spconv.pytorch.core
import
SparseConvTensor
import
spconv.pytorch.quantization.intrinsic
as
snni
MOD
=
TypeVar
(
'MOD'
,
bound
=
SparseConvolution
)
class
_SparseConvBn
(
SparseConvolution
,
nni
.
_FusedModule
):
_version
=
2
_FLOAT_MODULE
=
MOD
_FLOAT_CONV_MODULE
=
SparseConvolution
def
__init__
(
self
,
# SparseConvolution args
ndim
:
int
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
3
,
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
bias
:
bool
=
True
,
subm
:
bool
=
False
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
transposed
:
bool
=
False
,
inverse
:
bool
=
False
,
indice_key
:
Optional
[
str
]
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
fp32_accum
:
Optional
[
bool
]
=
None
,
record_voxel_count
:
bool
=
False
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
act_alpha
:
float
=
0
,
act_beta
:
float
=
0
,
name
=
None
,
# BatchNormNd args
# num_features: out_channels
eps
=
1e-05
,
momentum
=
0.1
,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn
=
False
,
qconfig
=
None
):
SparseConvolution
.
__init__
(
self
,
ndim
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
bias
=
False
,
subm
=
subm
,
output_padding
=
output_padding
,
transposed
=
transposed
,
inverse
=
inverse
,
indice_key
=
indice_key
,
algo
=
algo
,
fp32_accum
=
fp32_accum
,
record_voxel_count
=
record_voxel_count
,
act_type
=
act_type
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
,
name
=
name
)
assert
qconfig
,
'qconfig must be provided for QAT module'
self
.
qconfig
=
qconfig
self
.
freeze_bn
=
freeze_bn
if
self
.
training
else
True
self
.
bn
=
nn
.
BatchNorm1d
(
out_channels
,
eps
,
momentum
,
True
,
True
)
self
.
weight_fake_quant
=
self
.
qconfig
.
weight
()
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
out_channels
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_bn_parameters
()
# this needs to be called after reset_bn_parameters,
# as they modify the same state
if
self
.
training
:
if
freeze_bn
:
self
.
freeze_bn_stats
()
else
:
self
.
update_bn_stats
()
else
:
self
.
freeze_bn_stats
()
self
.
_enable_slow_path_for_better_numerical_stability
=
False
def
reset_running_stats
(
self
):
self
.
bn
.
reset_running_stats
()
def
reset_bn_parameters
(
self
):
self
.
bn
.
reset_running_stats
()
init
.
uniform_
(
self
.
bn
.
weight
)
init
.
zeros_
(
self
.
bn
.
bias
)
# note: below is actully for conv, not BN
if
self
.
bias
is
not
None
:
fan_in
,
_
=
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
def
reset_parameters
(
self
):
super
(
_SparseConvBn
,
self
).
reset_parameters
()
def
update_bn_stats
(
self
):
self
.
freeze_bn
=
False
self
.
bn
.
training
=
True
return
self
def
freeze_bn_stats
(
self
):
self
.
freeze_bn
=
True
self
.
bn
.
training
=
False
return
self
def
_forward
(
self
,
input
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
assert
not
self
.
_enable_slow_path_for_better_numerical_stability
if
self
.
_enable_slow_path_for_better_numerical_stability
:
return
self
.
_forward_slow
(
input
)
return
self
.
_forward_approximate
(
input
,
add_input
)
def
_forward_approximate
(
self
,
input
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
"""Approximated method to fuse conv and bn. It requires only one forward pass.
conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
"""
assert
self
.
bn
.
running_var
is
not
None
running_std
=
torch
.
sqrt
(
self
.
bn
.
running_var
+
self
.
bn
.
eps
)
scale_factor
=
self
.
bn
.
weight
/
running_std
weight_shape
=
[
1
]
*
len
(
self
.
weight
.
shape
)
weight_shape
[
0
]
=
-
1
bias_shape
=
[
1
]
*
len
(
self
.
weight
.
shape
)
bias_shape
[
1
]
=
-
1
scaled_weight
=
self
.
weight_fake_quant
(
self
.
weight
*
scale_factor
.
reshape
(
weight_shape
))
# using zero bias here since the bias for original conv
# will be added later
if
self
.
bias
is
not
None
:
zero_bias
=
torch
.
zeros_like
(
self
.
bias
,
dtype
=
input
.
features
.
dtype
)
else
:
zero_bias
=
torch
.
zeros
(
self
.
out_channels
,
device
=
scaled_weight
.
device
,
dtype
=
input
.
features
.
dtype
)
conv_spt
=
self
.
_conv_forward
(
input
,
scaled_weight
,
zero_bias
)
conv
=
conv_spt
.
features
conv_orig
=
conv
/
scale_factor
.
reshape
(
bias_shape
)
if
self
.
bias
is
not
None
:
conv_orig
=
conv_orig
+
self
.
bias
.
reshape
(
bias_shape
)
conv
=
self
.
bn
(
conv_orig
)
if
add_input
is
not
None
:
conv
=
conv
+
add_input
.
features
conv_spt
=
conv_spt
.
replace_feature
(
conv
)
return
conv_spt
def
_forward_slow
(
self
,
input
:
SparseConvTensor
):
"""
TODO not implemented for now
A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
It requires two forward passes but handles the case bn.weight == 0
Conv: Y = WX + B_c
Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
Batch statistics:
mean_Y = Y.mean()
= Y0.mean() + B_c
var_Y = (Y - mean_Y)^2.mean()
= (Y0 - Y0.mean())^2.mean()
BN (r: bn.weight, beta: bn.bias):
Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
= r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
= (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
Fused Conv BN inference (running_std = sqrt(running_var + eps)):
Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
QAT with fused conv bn:
Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
= conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
"""
assert
self
.
bn
.
running_var
is
not
None
assert
self
.
bn
.
running_mean
is
not
None
# using zero bias here since the bias for original conv
# will be added later
zero_bias
=
torch
.
zeros
(
self
.
out_channels
,
device
=
self
.
weight
.
device
,
dtype
=
input
.
features
.
dtype
)
weight_shape
=
[
1
]
*
len
(
self
.
weight
.
shape
)
weight_shape
[
0
]
=
-
1
bias_shape
=
[
1
]
*
len
(
self
.
weight
.
shape
)
bias_shape
[
1
]
=
-
1
conv_out
=
torch
.
Tensor
()
if
self
.
bn
.
training
:
# needed to compute batch mean/std
conv_spt
=
self
.
_conv_forward
(
input
,
self
.
weight
,
zero_bias
)
conv_out
=
conv_spt
.
features
# update bn statistics
with
torch
.
no_grad
():
conv_out_bias
=
(
conv_out
if
self
.
bias
is
None
else
conv_out
+
self
.
bias
.
reshape
(
bias_shape
)
)
self
.
bn
(
conv_out_bias
)
# fused conv + bn without bias using bn running statistics
running_std
=
torch
.
sqrt
(
self
.
bn
.
running_var
+
self
.
bn
.
eps
)
scale_factor
=
self
.
bn
.
weight
/
running_std
scaled_weight
=
self
.
weight_fake_quant
(
self
.
weight
*
scale_factor
.
reshape
(
weight_shape
)
)
# fused conv without bias for inference: (r * W / running_std) * X
conv_bn_spt
=
self
.
_conv_forward
(
input
,
scaled_weight
,
zero_bias
)
conv_bn
=
conv_bn_spt
.
features
if
self
.
bn
.
training
:
avg_dims
=
[
0
]
+
list
(
range
(
2
,
len
(
self
.
weight
.
shape
)))
batch_mean
=
conv_out
.
mean
(
avg_dims
)
batch_var
=
torch
.
square
(
conv_out
-
batch_mean
.
reshape
(
bias_shape
)).
mean
(
avg_dims
)
batch_std
=
torch
.
sqrt
(
batch_var
+
self
.
bn
.
eps
)
# scale to use batch std in training mode
# conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
unscale_factor
=
running_std
/
batch_std
conv_bn
*=
unscale_factor
.
reshape
(
bias_shape
)
fused_mean
=
batch_mean
fused_std
=
batch_std
else
:
fused_mean
=
self
.
bn
.
running_mean
-
(
self
.
bias
if
self
.
bias
is
not
None
else
0
)
fused_std
=
running_std
# fused bias = beta - r * mean / std
fused_bias
=
self
.
bn
.
bias
-
self
.
bn
.
weight
*
fused_mean
/
fused_std
conv_bn
+=
fused_bias
.
reshape
(
bias_shape
)
# HACK to let conv bias particpiate in loss to avoid DDP error (parameters
# were not used in producing loss)
if
self
.
bias
is
not
None
:
conv_bn
+=
(
self
.
bias
-
self
.
bias
).
reshape
(
bias_shape
)
conv_bn_spt
=
conv_bn_spt
.
replace_feature
(
conv_bn
)
return
conv_bn_spt
return
conv_bn
def
extra_repr
(
self
):
# TODO(jerryzh): extend
return
super
(
_SparseConvBn
,
self
).
extra_repr
()
def
forward
(
self
,
input
):
return
self
.
_forward
(
input
)
def
train
(
self
,
mode
=
True
):
"""
Batchnorm's training behavior is using the self.training flag. Prevent
changing it if BN is frozen. This makes sure that calling `model.train()`
on a model with a frozen BN will behave properly.
"""
self
.
training
=
mode
if
not
self
.
freeze_bn
:
for
module
in
self
.
children
():
module
.
train
(
mode
)
return
self
# ===== Serialization version history =====
#
# Version 1/None
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- gamma : Tensor
# |--- beta : Tensor
# |--- running_mean : Tensor
# |--- running_var : Tensor
# |--- num_batches_tracked : Tensor
#
# Version 2
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- bn : Module
# |--- weight : Tensor (moved from v1.self.gamma)
# |--- bias : Tensor (moved from v1.self.beta)
# |--- running_mean : Tensor (moved from v1.self.running_mean)
# |--- running_var : Tensor (moved from v1.self.running_var)
# |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
version
=
local_metadata
.
get
(
'version'
,
None
)
if
version
is
None
or
version
==
1
:
# BN related parameters and buffers were moved into the BN module for v2
v2_to_v1_names
=
{
'bn.weight'
:
'gamma'
,
'bn.bias'
:
'beta'
,
'bn.running_mean'
:
'running_mean'
,
'bn.running_var'
:
'running_var'
,
'bn.num_batches_tracked'
:
'num_batches_tracked'
,
}
for
v2_name
,
v1_name
in
v2_to_v1_names
.
items
():
if
prefix
+
v1_name
in
state_dict
:
state_dict
[
prefix
+
v2_name
]
=
state_dict
[
prefix
+
v1_name
]
state_dict
.
pop
(
prefix
+
v1_name
)
elif
prefix
+
v2_name
in
state_dict
:
# there was a brief period where forward compatibility
# for this module was broken (between
# https://github.com/pytorch/pytorch/pull/38478
# and https://github.com/pytorch/pytorch/pull/38820)
# and modules emitted the v2 state_dict format while
# specifying that version == 1. This patches the forward
# compatibility issue by allowing the v2 style entries to
# be used.
pass
elif
strict
:
missing_keys
.
append
(
prefix
+
v2_name
)
super
(
_SparseConvBn
,
self
).
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
@
classmethod
def
from_float
(
cls
,
mod
):
r
"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
# The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
# has no __name__ (code is fine though)
assert
type
(
mod
)
==
cls
.
_FLOAT_MODULE
,
'qat.'
+
cls
.
__name__
+
'.from_float only works for '
+
\
cls
.
_FLOAT_MODULE
.
__name__
# type: ignore[attr-defined]
assert
hasattr
(
mod
,
'qconfig'
),
'Input float module must have qconfig defined'
assert
mod
.
qconfig
,
'Input float module must have a valid qconfig'
qconfig
=
mod
.
qconfig
conv
:
SparseConvolution
=
mod
[
0
]
bn
:
nn
.
BatchNorm1d
=
mod
[
1
]
qat_convbn
=
cls
(
conv
.
ndim
,
conv
.
in_channels
,
conv
.
out_channels
,
conv
.
kernel_size
,
conv
.
stride
,
conv
.
padding
,
conv
.
dilation
,
conv
.
groups
,
conv
.
bias
is
not
None
,
subm
=
conv
.
subm
,
output_padding
=
conv
.
output_padding
,
transposed
=
conv
.
transposed
,
inverse
=
conv
.
inverse
,
indice_key
=
conv
.
indice_key
,
algo
=
conv
.
algo
,
fp32_accum
=
conv
.
fp32_accum
,
record_voxel_count
=
conv
.
record_voxel_count
,
act_type
=
conv
.
act_type
,
act_alpha
=
conv
.
act_alpha
,
act_beta
=
conv
.
act_beta
,
name
=
conv
.
name
,
eps
=
bn
.
eps
,
momentum
=
bn
.
momentum
,
freeze_bn
=
False
,
qconfig
=
qconfig
)
qat_convbn
.
weight
=
conv
.
weight
qat_convbn
.
bias
=
conv
.
bias
qat_convbn
.
bn
.
weight
=
bn
.
weight
qat_convbn
.
bn
.
bias
=
bn
.
bias
qat_convbn
.
bn
.
running_mean
=
bn
.
running_mean
qat_convbn
.
bn
.
running_var
=
bn
.
running_var
# mypy error: Cannot determine type of 'num_batches_tracked'
qat_convbn
.
bn
.
num_batches_tracked
=
bn
.
num_batches_tracked
# type: ignore[has-type]
return
qat_convbn
def
to_float
(
self
):
cls
=
type
(
self
)
conv
=
cls
.
_FLOAT_CONV_MODULE
(
# type: ignore[attr-defined]
self
.
ndim
,
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
bias
is
not
None
,
subm
=
self
.
subm
,
output_padding
=
self
.
output_padding
,
transposed
=
self
.
transposed
,
inverse
=
self
.
inverse
,
indice_key
=
self
.
indice_key
,
algo
=
self
.
algo
,
fp32_accum
=
self
.
fp32_accum
,
record_voxel_count
=
self
.
record_voxel_count
,
act_type
=
self
.
act_type
,
act_alpha
=
self
.
act_alpha
,
act_beta
=
self
.
act_beta
,
name
=
self
.
name
)
conv
.
weight
=
torch
.
nn
.
Parameter
(
self
.
weight
.
detach
())
if
self
.
bias
is
not
None
:
conv
.
bias
=
torch
.
nn
.
Parameter
(
self
.
bias
.
detach
())
if
cls
.
_FLOAT_BN_MODULE
:
# type: ignore[attr-defined]
# fuse bn into conv
conv
.
weight
,
conv
.
bias
=
fuse_conv_bn_weights
(
conv
.
weight
,
conv
.
bias
,
self
.
bn
.
running_mean
,
self
.
bn
.
running_var
,
self
.
bn
.
eps
,
self
.
bn
.
weight
,
self
.
bn
.
bias
)
if
cls
.
_FLOAT_RELU_MODULE
:
# type: ignore[attr-defined]
modules
=
[]
modules
.
append
(
conv
)
relu
=
cls
.
_FLOAT_RELU_MODULE
()
# type: ignore[attr-defined]
modules
.
append
(
relu
)
conv_relu
=
cls
.
_FUSED_FLOAT_MODULE
(
*
modules
)
# type: ignore[attr-defined]
conv_relu
.
train
(
self
.
training
)
return
conv_relu
else
:
conv
.
train
(
self
.
training
)
return
conv
class
SparseConvBn
(
_SparseConvBn
):
r
"""
A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d`.
Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
to default.
Attributes:
freeze_bn:
weight_fake_quant: fake quant module for weight
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE
=
snni
.
SpconvBnNd
# type: ignore[assignment]
_FLOAT_CONV_MODULE
=
SparseConvolution
_FLOAT_BN_MODULE
=
nn
.
BatchNorm1d
_FLOAT_RELU_MODULE
=
None
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE
=
snni
.
SpconvReLUNd
class
SparseConvBnReLU
(
_SparseConvBn
):
r
"""
A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE
=
snni
.
SpconvBnReLUNd
# type: ignore[assignment]
_FLOAT_CONV_MODULE
=
SparseConvolution
_FLOAT_BN_MODULE
=
nn
.
BatchNorm1d
_FLOAT_RELU_MODULE
=
nn
.
ReLU
# type: ignore[assignment]
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE
=
snni
.
SpconvReLUNd
def
forward
(
self
,
input
):
x
=
_SparseConvBn
.
_forward
(
self
,
input
)
return
x
.
replace_feature
(
F
.
relu
(
x
.
features
))
@
classmethod
def
from_float
(
cls
,
mod
):
return
super
(
SparseConvBnReLU
,
cls
).
from_float
(
mod
)
spconv/pytorch/quantization/fake_q.py
0 → 100644
View file @
aa26c99e
from
torch.ao.quantization.fake_quantize
import
FusedMovingAvgObsFakeQuantize
,
fused_wt_fake_quant_range_neg_127_to_127
from
spconv.pytorch.core
import
SparseConvTensor
import
torch
from
torch.ao.quantization.qconfig
import
QConfig
from
torch.ao.quantization.observer
import
MovingAverageMinMaxObserver
class
SparseFusedMovingAvgObsFakeQuantize
(
FusedMovingAvgObsFakeQuantize
):
def
forward
(
self
,
input
:
SparseConvTensor
):
# add lines to support spconv
x
=
input
.
features
res_features
=
super
().
forward
(
x
)
return
input
.
replace_feature
(
res_features
)
default_symmetric_spconv_qat_qconfig
=
QConfig
(
activation
=
SparseFusedMovingAvgObsFakeQuantize
.
with_args
(
observer
=
MovingAverageMinMaxObserver
,
quant_min
=-
128
,
quant_max
=
127
,
dtype
=
torch
.
qint8
,
reduce_range
=
False
,
eps
=
2
**
-
12
),
weight
=
fused_wt_fake_quant_range_neg_127_to_127
)
spconv/pytorch/quantization/fuse_mapping.py
0 → 100644
View file @
aa26c99e
from
typing
import
Union
,
Callable
,
Tuple
,
Dict
,
Optional
,
Type
,
Any
import
torch.nn
as
nn
import
spconv.pytorch
as
spconv
from
.utils
import
fuse_spconv_bn_eval
from
.
import
intrinsic
as
snni
from
.conv_fused
import
SparseConvBn
,
SparseConvBnReLU
def
fuse_conv_bn
(
conv
,
bn
):
r
"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert
(
conv
.
training
==
bn
.
training
),
\
"Conv and BN both must be in the same mode (train or eval)."
fused_module_class_map
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvBnNd
,
spconv
.
SparseConv1d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvBnNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvBnNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvBnNd
,
}
if
conv
.
training
:
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv2d must match num_features of BatchNorm2d'
assert
bn
.
affine
,
'Only support fusing BatchNorm2d with affine set to True'
assert
bn
.
track_running_stats
,
'Only support fusing BatchNorm2d with tracking_running_stats set to True'
fused_module_class
=
fused_module_class_map
.
get
((
type
(
conv
)),
None
)
if
fused_module_class
is
not
None
:
return
fused_module_class
(
conv
,
bn
)
else
:
raise
NotImplementedError
(
"Cannot fuse train modules: {}"
.
format
((
conv
,
bn
)))
else
:
return
fuse_spconv_bn_eval
(
conv
,
bn
)
def
fuse_conv_bn_relu
(
conv
,
bn
,
relu
):
r
"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert
(
conv
.
training
==
bn
.
training
==
relu
.
training
),
\
"Conv and BN both must be in the same mode (train or eval)."
fused_module
:
Optional
[
Type
[
spconv
.
SparseSequential
]]
=
None
if
conv
.
training
:
map_to_fused_module_train
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseConv1d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvBnReLUNd
,
}
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv must match num_features of BatchNorm'
assert
bn
.
affine
,
'Only support fusing BatchNorm with affine set to True'
assert
bn
.
track_running_stats
,
'Only support fusing BatchNorm with tracking_running_stats set to True'
fused_module
=
map_to_fused_module_train
.
get
(
type
(
conv
),
None
)
if
fused_module
is
not
None
:
return
fused_module
(
conv
,
bn
,
relu
)
else
:
raise
NotImplementedError
(
"Cannot fuse train modules: {}"
.
format
((
conv
,
bn
,
relu
)))
else
:
map_to_fused_module_eval
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseConv1d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvReLUNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvReLUNd
,
}
fused_module
=
map_to_fused_module_eval
.
get
(
type
(
conv
),
None
)
if
fused_module
is
not
None
:
fused_conv
=
fuse_spconv_bn_eval
(
conv
,
bn
)
return
fused_module
(
fused_conv
,
relu
)
else
:
raise
NotImplementedError
(
"Cannot fuse eval modules: {}"
.
format
((
conv
,
bn
,
relu
)))
DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD
:
Dict
[
Tuple
,
Union
[
nn
.
Sequential
,
Callable
]]
=
{
(
spconv
.
SubMConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SubMConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseInverseConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseInverseConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SubMConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SubMConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseInverseConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseInverseConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SubMConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SubMConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseInverseConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseInverseConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
}
# Default map for swapping float module to qat modules
DEFAULT_SPCONV_QAT_MODULE_MAPPINGS
:
Dict
[
Callable
,
Any
]
=
{
# nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
snni
.
SpconvBnNd
:
SparseConvBn
,
snni
.
SpconvBnReLUNd
:
SparseConvBnReLU
,
}
spconv/pytorch/quantization/intrinsic.py
0 → 100644
View file @
aa26c99e
import
torch
from
torch.nn
import
Conv1d
,
Conv2d
,
Conv3d
,
ReLU
,
Linear
,
BatchNorm1d
,
BatchNorm2d
,
BatchNorm3d
from
torch.nn.utils.parametrize
import
type_before_parametrizations
import
torch.ao.nn.intrinsic
as
nni
from
spconv.pytorch.conv
import
SparseConvolution
class
SpconvReLUNd
(
nni
.
_FusedModule
):
r
"""This is a sequential container which calls the Conv3d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
relu
):
assert
isinstance
(
conv
,
SparseConvolution
)
and
isinstance
(
relu
,
ReLU
),
\
'Incorrect types for input modules{}{}'
.
format
(
type
(
conv
),
type
(
relu
))
super
().
__init__
(
conv
,
relu
)
class
SpconvBnNd
(
nni
.
_FusedModule
):
r
"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
bn
):
assert
isinstance
(
conv
,
SparseConvolution
)
and
isinstance
(
bn
,
BatchNorm1d
),
\
'Incorrect types for input modules{}{}'
.
format
(
type
(
conv
),
type
(
bn
))
super
().
__init__
(
conv
,
bn
)
class
SpconvBnReLUNd
(
nni
.
_FusedModule
):
r
"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
bn
,
relu
):
assert
isinstance
(
conv
,
SparseConvolution
)
and
isinstance
(
bn
,
BatchNorm1d
)
and
\
isinstance
(
relu
,
ReLU
),
'Incorrect types for input modules{}{}{}'
\
.
format
(
type
(
conv
),
type
(
bn
),
type
(
relu
))
super
().
__init__
(
conv
,
bn
,
relu
)
spconv/pytorch/quantization/modules.py
0 → 100644
View file @
aa26c99e
from
spconv.pytorch.modules
import
SparseModule
from
spconv.pytorch.conv
import
SparseConvolution
from
spconv.pytorch.core
import
SparseConvTensor
import
torch
class
ConvBatchNormAddAct
(
torch
.
nn
.
Module
):
"""for simple int8 residual op fusion, we can use this module to handle add.
"""
def
__init__
(
self
,
conv
:
SparseConvolution
,
bn
:
torch
.
nn
.
BatchNorm1d
,
act
:
torch
.
nn
.
ReLU
)
->
None
:
super
().
__init__
()
self
.
conv
=
conv
self
.
bn
=
bn
self
.
act
=
act
def
forward
(
self
,
x
:
SparseConvTensor
,
x_add
:
SparseConvTensor
):
x
=
self
.
conv
(
x
)
x
=
x
.
replace_feature
(
self
.
bn
(
x
.
features
))
return
self
.
act
(
x
.
replace_feature
(
x
.
features
+
x_add
.
features
))
spconv/pytorch/quantization/utils.py
0 → 100644
View file @
aa26c99e
import
torch
import
copy
from
cumm
import
tensorview
as
tv
def
fuse_spconv_bn_weights
(
conv_w_OKI
,
conv_b
,
bn_rm
,
bn_rv
,
bn_eps
,
bn_w
,
bn_b
):
NDim
=
conv_w_OKI
.
ndim
-
2
permute
=
[
0
,
NDim
+
1
]
+
[
i
+
1
for
i
in
range
(
NDim
)]
conv_w_OIK
=
conv_w_OKI
.
permute
(
*
permute
)
# OIDHW
if
conv_b
is
None
:
conv_b
=
torch
.
zeros_like
(
bn_rm
)
if
bn_w
is
None
:
bn_w
=
torch
.
ones_like
(
bn_rm
)
if
bn_b
is
None
:
bn_b
=
torch
.
zeros_like
(
bn_rm
)
bn_var_rsqrt
=
torch
.
rsqrt
(
bn_rv
+
bn_eps
)
conv_w_OIK
=
conv_w_OIK
*
(
bn_w
*
bn_var_rsqrt
).
reshape
([
-
1
]
+
[
1
]
*
(
len
(
conv_w_OIK
.
shape
)
-
1
))
conv_b
=
(
conv_b
-
bn_rm
)
*
bn_var_rsqrt
*
bn_w
+
bn_b
permute
=
[
0
,]
+
[
i
+
2
for
i
in
range
(
NDim
)]
+
[
1
,]
conv_w_OKI
=
conv_w_OIK
.
permute
(
*
permute
).
contiguous
()
return
torch
.
nn
.
Parameter
(
conv_w_OKI
),
torch
.
nn
.
Parameter
(
conv_b
)
def
fuse_spconv_bn_eval
(
conv
,
bn
):
"""
Given a conv Module `A` and an batch_norm module `B`, returns a conv
module `C` such that C(x) == B(A(x)) in inference mode.
"""
assert
(
not
(
conv
.
training
or
bn
.
training
)),
"Fusion only for eval!"
fused_conv
=
copy
.
deepcopy
(
conv
)
fused_conv
.
weight
,
fused_conv
.
bias
=
\
fuse_spconv_bn_weights
(
fused_conv
.
weight
,
fused_conv
.
bias
,
bn
.
running_mean
,
bn
.
running_var
,
bn
.
eps
,
bn
.
weight
,
bn
.
bias
)
return
fused_conv
def
fuse_spconv_act_eval
(
conv
,
act
):
"""
Given a conv Module `A` and an batch_norm module `B`, returns a conv
module `C` such that C(x) == B(A(x)) in inference mode.
"""
assert
(
not
(
conv
.
training
)),
"Fusion only for eval!"
fused_conv
=
copy
.
deepcopy
(
conv
)
if
isinstance
(
act
,
torch
.
nn
.
ReLU
):
fused_conv
.
act_type
=
tv
.
gemm
.
Activation
.
ReLU
elif
isinstance
(
act
,
torch
.
nn
.
LeakyReLU
):
fused_conv
.
act_type
=
tv
.
gemm
.
Activation
.
LeakyReLU
fused_conv
.
act_alpha
=
act
.
negative_slope
else
:
raise
NotImplementedError
return
fused_conv
spconv/test_utils.py
View file @
aa26c99e
...
@@ -53,7 +53,7 @@ class TestCase(unittest.TestCase):
...
@@ -53,7 +53,7 @@ class TestCase(unittest.TestCase):
print
(
"not equal rhs = "
,
y
)
print
(
"not equal rhs = "
,
y
)
np
.
testing
.
assert_array_equal
(
a
,
b
)
np
.
testing
.
assert_array_equal
(
a
,
b
)
def
assertAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
):
def
assertAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
,
msg
:
str
=
""
):
"""Asserts that two numpy arrays, or dicts of same, have near values.
"""Asserts that two numpy arrays, or dicts of same, have near values.
This does not support nested dicts.
This does not support nested dicts.
Args:
Args:
...
@@ -68,22 +68,22 @@ class TestCase(unittest.TestCase):
...
@@ -68,22 +68,22 @@ class TestCase(unittest.TestCase):
"""
"""
is_a_dict
=
isinstance
(
a
,
dict
)
is_a_dict
=
isinstance
(
a
,
dict
)
if
is_a_dict
!=
isinstance
(
b
,
dict
):
if
is_a_dict
!=
isinstance
(
b
,
dict
):
raise
ValueError
(
"Can't compare dict to non-dict, %s vs %s."
%
raise
ValueError
(
f
"Can't compare dict to non-dict, %s vs %s.
{
msg
}
"
%
(
a
,
b
))
(
a
,
b
))
if
is_a_dict
:
if
is_a_dict
:
self
.
assertCountEqual
(
a
.
keys
(),
self
.
assertCountEqual
(
a
.
keys
(),
b
.
keys
(),
b
.
keys
(),
msg
=
"mismatched keys, expected %s, got %s"
%
msg
=
f
"mismatched keys, expected %s, got %s
.
{
msg
}
"
%
(
a
.
keys
(),
b
.
keys
()))
(
a
.
keys
(),
b
.
keys
()))
for
k
in
a
:
for
k
in
a
:
self
.
_assertArrayLikeAllClose
(
a
[
k
],
self
.
_assertArrayLikeAllClose
(
a
[
k
],
b
[
k
],
b
[
k
],
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
,
atol
=
atol
,
msg
=
"%s: expected %s, got %s."
%
msg
=
f
"%s: expected %s, got %s.
{
msg
}
"
%
(
k
,
a
,
b
))
(
k
,
a
,
b
))
else
:
else
:
self
.
_assertArrayLikeAllClose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
self
.
_assertArrayLikeAllClose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
,
msg
=
msg
)
def
_assertArrayLikeAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
,
msg
=
None
):
def
_assertArrayLikeAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
,
msg
=
None
):
a
=
self
.
_GetNdArray
(
a
)
a
=
self
.
_GetNdArray
(
a
)
...
...
test/test_all_algo.py
View file @
aa26c99e
...
@@ -37,6 +37,8 @@ from cumm import tensorview as tv
...
@@ -37,6 +37,8 @@ from cumm import tensorview as tv
from
spconv.constants
import
SPCONV_ALLOW_TF32
from
spconv.constants
import
SPCONV_ALLOW_TF32
from
cumm.conv.bases
import
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvOpType
from
cumm.conv.bases
import
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvOpType
import
os
import
os
from
cumm.dtypes
import
get_npdtype_from_tvdtype
from
cumm.gemm.codeops
import
div_up
from
cumm.gemm.codeops
import
div_up
from
spconv.core
import
AlgoHint
,
ConvAlgo
from
spconv.core
import
AlgoHint
,
ConvAlgo
from
spconv.pytorch.conv
import
expand_nd
from
spconv.pytorch.conv
import
expand_nd
...
@@ -63,14 +65,18 @@ NUMPY_DTYPE_TO_TORCH = {
...
@@ -63,14 +65,18 @@ NUMPY_DTYPE_TO_TORCH = {
}
}
class
SparseConvTester
:
class
SparseConvTester
:
def
__init__
(
self
,
algo
:
ConvAlgo
,
subm
:
bool
,
shape
:
List
[
int
],
bs
:
int
,
dtype
:
np
.
dtype
,
N
:
int
,
K
:
int
,
C
:
int
,
def
__init__
(
self
,
algo
:
ConvAlgo
,
subm
:
bool
,
shape
:
List
[
int
],
bs
:
int
,
dtype
:
np
.
dtype
,
out_dtype
:
np
.
dtype
,
N
:
int
,
K
:
int
,
C
:
int
,
ksize
:
int
,
stride
:
int
,
padding
:
int
,
dilation
:
int
,
check_bias
:
bool
=
False
,
check_act
:
bool
=
False
)
->
None
:
ksize
:
int
,
stride
:
int
,
padding
:
int
,
dilation
:
int
,
check_bias
:
bool
=
False
,
check_act
:
bool
=
False
,
check_int8_infer
:
bool
=
False
,
dtype_comp
:
np
.
dtype
=
np
.
dtype
(
np
.
float32
))
->
None
:
ndim
=
3
ndim
=
3
transpose
=
False
transpose
=
False
self
.
shape
=
shape
self
.
shape
=
shape
self
.
bs
=
bs
self
.
bs
=
bs
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
out_dtype
=
out_dtype
self
.
dtype_th
=
NUMPY_DTYPE_TO_TORCH
[
dtype
]
self
.
dtype_th
=
NUMPY_DTYPE_TO_TORCH
[
dtype
]
self
.
out_dtype_th
=
NUMPY_DTYPE_TO_TORCH
[
out_dtype
]
self
.
K
=
K
self
.
K
=
K
self
.
C
=
C
self
.
C
=
C
self
.
ksize
=
expand_nd
(
ndim
,
ksize
)
self
.
ksize
=
expand_nd
(
ndim
,
ksize
)
...
@@ -82,6 +88,12 @@ class SparseConvTester:
...
@@ -82,6 +88,12 @@ class SparseConvTester:
op
=
expand_nd
(
ndim
,
0
)
op
=
expand_nd
(
ndim
,
0
)
self
.
kv
:
int
=
np
.
prod
(
self
.
ksize
)
self
.
kv
:
int
=
np
.
prod
(
self
.
ksize
)
self
.
num_split
=
1
if
algo
==
ConvAlgo
.
MaskImplicitGemm
else
2
self
.
num_split
=
1
if
algo
==
ConvAlgo
.
MaskImplicitGemm
else
2
self
.
output_scale
:
float
=
1.0
self
.
check_int8_infer
=
check_int8_infer
if
check_int8_infer
:
assert
check_bias
and
self
.
dtype
==
np
.
int8
self
.
dtype_comp
=
dtype_comp
if
not
subm
:
if
not
subm
:
if
transpose
:
if
transpose
:
out_shape
=
ops
.
get_deconv_output_size
(
shape
,
self
.
ksize
,
self
.
stride
,
out_shape
=
ops
.
get_deconv_output_size
(
shape
,
self
.
ksize
,
self
.
stride
,
...
@@ -91,6 +103,7 @@ class SparseConvTester:
...
@@ -91,6 +103,7 @@ class SparseConvTester:
self
.
padding
,
self
.
dilation
)
self
.
padding
,
self
.
dilation
)
else
:
else
:
out_shape
=
shape
out_shape
=
shape
self
.
scales
=
np
.
random
.
uniform
(
0.5
,
1.5
,
size
=
K
).
astype
(
dtype_comp
)
sparse_dict
=
generate_sparse_data
(
shape
,
[
N
]
*
bs
,
C
)
sparse_dict
=
generate_sparse_data
(
shape
,
[
N
]
*
bs
,
C
)
...
@@ -109,7 +122,7 @@ class SparseConvTester:
...
@@ -109,7 +122,7 @@ class SparseConvTester:
self
.
pair_native
=
pair_ref
self
.
pair_native
=
pair_ref
self
.
indice_num_per_loc
=
indice_num_per_loc
self
.
indice_num_per_loc
=
indice_num_per_loc
self
.
use_direct_table
=
True
self
.
use_direct_table
=
True
self
.
mask_int_count
=
div_up
(
self
.
kv
,
32
)
self
.
out_shape
=
out_shape
self
.
out_shape
=
out_shape
if
algo
==
ConvAlgo
.
Native
:
if
algo
==
ConvAlgo
.
Native
:
self
.
out_inds
:
torch
.
Tensor
=
out_inds
self
.
out_inds
:
torch
.
Tensor
=
out_inds
...
@@ -135,7 +148,6 @@ class SparseConvTester:
...
@@ -135,7 +148,6 @@ class SparseConvTester:
self
.
mask_argsort_fwd_splits
=
res
[
6
]
self
.
mask_argsort_fwd_splits
=
res
[
6
]
self
.
mask_argsort_bwd_splits
=
res
[
7
]
self
.
mask_argsort_bwd_splits
=
res
[
7
]
self
.
masks
=
res
[
8
]
self
.
masks
=
res
[
8
]
self
.
mask_int_count
=
res
[
9
]
self
.
out_inds_scalar
=
Fsp
.
_indice_to_scalar
(
self
.
out_inds
.
long
(),
[
bs
,
*
out_shape
])
self
.
out_inds_scalar
=
Fsp
.
_indice_to_scalar
(
self
.
out_inds
.
long
(),
[
bs
,
*
out_shape
])
...
@@ -159,18 +171,28 @@ class SparseConvTester:
...
@@ -159,18 +171,28 @@ class SparseConvTester:
self
.
check_act
=
check_act
self
.
check_act
=
check_act
self
.
subm
=
subm
self
.
subm
=
subm
self
.
output_add_scale
=
1.0
if
dtype
==
np
.
int8
:
if
dtype
==
np
.
int8
:
self
.
inp
=
np
.
random
.
randint
(
-
2
,
2
,
size
=
[
voxels_np
.
shape
[
0
],
self
.
inp
=
np
.
random
.
randint
(
-
1
,
1
,
size
=
[
voxels_np
.
shape
[
0
],
C
]).
astype
(
np
.
int8
)
C
]).
astype
(
np
.
int8
)
self
.
weight
=
np
.
random
.
randint
(
-
2
,
2
,
size
=
[
K
,
*
self
.
ksize
,
self
.
weight
=
np
.
random
.
randint
(
-
1
,
1
,
size
=
[
K
,
*
self
.
ksize
,
C
]).
astype
(
np
.
int8
)
C
]).
astype
(
np
.
int8
)
self
.
output
=
np
.
random
.
randint
(
-
2
,
2
,
size
=
[
self
.
output
=
np
.
random
.
randint
(
-
1
,
1
,
size
=
[
self
.
out_inds
.
shape
[
0
],
K
self
.
out_inds
.
shape
[
0
],
K
]).
astype
(
dtype
)
]).
astype
(
out_dtype
)
self
.
bias
=
np
.
random
.
randint
(
-
2
,
2
,
size
=
[
self
.
output_add
=
np
.
random
.
randint
(
-
1
,
1
,
size
=
[
K
self
.
out_inds
.
shape
[
0
],
K
]).
astype
(
dtype
)
]).
astype
(
out_dtype
)
self
.
output_add_scale
=
14.2
if
check_int8_infer
:
self
.
bias
=
np
.
random
.
uniform
(
-
5
,
5
,
size
=
[
K
]).
astype
(
dtype_comp
)
else
:
self
.
bias
=
np
.
random
.
randint
(
-
4
,
4
,
size
=
[
K
]).
astype
(
dtype
)
else
:
else
:
self
.
inp
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
self
.
inp
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
voxels_np
.
shape
[
0
],
C
voxels_np
.
shape
[
0
],
C
...
@@ -178,28 +200,31 @@ class SparseConvTester:
...
@@ -178,28 +200,31 @@ class SparseConvTester:
self
.
weight
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
K
,
*
self
.
ksize
,
C
]).
astype
(
dtype
)
self
.
weight
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
K
,
*
self
.
ksize
,
C
]).
astype
(
dtype
)
self
.
output
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
self
.
output
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
self
.
out_inds
.
shape
[
0
],
K
self
.
out_inds
.
shape
[
0
],
K
]).
astype
(
dtype
)
]).
astype
(
out_dtype
)
self
.
output_add
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
self
.
out_inds
.
shape
[
0
],
K
]).
astype
(
out_dtype
)
self
.
bias
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
self
.
bias
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
K
K
]).
astype
(
dtype
)
]).
astype
(
dtype
)
# self.bias[:] = 0
# self.scales[:] = 1
self
.
weight_ref
=
self
.
weight
.
transpose
(
1
,
2
,
3
,
0
,
4
)
self
.
weight_ref
=
self
.
weight
.
transpose
(
1
,
2
,
3
,
0
,
4
)
self
.
weight_ref
=
np
.
ascontiguousarray
(
self
.
weight_ref
).
reshape
(
-
1
,
K
,
C
)
self
.
weight_ref
=
np
.
ascontiguousarray
(
self
.
weight_ref
).
reshape
(
-
1
,
K
,
C
)
self
.
out_ref
,
self
.
din_ref
,
self
.
dw_ref
=
self
.
_get_ref_output
()
self
.
out_ref
,
self
.
din_ref
,
self
.
dw_ref
=
self
.
_get_ref_output
()
if
check_bias
:
self
.
out_ref
+=
self
.
bias
# relu
if
check_act
:
self
.
out_ref
=
np
.
maximum
(
self
.
out_ref
,
0
)
self
.
dw_ref
=
np
.
ascontiguousarray
(
self
.
dw_ref
.
transpose
(
1
,
0
,
2
).
reshape
(
K
,
*
self
.
ksize
,
C
))
self
.
dw_ref
=
np
.
ascontiguousarray
(
self
.
dw_ref
.
transpose
(
1
,
0
,
2
).
reshape
(
K
,
*
self
.
ksize
,
C
))
self
.
arch
=
tv
.
get_compute_capability
()
self
.
arch
=
tv
.
get_compute_capability
()
def
get_output_ref_spt
(
self
):
def
get_output_ref_spt
(
self
):
return
SparseConvTensor
(
torch
.
from_numpy
(
self
.
out_ref
).
cuda
(),
self
.
ref_out_inds
,
self
.
out_shape
,
self
.
bs
)
return
SparseConvTensor
(
torch
.
from_numpy
(
self
.
out_ref
).
cuda
(),
self
.
ref_out_inds
,
self
.
out_shape
,
self
.
bs
)
def
_get_ref_output
(
self
):
def
_get_ref_output
(
self
):
output_ref
=
np
.
zeros_like
(
self
.
output
,
dtype
=
np
.
float32
)
out_dtype
=
np
.
float32
if
self
.
dtype
==
np
.
int8
:
out_dtype
=
np
.
int32
output_ref
=
np
.
zeros_like
(
self
.
output
,
dtype
=
out_dtype
)
dinput_ref
=
np
.
zeros_like
(
self
.
inp
,
dtype
=
np
.
float32
)
dinput_ref
=
np
.
zeros_like
(
self
.
inp
,
dtype
=
np
.
float32
)
dw_ref
=
np
.
zeros_like
(
self
.
weight_ref
,
dw_ref
=
np
.
zeros_like
(
self
.
weight_ref
,
dtype
=
np
.
float32
)
# KV, K, C
dtype
=
np
.
float32
)
# KV, K, C
...
@@ -215,9 +240,14 @@ class SparseConvTester:
...
@@ -215,9 +240,14 @@ class SparseConvTester:
i_inds
=
self
.
indice_pairs_np
[
0
][
filter_offset
][:
nhot
]
i_inds
=
self
.
indice_pairs_np
[
0
][
filter_offset
][:
nhot
]
o_inds
=
self
.
indice_pairs_np
[
1
][
filter_offset
][:
nhot
]
o_inds
=
self
.
indice_pairs_np
[
1
][
filter_offset
][:
nhot
]
a
=
self
.
inp
[
i_inds
]
a
=
self
.
inp
[
i_inds
]
cc
=
a
.
astype
(
if
self
.
dtype
==
np
.
int8
:
np
.
float32
)
@
self
.
weight_ref
[
filter_offset
].
T
.
astype
(
cc
=
a
.
astype
(
np
.
float32
)
np
.
int32
)
@
self
.
weight_ref
[
filter_offset
].
T
.
astype
(
np
.
int32
)
else
:
cc
=
a
.
astype
(
np
.
float32
)
@
self
.
weight_ref
[
filter_offset
].
T
.
astype
(
np
.
float32
)
output_ref
[
o_inds
]
+=
cc
output_ref
[
o_inds
]
+=
cc
# we use random output as dout here
# we use random output as dout here
a
=
self
.
output
[
self
.
out_order
][
o_inds
]
a
=
self
.
output
[
self
.
out_order
][
o_inds
]
...
@@ -233,8 +263,25 @@ class SparseConvTester:
...
@@ -233,8 +263,25 @@ class SparseConvTester:
dw_res
=
out_gather
.
astype
(
dw_res
=
out_gather
.
astype
(
np
.
float32
).
T
@
inp_gather
.
astype
(
np
.
float32
)
np
.
float32
).
T
@
inp_gather
.
astype
(
np
.
float32
)
dw_ref
[
filter_offset
]
=
dw_res
dw_ref
[
filter_offset
]
=
dw_res
if
not
self
.
check_int8_infer
:
if
self
.
check_bias
:
output_ref
+=
self
.
bias
# relu
if
self
.
check_act
:
output_ref
=
np
.
maximum
(
output_ref
,
0
)
if
self
.
dtype
==
np
.
int8
:
if
self
.
dtype
==
np
.
int8
:
output_ref
=
np
.
clip
(
output_ref
,
-
127
,
127
)
if
self
.
check_int8_infer
:
rescaled
=
output_ref
.
astype
(
self
.
dtype_comp
)
*
self
.
scales
.
astype
(
self
.
dtype_comp
)
rescaled
+=
self
.
bias
.
astype
(
self
.
dtype_comp
)
rescaled
+=
self
.
output_add
[
self
.
out_order
].
astype
(
self
.
dtype_comp
)
*
self
.
output_add_scale
if
self
.
check_act
:
rescaled
=
np
.
maximum
(
rescaled
,
0
)
if
self
.
out_dtype
==
np
.
int8
:
output_ref
=
np
.
clip
(
np
.
round
(
rescaled
),
-
128
,
127
).
astype
(
np
.
int8
)
else
:
output_ref
=
rescaled
.
astype
(
self
.
out_dtype
)
else
:
output_ref
=
np
.
clip
(
output_ref
,
-
127
,
127
)
return
output_ref
,
dinput_ref
,
dw_ref
return
output_ref
,
dinput_ref
,
dw_ref
def
get_operands
(
self
,
op_type
:
ConvOpType
):
def
get_operands
(
self
,
op_type
:
ConvOpType
):
...
@@ -248,7 +295,7 @@ class SparseConvTester:
...
@@ -248,7 +295,7 @@ class SparseConvTester:
else
:
else
:
weight_tv
=
tv
.
from_numpy
(
self
.
weight
).
cuda
()
weight_tv
=
tv
.
from_numpy
(
self
.
weight
).
cuda
()
if
op_type
==
ConvOpType
.
kForward
:
if
op_type
==
ConvOpType
.
kForward
:
output_tv
=
zeros_func
(
list
(
self
.
output
.
shape
),
self
.
dtype
,
0
)
output_tv
=
zeros_func
(
list
(
self
.
output
.
shape
),
self
.
out_
dtype
,
0
)
else
:
else
:
output_tv
=
tv
.
from_numpy
(
self
.
output
).
cuda
()
output_tv
=
tv
.
from_numpy
(
self
.
output
).
cuda
()
return
inp_tv
,
weight_tv
,
output_tv
return
inp_tv
,
weight_tv
,
output_tv
...
@@ -280,26 +327,31 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -280,26 +327,31 @@ def _test_impgemm_conv_cuda(subm: bool):
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
shapes
=
[[
19
,
18
,
17
]]
shapes
=
[[
19
,
18
,
17
]]
batchsizes
=
[
1
]
batchsizes
=
[
1
]
dtypes
=
[
np
.
float32
,
np
.
float
16
]
#
dtypes = [
(
np.float32, np.float
32), (np.float16, np.float16)
]
# dtypes = [np.float16]
# dtypes = [np.float16]
dtypes
=
[(
np
.
int8
,
np
.
int8
),
(
np
.
int8
,
np
.
float32
),
(
np
.
int8
,
np
.
float16
)]
dtypes
=
[(
np
.
int8
,
np
.
int8
)]
# dtypes = [(np.float16, np.float16)]
# dtypes = [np.int8]
test_case
=
TestCase
()
test_case
=
TestCase
()
# in_channels = [32]
# in_channels = [32]
# out_channels = [32, 48, 64]
# out_channels = [32, 48, 64]
in_channels
=
[
32
,
47
]
in_channels
=
[
32
,
47
]
out_channels
=
[
32
,
48
,
62
]
out_channels
=
[
32
,
48
,
62
]
#
in_channels = [
32
]
in_channels
=
[
16
]
#
out_channels = [
32
]
out_channels
=
[
16
]
multiple_base
=
16
multiple_base
=
16
if
subm
:
if
subm
:
ksizes
=
[
3
,
(
3
,
3
,
5
),
(
3
,
5
,
5
),
5
]
# ksizes = [3, (3, 3, 5), (3, 5, 5), 5]
ksizes
=
[
3
]
strides
=
[
1
]
strides
=
[
1
]
paddings
=
[
0
]
paddings
=
[
0
]
dilations
=
[
1
]
dilations
=
[
1
]
else
:
else
:
ksizes
=
[
2
,
3
,
(
3
,
3
,
4
),
4
,
(
4
,
5
,
5
),
5
]
ksizes
=
[
2
,
3
,
(
3
,
3
,
4
),
4
,
(
4
,
5
,
5
),
5
]
ksizes
=
[
2
,
3
]
strides
=
[
1
,
2
,
3
]
strides
=
[
1
,
2
,
3
]
paddings
=
[
0
,
1
]
paddings
=
[
0
,
1
]
dilations
=
[
1
,
2
]
dilations
=
[
1
,
2
]
...
@@ -310,9 +362,16 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -310,9 +362,16 @@ def _test_impgemm_conv_cuda(subm: bool):
]
]
arch
=
torch
.
cuda
.
get_device_capability
()
arch
=
torch
.
cuda
.
get_device_capability
()
force_nvrtc
=
False
force_nvrtc
=
False
for
shape
,
bs
,
C
,
K
,
k
,
s
,
p
,
d
,
algo
,
dtype
in
tqdm
.
tqdm
(
params_grid
(
for
shape
,
bs
,
C
,
K
,
k
,
s
,
p
,
d
,
algo
,
dtype
_outdtype
in
tqdm
.
tqdm
(
params_grid
(
shapes
,
batchsizes
,
in_channels
,
out_channels
,
ksizes
,
shapes
,
batchsizes
,
in_channels
,
out_channels
,
ksizes
,
strides
,
paddings
,
dilations
,
algos
,
dtypes
)):
strides
,
paddings
,
dilations
,
algos
,
dtypes
)):
dtype
,
out_dtype
=
dtype_outdtype
if
(
C
%
16
!=
0
or
K
%
16
!=
0
)
and
dtype
==
np
.
int8
:
continue
dcomp
=
np
.
float32
check_int8_infer
=
True
if
dtype
!=
np
.
int8
:
check_int8_infer
=
False
shape_prod
=
np
.
prod
(
shape
)
shape_prod
=
np
.
prod
(
shape
)
num_batch
=
np
.
random
.
randint
(
int
(
0.2
*
shape_prod
),
int
(
0.7
*
shape_prod
))
num_batch
=
np
.
random
.
randint
(
int
(
0.2
*
shape_prod
),
int
(
0.7
*
shape_prod
))
# C = np.random.randint(int(0.3 * C), int(0.7 * C))
# C = np.random.randint(int(0.3 * C), int(0.7 * C))
...
@@ -320,32 +379,51 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -320,32 +379,51 @@ def _test_impgemm_conv_cuda(subm: bool):
multipler
=
max
(
C
,
K
)
/
multiple_base
multipler
=
max
(
C
,
K
)
/
multiple_base
multipler
=
max
(
multipler
,
1.0
)
multipler
=
max
(
multipler
,
1.0
)
# print(num_batch)
# print(num_batch)
tester
=
SparseConvTester
(
algo
,
subm
,
shape
,
bs
,
dtype
,
num_batch
,
K
,
C
,
k
,
s
,
p
,
d
,
check_bias
=
True
,
check_act
=
True
)
tester
=
SparseConvTester
(
algo
,
subm
,
shape
,
bs
,
dtype
,
out_dtype
,
num_batch
,
K
,
C
,
k
,
s
,
p
,
d
,
check_bias
=
True
,
check_act
=
True
,
check_int8_infer
=
check_int8_infer
,
dtype_comp
=
np
.
float32
)
enable_dy_mask
=
tester
.
kv
>
32
output_add_cuda
=
tv
.
from_numpy
(
tester
.
output_add
).
cuda
()
bias
=
None
bias
=
None
scales
=
None
act
=
tv
.
gemm
.
Activation
.
None_
act
=
tv
.
gemm
.
Activation
.
None_
if
tester
.
check_bias
:
if
tester
.
check_bias
:
bias
=
tv
.
from_numpy
(
tester
.
bias
).
cuda
()
if
check_int8_infer
:
bias
=
tv
.
from_numpy
(
tester
.
bias
.
astype
(
dcomp
)).
cuda
()
else
:
bias
=
tv
.
from_numpy
(
tester
.
bias
).
cuda
()
if
check_int8_infer
:
scales
=
tv
.
from_numpy
(
tester
.
scales
.
astype
(
dcomp
)).
cuda
()
atol
,
rtol
=
dtype_to_tol
[
dtype
]
atol
,
rtol
=
dtype_to_tol
[
dtype
]
mask_width_to_mask_out_fwd
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
mask_width_to_mask_out_fwd
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
mask_width_to_mask_out_bwd
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
mask_width_to_mask_out_bwd
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
op_types
=
[
ConvOpType
.
kForward
,
ConvOpType
.
kBackwardInput
]
op_types
=
[
ConvOpType
.
kForward
,
ConvOpType
.
kBackwardInput
]
spk
=
1
spk
=
1
for
op_type
in
op_types
:
for
op_type
in
op_types
:
if
tester
.
dtype
==
np
.
int8
and
op_type
!=
ConvOpType
.
kForward
:
continue
inp_tv
,
weight_tv
,
output_tv
=
tester
.
get_operands
(
op_type
)
inp_tv
,
weight_tv
,
output_tv
=
tester
.
get_operands
(
op_type
)
if
SPCONV_CPP_GEMM
:
if
SPCONV_CPP_GEMM
:
avail_desps
=
CONV_CPP
.
get_all_available
(
inp_tv
,
weight_tv
,
output_tv
,
avail_desps
=
CONV_CPP
.
get_all_available
(
inp_tv
,
weight_tv
,
output_tv
,
NHWC
.
layout_type
.
value
,
NHWC
.
layout_type
.
value
,
NHWC
.
layout_type
.
value
,
NHWC
.
layout_type
.
value
,
NHWC
.
layout_type
.
value
,
NHWC
.
interleave
,
NHWC
.
interleave
,
NHWC
.
interleave
,
arch
,
op_type
.
value
,
-
1
,
True
,
False
,
NHWC
.
layout_type
.
value
,
NHWC
.
interleave
,
NHWC
.
interleave
,
NHWC
.
interleave
,
arch
,
op_type
.
value
,
-
1
,
True
,
False
,
use_tf32
=
SPCONV_ALLOW_TF32
)
use_tf32
=
SPCONV_ALLOW_TF32
,
bias
=
bias
if
bias
is
not
None
else
tv
.
Tensor
(),
scale
=
scales
if
scales
is
not
None
else
tv
.
Tensor
())
else
:
else
:
avail_desps
=
CONV
.
get_all_available
(
inp_tv
,
weight_tv
,
output_tv
,
NHWC
,
NHWC
,
NHWC
,
arch
,
op_type
,
-
1
,
avail_desps
=
CONV
.
get_all_available
(
inp_tv
,
weight_tv
,
output_tv
,
NHWC
,
NHWC
,
NHWC
,
arch
,
op_type
,
-
1
,
use_tf32
=
SPCONV_ALLOW_TF32
)
use_tf32
=
SPCONV_ALLOW_TF32
,
bias
=
bias
if
bias
is
not
None
else
tv
.
Tensor
(),
scale
=
scales
if
scales
is
not
None
else
tv
.
Tensor
())
if
op_type
==
ConvOpType
.
kForward
and
tester
.
check_act
:
if
op_type
==
ConvOpType
.
kForward
and
tester
.
check_act
:
act
=
tv
.
gemm
.
Activation
.
ReLU
act
=
tv
.
gemm
.
Activation
.
ReLU
else
:
else
:
act
=
tv
.
gemm
.
Activation
.
None_
act
=
tv
.
gemm
.
Activation
.
None_
assert
avail_desps
assert
avail_desps
for
desp
in
avail_desps
:
for
desp
in
avail_desps
:
dcomp
=
get_npdtype_from_tvdtype
(
desp
.
dcomp
)
if
enable_dy_mask
and
not
desp
.
dynamic_mask
:
continue
if
tester
.
check_int8_infer
and
not
desp
.
is_int8_inference
:
continue
if
not
subm
:
if
not
subm
:
if
op_type
==
ConvOpType
.
kForward
:
if
op_type
==
ConvOpType
.
kForward
:
output_tv
.
zero_
()
output_tv
.
zero_
()
...
@@ -353,11 +431,13 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -353,11 +431,13 @@ def _test_impgemm_conv_cuda(subm: bool):
inp_tv
.
zero_
()
inp_tv
.
zero_
()
# this algo must success
# this algo must success
mask_width
=
desp
.
tile_shape
[
0
]
mask_width
=
desp
.
tile_shape
[
0
]
alpha
=
1.0
if
tester
.
check_int8_infer
:
alpha
=
tester
.
output_scale
# if mask_width != 32:
# if mask_width != 32:
# continue
# continue
if
mask_width
not
in
mask_width_to_mask_out_fwd
:
if
mask_width
not
in
mask_width_to_mask_out_fwd
:
mask_width_to_mask_out_fwd
[
mask_width
]
=
torch
.
zeros
([
2
,
tester
.
mask_int_count
*
div_up
(
tester
.
out_inds
.
shape
[
0
],
mask_width
)],
mask_width_to_mask_out_fwd
[
mask_width
]
=
torch
.
zeros
([
2
,
div_up
(
tester
.
out_inds
.
shape
[
0
],
mask_width
)
,
tester
.
mask_int_count
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
tester
.
device
)
device
=
tester
.
device
)
mask_output_fwd
=
mask_width_to_mask_out_fwd
[
mask_width
]
mask_output_fwd
=
mask_width_to_mask_out_fwd
[
mask_width
]
...
@@ -365,6 +445,11 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -365,6 +445,11 @@ def _test_impgemm_conv_cuda(subm: bool):
bias_cur
=
bias
bias_cur
=
bias
if
op_type
!=
ConvOpType
.
kForward
:
if
op_type
!=
ConvOpType
.
kForward
:
bias_cur
=
None
bias_cur
=
None
output_add_cur_tv
=
tv
.
Tensor
()
output_add_cur
=
None
if
is_fwd
and
tester
.
check_int8_infer
:
output_add_cur
=
output_add_cuda
output_add_cur_tv
=
output_add_cur
if
subm
:
if
subm
:
if
desp
.
op_type
.
value
==
ConvOpType
.
kForward
.
value
:
if
desp
.
op_type
.
value
==
ConvOpType
.
kForward
.
value
:
indice_pairs
=
tester
.
pair_fwd
indice_pairs
=
tester
.
pair_fwd
...
@@ -376,10 +461,13 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -376,10 +461,13 @@ def _test_impgemm_conv_cuda(subm: bool):
# print([bin(x.item()) for x in masks])
# print([bin(x.item()) for x in masks])
for
j
in
range
(
tester
.
num_split
):
for
j
in
range
(
tester
.
num_split
):
beta
=
1
if
j
>
0
else
0
beta
=
1
if
j
>
0
else
0
if
bias_cur
is
not
None
:
if
bias_cur
is
not
None
and
not
tester
.
check_int8_infer
:
# this beta is used for C-beta (use C as bias, not standalone bias)
beta
=
1
beta
=
1
if
j
>
0
:
if
j
>
0
:
bias_cur
=
None
bias_cur
=
None
if
output_add_cur
is
not
None
and
tester
.
check_int8_infer
:
beta
=
tester
.
output_add_scale
mask_filter
=
tester
.
masks
[
j
].
item
()
mask_filter
=
tester
.
masks
[
j
].
item
()
reverse_mask
=
False
reverse_mask
=
False
if
desp
.
op_type
.
value
==
ConvOpType
.
kBackwardWeight
.
value
:
if
desp
.
op_type
.
value
==
ConvOpType
.
kBackwardWeight
.
value
:
...
@@ -396,7 +484,6 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -396,7 +484,6 @@ def _test_impgemm_conv_cuda(subm: bool):
# desp.is_nvrtc = True
# desp.is_nvrtc = True
# print(force_nvrtc, desp.op_type, op_type)
# print(force_nvrtc, desp.op_type, op_type)
if
SPCONV_CPP_GEMM
:
if
SPCONV_CPP_GEMM
:
CONV_CPP
.
run_with_tuned_result
(
CONV_CPP
.
run_with_tuned_result
(
ConvTuneResult
(
desp
,
tester
.
arch
,
spk
),
ConvTuneResult
(
desp
,
tester
.
arch
,
spk
),
desp
.
op_type
.
value
,
desp
.
op_type
.
value
,
...
@@ -410,13 +497,14 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -410,13 +497,14 @@ def _test_impgemm_conv_cuda(subm: bool):
reverse_mask
,
reverse_mask
,
mask_filter
=
mask_filter
,
mask_filter
=
mask_filter
,
mask_width
=
mask_width
,
mask_width
=
mask_width
,
alpha
=
alpha
,
beta
=
beta
,
beta
=
beta
,
verbose
=
False
,
verbose
=
False
,
force_nvrtc
=
force_nvrtc
,
force_nvrtc
=
force_nvrtc
,
bias
=
bias_cur
if
is_fwd
and
bias_cur
is
not
None
else
tv
.
Tensor
(),
bias
=
bias_cur
if
is_fwd
and
bias_cur
is
not
None
else
tv
.
Tensor
(),
scale
=
scales
if
is_fwd
and
scales
is
not
None
else
tv
.
Tensor
(),
act_type
=
act
,
act_type
=
act
,
mask_int_count
=
tester
.
mask_int_count
,
output_add
=
output_add_cur_tv
)
)
else
:
else
:
CONV
.
run_with_tuned_result
(
CONV
.
run_with_tuned_result
(
BestConvAlgoByProfile
(
desp
,
tester
.
arch
,
spk
),
BestConvAlgoByProfile
(
desp
,
tester
.
arch
,
spk
),
...
@@ -431,12 +519,14 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -431,12 +519,14 @@ def _test_impgemm_conv_cuda(subm: bool):
reverse_mask
,
reverse_mask
,
mask_filter
=
mask_filter
,
mask_filter
=
mask_filter
,
mask_width
=
mask_width
,
mask_width
=
mask_width
,
alpha
=
alpha
,
beta
=
beta
,
beta
=
beta
,
verbose
=
False
,
verbose
=
False
,
force_nvrtc
=
force_nvrtc
,
force_nvrtc
=
force_nvrtc
,
bias
=
bias_cur
if
is_fwd
else
None
,
bias
=
bias_cur
if
is_fwd
else
None
,
scale
=
scales
if
is_fwd
else
None
,
act_type
=
act
,
act_type
=
act
,
mask_int_count
=
tester
.
mask_int_count
output_add
=
output_add_cur
,
)
)
else
:
else
:
...
@@ -465,10 +555,13 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -465,10 +555,13 @@ def _test_impgemm_conv_cuda(subm: bool):
for
j
in
range
(
tester
.
num_split
):
for
j
in
range
(
tester
.
num_split
):
# beta = 1 if j == 1 else 0
# beta = 1 if j == 1 else 0
beta
=
1
if
j
>
0
else
0
beta
=
1
if
j
>
0
else
0
if
bias_cur
is
not
None
:
if
bias_cur
is
not
None
and
not
tester
.
check_int8_infer
:
# this beta is used for C-beta (use C as bias, not standalone bias)
beta
=
1
beta
=
1
if
j
>
0
:
if
j
>
0
:
bias_cur
=
None
bias_cur
=
None
if
output_add_cur
is
not
None
and
tester
.
check_int8_infer
:
beta
=
tester
.
output_add_scale
mask_filter
=
tester
.
masks
[
j
].
item
()
mask_filter
=
tester
.
masks
[
j
].
item
()
reverse_mask
=
False
reverse_mask
=
False
if
desp
.
op_type
.
value
==
ConvOpType
.
kBackwardWeight
.
value
:
if
desp
.
op_type
.
value
==
ConvOpType
.
kBackwardWeight
.
value
:
...
@@ -476,7 +569,6 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -476,7 +569,6 @@ def _test_impgemm_conv_cuda(subm: bool):
else
:
else
:
mask_op
=
mask_ops
[
j
]
mask_op
=
mask_ops
[
j
]
if
SPCONV_CPP_GEMM
:
if
SPCONV_CPP_GEMM
:
CONV_CPP
.
run_with_tuned_result
(
CONV_CPP
.
run_with_tuned_result
(
ConvTuneResult
(
desp
,
tester
.
arch
,
spk
),
ConvTuneResult
(
desp
,
tester
.
arch
,
spk
),
desp
.
op_type
.
value
,
desp
.
op_type
.
value
,
...
@@ -493,9 +585,10 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -493,9 +585,10 @@ def _test_impgemm_conv_cuda(subm: bool):
beta
=
beta
,
beta
=
beta
,
verbose
=
False
,
verbose
=
False
,
force_nvrtc
=
force_nvrtc
,
force_nvrtc
=
force_nvrtc
,
bias
=
bias
if
is_fwd
and
bias
is
not
None
else
tv
.
Tensor
(),
bias
=
bias_cur
if
is_fwd
and
bias_cur
is
not
None
else
tv
.
Tensor
(),
scale
=
scales
if
is_fwd
and
scales
is
not
None
else
tv
.
Tensor
(),
act_type
=
act
,
act_type
=
act
,
mask_int_count
=
tester
.
mask_int_count
,
output_add
=
output_add_cur_tv
,
)
)
else
:
else
:
CONV
.
run_with_tuned_result
(
CONV
.
run_with_tuned_result
(
...
@@ -514,9 +607,10 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -514,9 +607,10 @@ def _test_impgemm_conv_cuda(subm: bool):
beta
=
beta
,
beta
=
beta
,
verbose
=
False
,
verbose
=
False
,
force_nvrtc
=
force_nvrtc
,
force_nvrtc
=
force_nvrtc
,
bias
=
bias
if
is_fwd
else
None
,
bias
=
bias_cur
if
is_fwd
else
None
,
scale
=
scales
if
is_fwd
else
None
,
act_type
=
act
,
act_type
=
act
,
mask_int_count
=
tester
.
mask_int_count
,
output_add
=
output_add_cur
,
)
)
out_ref
=
tester
.
out_ref
out_ref
=
tester
.
out_ref
...
@@ -526,6 +620,8 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -526,6 +620,8 @@ def _test_impgemm_conv_cuda(subm: bool):
out_my
=
output_tv
.
cpu
().
numpy
()
out_my
=
output_tv
.
cpu
().
numpy
()
out_my
=
out_my
[
tester
.
out_order
]
out_my
=
out_my
[
tester
.
out_order
]
if
dtype
!=
np
.
float16
:
if
dtype
!=
np
.
float16
:
if
dtype
==
np
.
int8
:
print
(
"max int8 diff"
,
np
.
abs
(
out_ref
-
out_my
).
max
())
test_case
.
assertAllClose
(
out_ref
,
out_my
,
atol
=
atol
,
rtol
=
rtol
)
test_case
.
assertAllClose
(
out_ref
,
out_my
,
atol
=
atol
,
rtol
=
rtol
)
else
:
else
:
error_norm
=
np
.
linalg
.
norm
(
out_ref
.
reshape
(
-
1
)
-
out_my
.
reshape
(
-
1
))
error_norm
=
np
.
linalg
.
norm
(
out_ref
.
reshape
(
-
1
)
-
out_my
.
reshape
(
-
1
))
...
@@ -539,86 +635,86 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -539,86 +635,86 @@ def _test_impgemm_conv_cuda(subm: bool):
else
:
else
:
error_norm
=
np
.
linalg
.
norm
(
din_ref
.
reshape
(
-
1
)
-
din_my
.
reshape
(
-
1
))
error_norm
=
np
.
linalg
.
norm
(
din_ref
.
reshape
(
-
1
)
-
din_my
.
reshape
(
-
1
))
assert
error_norm
<
10
*
multipler
,
f
"
{
desp
}
,
{
error_norm
}
,
{
k
}
,
{
s
}
,
{
p
}
,
{
d
}
"
assert
error_norm
<
10
*
multipler
,
f
"
{
desp
}
,
{
error_norm
}
,
{
k
}
,
{
s
}
,
{
p
}
,
{
d
}
"
inp_tv
,
weight_tv
,
output_tv
=
tester
.
get_operands
(
ConvOpType
.
kBackwardWeight
)
if
not
tester
.
check_int8_infer
:
for
spk
in
[
1
,
4
,
16
,
64
]:
inp_tv
,
weight_tv
,
output_tv
=
tester
.
get_operands
(
ConvOpType
.
kBackwardWeight
)
for
mask_width
,
mask_output
in
mask_width_to_mask_out_fwd
.
items
():
for
spk
in
[
1
,
4
,
16
,
64
]:
if
SPCONV_CPP_GEMM
:
for
mask_width
,
mask_output
in
mask_width_to_mask_out_fwd
.
items
():
avail_desps
=
CONV_CPP
.
get_all_available
(
inp_tv
,
weight_tv
,
output_tv
,
if
SPCONV_CPP_GEMM
:
NHWC
.
layout_type
.
value
,
NHWC
.
layout_type
.
value
,
avail_desps
=
CONV_CPP
.
get_all_available
(
inp_tv
,
weight_tv
,
output_tv
,
NHWC
.
layout_type
.
value
,
NHWC
.
interleave
,
NHWC
.
interleave
,
NHWC
.
interleave
,
arch
,
NHWC
.
layout_type
.
value
,
NHWC
.
layout_type
.
value
,
ConvOpType
.
kBackwardWeight
.
value
,
mask_width
,
True
,
False
,
NHWC
.
layout_type
.
value
,
NHWC
.
interleave
,
NHWC
.
interleave
,
NHWC
.
interleave
,
arch
,
use_tf32
=
SPCONV_ALLOW_TF32
)
ConvOpType
.
kBackwardWeight
.
value
,
mask_width
,
True
,
False
,
else
:
use_tf32
=
SPCONV_ALLOW_TF32
)
avail_desps
=
CONV
.
get_all_available
(
inp_tv
,
weight_tv
,
output_tv
,
NHWC
,
NHWC
,
NHWC
,
arch
,
ConvOpType
.
kBackwardWeight
,
mask_width
,
use_tf32
=
SPCONV_ALLOW_TF32
)
for
desp
in
avail_desps
:
weight_tv
.
zero_
()
if
subm
:
indice_pairs
=
tester
.
pair_fwd
for
j
in
range
(
tester
.
num_split
):
beta
=
0
mask_filter
=
tester
.
masks
[
j
].
item
()
mask_op
=
mask_output
[
j
]
mask_op_tv
=
torch_tensor_to_tv
(
mask_op
,
dtype
=
tv
.
uint32
)
# mask_op_np = mask_op_tv.cpu().numpy()
# bit_ref = np.bitwise_or.reduce(mask_op_np, axis=0)
# bit_my = mask_filter
CONV
.
run_with_tuned_result
(
BestConvAlgoByProfile
(
desp
,
tester
.
arch
,
spk
),
desp
.
op_type
.
value
,
inp_tv
,
weight_tv
,
output_tv
,
mask_op_tv
,
torch_tensor_to_tv
(
tester
.
mask_argsort_fwd_splits
[
j
]),
tv
.
Tensor
(),
torch_tensor_to_tv
(
indice_pairs
),
reverse_mask
=
False
,
mask_filter
=
mask_filter
,
mask_width
=
mask_width
,
beta
=
beta
,
verbose
=
False
,
mask_int_count
=
tester
.
mask_int_count
,
)
else
:
else
:
indice_pairs
=
tester
.
pair_fwd
# inp -> out
avail_desps
=
CONV
.
get_all_available
(
inp_tv
,
weight_tv
,
output_tv
,
NHWC
,
NHWC
,
NHWC
,
arch
,
ConvOpType
.
kBackwardWeight
,
mask_width
,
mask_ops
=
tester
.
pair_mask_fwd_splits
use_tf32
=
SPCONV_ALLOW_TF32
)
mask_argsorts
=
tester
.
mask_argsort_fwd_splits
for
desp
in
avail_desps
:
for
j
in
range
(
tester
.
num_split
):
if
enable_dy_mask
and
not
desp
.
dynamic_mask
:
# beta = 1 if j == 1 else 0
continue
beta
=
0
weight_tv
.
zero_
()
mask_filter
=
tester
.
masks
[
j
].
item
()
if
subm
:
reverse_mask
=
False
indice_pairs
=
tester
.
pair_fwd
mask_op
=
mask_output
[
j
]
for
j
in
range
(
tester
.
num_split
):
beta
=
0
mask_filter
=
tester
.
masks
[
j
].
item
()
mask_op
=
mask_output
[
j
]
mask_op_tv
=
torch_tensor_to_tv
(
mask_op
,
dtype
=
tv
.
uint32
)
# mask_op_np = mask_op_tv.cpu().numpy()
# bit_ref = np.bitwise_or.reduce(mask_op_np, axis=0)
# bit_my = mask_filter
CONV
.
run_with_tuned_result
(
BestConvAlgoByProfile
(
desp
,
tester
.
arch
,
spk
),
desp
.
op_type
.
value
,
inp_tv
,
weight_tv
,
output_tv
,
mask_op_tv
,
torch_tensor_to_tv
(
tester
.
mask_argsort_fwd_splits
[
j
]),
tv
.
Tensor
(),
torch_tensor_to_tv
(
indice_pairs
),
reverse_mask
=
False
,
mask_filter
=
mask_filter
,
mask_width
=
mask_width
,
beta
=
beta
,
verbose
=
False
,
)
else
:
indice_pairs
=
tester
.
pair_fwd
# inp -> out
mask_ops
=
tester
.
pair_mask_fwd_splits
mask_argsorts
=
tester
.
mask_argsort_fwd_splits
for
j
in
range
(
tester
.
num_split
):
# beta = 1 if j == 1 else 0
beta
=
0
mask_filter
=
tester
.
masks
[
j
].
item
()
reverse_mask
=
False
mask_op
=
mask_output
[
j
]
CONV
.
run_with_tuned_result
(
CONV
.
run_with_tuned_result
(
BestConvAlgoByProfile
(
desp
,
tester
.
arch
,
spk
),
BestConvAlgoByProfile
(
desp
,
tester
.
arch
,
spk
),
desp
.
op_type
.
value
,
desp
.
op_type
.
value
,
inp_tv
,
inp_tv
,
weight_tv
,
weight_tv
,
output_tv
,
output_tv
,
torch_tensor_to_tv
(
mask_op
,
dtype
=
tv
.
uint32
),
torch_tensor_to_tv
(
mask_op
,
dtype
=
tv
.
uint32
),
torch_tensor_to_tv
(
mask_argsorts
[
j
]),
torch_tensor_to_tv
(
mask_argsorts
[
j
]),
torch_tensor_to_tv
(
mask_output
[
j
],
dtype
=
tv
.
uint32
),
torch_tensor_to_tv
(
mask_output
[
j
],
dtype
=
tv
.
uint32
),
torch_tensor_to_tv
(
indice_pairs
),
torch_tensor_to_tv
(
indice_pairs
),
reverse_mask
,
reverse_mask
,
mask_filter
=
mask_filter
,
mask_filter
=
mask_filter
,
mask_width
=
mask_width
,
mask_width
=
mask_width
,
beta
=
beta
,
beta
=
beta
,
verbose
=
False
,
verbose
=
False
,
mask_int_count
=
tester
.
mask_int_count
,
)
)
dw_ref
=
tester
.
dw_ref
dw_ref
=
tester
.
dw_ref
dw_my
=
weight_tv
.
cpu
().
numpy
()
dw_my
=
weight_tv
.
cpu
().
numpy
()
if
dtype
!=
np
.
float16
:
if
dtype
!=
np
.
float16
:
test_case
.
assertAllClose
(
dw_ref
,
dw_my
,
atol
=
atol
,
rtol
=
rtol
)
# print(desp, spk, K, C, mask_width, algo)
else
:
test_case
.
assertAllClose
(
dw_ref
,
dw_my
,
atol
=
atol
,
rtol
=
rtol
)
error_norm
=
np
.
linalg
.
norm
(
dw_ref
.
reshape
(
-
1
)
-
dw_my
.
reshape
(
-
1
))
else
:
# print(desp, error_norm)
error_norm
=
np
.
linalg
.
norm
(
dw_ref
.
reshape
(
-
1
)
-
dw_my
.
reshape
(
-
1
))
if
(
error_norm
>
5
):
# print(desp, error_norm)
print
(
f
"
{
desp
}
, Error=
{
error_norm
}
,
{
spk
}
"
)
if
(
error_norm
>
5
):
assert
error_norm
<
10
*
multipler
print
(
f
"
{
desp
}
, Error=
{
error_norm
}
,
{
spk
}
"
)
assert
error_norm
<
10
*
multipler
def
_test_native_conv_cuda
(
subm
:
bool
):
def
_test_native_conv_cuda
(
subm
:
bool
):
ndim
=
3
ndim
=
3
...
@@ -924,7 +1020,7 @@ def _test_native_conv_cuda(subm: bool):
...
@@ -924,7 +1020,7 @@ def _test_native_conv_cuda(subm: bool):
def
test_all_algo_unit
():
def
test_all_algo_unit
():
# for i in range(5):
# for i in range(5):
_test_impgemm_conv_cuda
(
True
)
#
_test_impgemm_conv_cuda(True)
_test_impgemm_conv_cuda
(
False
)
_test_impgemm_conv_cuda
(
False
)
# _test_native_conv_cuda(True)
# _test_native_conv_cuda(True)
# _test_native_conv_cuda(False)
# _test_native_conv_cuda(False)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment