Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
f8c25027
Commit
f8c25027
authored
Sep 06, 2022
by
yan.yan
Browse files
add act
parent
99c8a0bd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
242 additions
and
98 deletions
+242
-98
spconv/pytorch/conv.py
spconv/pytorch/conv.py
+60
-19
spconv/pytorch/functional.py
spconv/pytorch/functional.py
+126
-73
spconv/pytorch/ops.py
spconv/pytorch/ops.py
+56
-6
No files found.
spconv/pytorch/conv.py
View file @
f8c25027
...
@@ -35,6 +35,7 @@ from spconv.pytorch.modules import SparseModule
...
@@ -35,6 +35,7 @@ from spconv.pytorch.modules import SparseModule
from
spconv.constants
import
SAVED_WEIGHT_LAYOUT
,
ALL_WEIGHT_IS_KRSC
,
SPCONV_DEBUG_WEIGHT
from
spconv.constants
import
SAVED_WEIGHT_LAYOUT
,
ALL_WEIGHT_IS_KRSC
,
SPCONV_DEBUG_WEIGHT
from
spconv.utils
import
nullcontext
from
spconv.utils
import
nullcontext
from
torch.nn.init
import
calculate_gain
from
torch.nn.init
import
calculate_gain
from
cumm
import
tensorview
as
tv
FILTER_HWIO
=
False
FILTER_HWIO
=
False
...
@@ -65,6 +66,9 @@ class SparseConvolution(SparseModule):
...
@@ -65,6 +66,9 @@ class SparseConvolution(SparseModule):
algo
:
Optional
[
ConvAlgo
]
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
fp32_accum
:
Optional
[
bool
]
=
None
,
fp32_accum
:
Optional
[
bool
]
=
None
,
record_voxel_count
:
bool
=
False
,
record_voxel_count
:
bool
=
False
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
act_alpha
:
float
=
0
,
act_beta
:
float
=
0
,
name
=
None
):
name
=
None
):
super
(
SparseConvolution
,
self
).
__init__
(
name
=
name
)
super
(
SparseConvolution
,
self
).
__init__
(
name
=
name
)
assert
groups
==
1
,
"don't support groups for now"
assert
groups
==
1
,
"don't support groups for now"
...
@@ -131,6 +135,12 @@ class SparseConvolution(SparseModule):
...
@@ -131,6 +135,12 @@ class SparseConvolution(SparseModule):
self
.
bias
=
Parameter
(
torch
.
Tensor
(
out_channels
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
out_channels
))
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
act_type
=
act_type
self
.
act_alpha
=
act_alpha
self
.
act_beta
=
act_beta
if
self
.
conv1x1
:
assert
act_type
==
tv
.
gemm
.
Activation
.
None_
,
"conv1x1 don't support fused act"
self
.
reset_parameters
()
self
.
reset_parameters
()
if
hasattr
(
self
,
"_register_load_state_dict_pre_hook"
):
if
hasattr
(
self
,
"_register_load_state_dict_pre_hook"
):
self
.
_register_load_state_dict_pre_hook
(
self
.
_register_load_state_dict_pre_hook
(
...
@@ -139,8 +149,7 @@ class SparseConvolution(SparseModule):
...
@@ -139,8 +149,7 @@ class SparseConvolution(SparseModule):
def
get_max_num_voxels
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
get_max_num_voxels
(
self
)
->
Optional
[
torch
.
Tensor
]:
if
hasattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
):
if
hasattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
):
return
getattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
)
return
getattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
)
return
None
return
None
def
_load_weight_different_layout
(
self
,
state_dict
,
prefix
,
local_metadata
,
def
_load_weight_different_layout
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
strict
,
missing_keys
,
unexpected_keys
,
...
@@ -255,6 +264,12 @@ class SparseConvolution(SparseModule):
...
@@ -255,6 +264,12 @@ class SparseConvolution(SparseModule):
indices
=
input
.
indices
indices
=
input
.
indices
spatial_shape
=
input
.
spatial_shape
spatial_shape
=
input
.
spatial_shape
batch_size
=
input
.
batch_size
batch_size
=
input
.
batch_size
bias_for_training
=
self
.
bias
if
self
.
training
else
None
bias_for_infer
=
self
.
bias
if
not
self
.
training
else
None
if
self
.
training
:
msg
=
"act don't support backward, only used in inference"
assert
self
.
act_type
==
tv
.
gemm
.
Activation
.
None_
,
msg
if
not
self
.
subm
:
if
not
self
.
subm
:
if
self
.
transposed
:
if
self
.
transposed
:
out_spatial_shape
=
ops
.
get_deconv_output_size
(
out_spatial_shape
=
ops
.
get_deconv_output_size
(
...
@@ -393,19 +408,43 @@ class SparseConvolution(SparseModule):
...
@@ -393,19 +408,43 @@ class SparseConvolution(SparseModule):
indice_pairs_calc
=
indice_pairs
.
to
(
features
.
device
)
indice_pairs_calc
=
indice_pairs
.
to
(
features
.
device
)
if
self
.
subm
:
if
self
.
subm
:
out_features
=
Fsp
.
indice_subm_conv
(
out_features
=
Fsp
.
indice_subm_conv
(
features
,
self
.
weight
,
indice_pairs_calc
,
features
,
indice_pair_num
,
outids
.
shape
[
0
],
algo
,
input
.
_timer
)
self
.
weight
,
indice_pairs_calc
,
indice_pair_num
,
outids
.
shape
[
0
],
algo
,
input
.
_timer
,
bias_for_infer
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
)
else
:
else
:
if
self
.
inverse
:
if
self
.
inverse
:
out_features
=
Fsp
.
indice_inverse_conv
(
out_features
=
Fsp
.
indice_inverse_conv
(
features
,
self
.
weight
,
indice_pairs_calc
,
features
,
indice_pair_num
,
outids
.
shape
[
0
],
algo
)
self
.
weight
,
indice_pairs_calc
,
indice_pair_num
,
outids
.
shape
[
0
],
algo
,
bias_for_infer
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
)
else
:
else
:
out_features
=
Fsp
.
indice_conv
(
features
,
self
.
weight
,
out_features
=
Fsp
.
indice_conv
(
indice_pairs_calc
,
features
,
indice_pair_num
,
self
.
weight
,
outids
.
shape
[
0
],
algo
,
indice_pairs_calc
,
input
.
_timer
)
indice_pair_num
,
outids
.
shape
[
0
],
algo
,
input
.
_timer
,
bias_for_infer
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
)
else
:
else
:
datas
=
input
.
find_indice_pair
(
self
.
indice_key
)
datas
=
input
.
find_indice_pair
(
self
.
indice_key
)
...
@@ -507,9 +546,14 @@ class SparseConvolution(SparseModule):
...
@@ -507,9 +546,14 @@ class SparseConvolution(SparseModule):
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
num_activate_out
,
masks
,
self
.
training
,
self
.
subm
,
num_activate_out
,
masks
,
self
.
training
,
self
.
subm
,
input
.
_timer
,
self
.
fp32_accum
)
input
.
_timer
,
self
.
fp32_accum
,
if
self
.
bias
is
not
None
:
bias_for_infer
,
out_features
+=
self
.
bias
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
)
if
bias_for_training
is
not
None
:
out_features
+=
bias_for_training
if
input
.
benchmark
:
if
input
.
benchmark
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
interval
=
time
.
time
()
-
t
interval
=
time
.
time
()
-
t
...
@@ -519,12 +563,9 @@ class SparseConvolution(SparseModule):
...
@@ -519,12 +563,9 @@ class SparseConvolution(SparseModule):
out_tensor
.
benchmark_record
[
self
.
name
][
"num_out_points"
].
append
(
out_tensor
.
benchmark_record
[
self
.
name
][
"num_out_points"
].
append
(
out_features
.
shape
[
0
])
out_features
.
shape
[
0
])
if
not
self
.
subm
and
not
self
.
inverse
and
self
.
record_voxel_count
:
if
not
self
.
subm
and
not
self
.
inverse
and
self
.
record_voxel_count
:
if
hasattr
(
self
,
if
hasattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
):
_MAX_NUM_VOXELS_DURING_TRAINING
):
ops
.
maximum_value_int_
(
ops
.
maximum_value_int_
(
getattr
(
getattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
),
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
),
outids
.
shape
[
0
])
outids
.
shape
[
0
])
out_tensor
=
out_tensor
.
replace_feature
(
out_features
)
out_tensor
=
out_tensor
.
replace_feature
(
out_features
)
out_tensor
.
indices
=
outids
out_tensor
.
indices
=
outids
...
...
spconv/pytorch/functional.py
View file @
f8c25027
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
import
sys
import
sys
import
pickle
import
pickle
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -30,15 +30,18 @@ from pathlib import Path
...
@@ -30,15 +30,18 @@ from pathlib import Path
from
spconv.pytorch.hash
import
HashTable
from
spconv.pytorch.hash
import
HashTable
from
cumm.gemm.layout
import
to_stride
from
cumm.gemm.layout
import
to_stride
from
typing
import
List
from
typing
import
List
from
functools
import
reduce
from
functools
import
reduce
from
cumm
import
tensorview
as
tv
_MAX_INT32
=
2147483647
_MAX_INT32
=
2147483647
_T
=
TypeVar
(
"_T"
)
_T
=
TypeVar
(
"_T"
)
def
identity_decorator
(
func
:
_T
)
->
_T
:
def
identity_decorator
(
func
:
_T
)
->
_T
:
return
func
return
func
if
PYTORCH_VERSION
>=
[
1
,
6
,
0
]:
if
PYTORCH_VERSION
>=
[
1
,
6
,
0
]:
import
torch.cuda.amp
as
amp
import
torch.cuda.amp
as
amp
_TORCH_CUSTOM_FWD
=
amp
.
custom_fwd
(
cast_inputs
=
torch
.
float16
)
_TORCH_CUSTOM_FWD
=
amp
.
custom_fwd
(
cast_inputs
=
torch
.
float16
)
...
@@ -48,6 +51,7 @@ else:
...
@@ -48,6 +51,7 @@ else:
_TORCH_CUSTOM_FWD
=
identity_decorator
_TORCH_CUSTOM_FWD
=
identity_decorator
_TORCH_CUSTOM_BWD
=
identity_decorator
_TORCH_CUSTOM_BWD
=
identity_decorator
class
SparseConvFunction
(
Function
):
class
SparseConvFunction
(
Function
):
@
staticmethod
@
staticmethod
@
_TORCH_CUSTOM_FWD
@
_TORCH_CUSTOM_FWD
...
@@ -58,26 +62,34 @@ class SparseConvFunction(Function):
...
@@ -58,26 +62,34 @@ class SparseConvFunction(Function):
indice_pair_num
,
indice_pair_num
,
num_activate_out
,
num_activate_out
,
algo
,
algo
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
)):
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
act_alpha
:
float
=
0.0
,
act_beta
:
float
=
0.0
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
):
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
ctx
.
algo
=
algo
ctx
.
algo
=
algo
ctx
.
timer
=
timer
ctx
.
timer
=
timer
try
:
try
:
return
ops
.
indice_conv
(
features
,
return
ops
.
indice_conv
(
features
,
filters
,
filters
,
indice_pairs
,
indice_pairs
,
indice_pair_num
,
indice_pair_num
,
num_activate_out
,
num_activate_out
,
False
,
False
,
algo
=
algo
,
algo
=
algo
,
timer
=
timer
)
timer
=
timer
,
bias
=
bias
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
,
act_type
=
act_type
)
except
Exception
as
e
:
except
Exception
as
e
:
msg
=
"[Exception|indice_conv]"
msg
=
"[Exception|indice_conv]"
msg
+=
f
"feat=
{
features
.
shape
}
,w=
{
filters
.
shape
}
,pair=
{
indice_pairs
.
shape
}
,"
msg
+=
f
"feat=
{
features
.
shape
}
,w=
{
filters
.
shape
}
,pair=
{
indice_pairs
.
shape
}
,"
msg
+=
f
"pairnum=
{
indice_pair_num
}
,act=
{
num_activate_out
}
,algo=
{
algo
}
"
msg
+=
f
"pairnum=
{
indice_pair_num
}
,act=
{
num_activate_out
}
,algo=
{
algo
}
"
print
(
msg
,
file
=
sys
.
stderr
)
print
(
msg
,
file
=
sys
.
stderr
)
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
raise
e
raise
e
@
staticmethod
@
staticmethod
@
once_differentiable
@
once_differentiable
...
@@ -100,9 +112,9 @@ class SparseConvFunction(Function):
...
@@ -100,9 +112,9 @@ class SparseConvFunction(Function):
msg
+=
f
"pairnum=
{
indice_pair_num
}
,do=
{
grad_output
.
shape
}
"
msg
+=
f
"pairnum=
{
indice_pair_num
}
,do=
{
grad_output
.
shape
}
"
print
(
msg
,
file
=
sys
.
stderr
)
print
(
msg
,
file
=
sys
.
stderr
)
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
raise
e
raise
e
return
input_bp
,
filters_bp
,
None
,
None
,
None
,
None
,
None
return
input_bp
,
filters_bp
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
SparseInverseConvFunction
(
Function
):
class
SparseInverseConvFunction
(
Function
):
...
@@ -115,27 +127,35 @@ class SparseInverseConvFunction(Function):
...
@@ -115,27 +127,35 @@ class SparseInverseConvFunction(Function):
indice_pair_num
,
indice_pair_num
,
num_activate_out
,
num_activate_out
,
algo
,
algo
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
)):
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
act_alpha
:
float
=
0.0
,
act_beta
:
float
=
0.0
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
):
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
ctx
.
algo
=
algo
ctx
.
algo
=
algo
ctx
.
timer
=
timer
ctx
.
timer
=
timer
try
:
try
:
return
ops
.
indice_conv
(
features
,
return
ops
.
indice_conv
(
features
,
filters
,
filters
,
indice_pairs
,
indice_pairs
,
indice_pair_num
,
indice_pair_num
,
num_activate_out
,
num_activate_out
,
True
,
True
,
False
,
False
,
algo
=
algo
,
algo
=
algo
,
timer
=
timer
)
timer
=
timer
,
bias
=
bias
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
,
act_type
=
act_type
)
except
Exception
as
e
:
except
Exception
as
e
:
msg
=
"[Exception|indice_conv|inverse]"
msg
=
"[Exception|indice_conv|inverse]"
msg
+=
f
"feat=
{
features
.
shape
}
,w=
{
filters
.
shape
}
,pair=
{
indice_pairs
.
shape
}
,"
msg
+=
f
"feat=
{
features
.
shape
}
,w=
{
filters
.
shape
}
,pair=
{
indice_pairs
.
shape
}
,"
msg
+=
f
"pairnum=
{
indice_pair_num
}
,act=
{
num_activate_out
}
,algo=
{
algo
}
"
msg
+=
f
"pairnum=
{
indice_pair_num
}
,act=
{
num_activate_out
}
,algo=
{
algo
}
"
print
(
msg
,
file
=
sys
.
stderr
)
print
(
msg
,
file
=
sys
.
stderr
)
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
raise
e
raise
e
@
staticmethod
@
staticmethod
@
once_differentiable
@
once_differentiable
...
@@ -159,9 +179,9 @@ class SparseInverseConvFunction(Function):
...
@@ -159,9 +179,9 @@ class SparseInverseConvFunction(Function):
msg
+=
f
"pairnum=
{
indice_pair_num
}
,do=
{
grad_output
.
shape
}
"
msg
+=
f
"pairnum=
{
indice_pair_num
}
,do=
{
grad_output
.
shape
}
"
print
(
msg
,
file
=
sys
.
stderr
)
print
(
msg
,
file
=
sys
.
stderr
)
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
raise
e
raise
e
return
input_bp
,
filters_bp
,
None
,
None
,
None
,
None
,
None
return
input_bp
,
filters_bp
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
SparseImplicitGemmFunction
(
Function
):
class
SparseImplicitGemmFunction
(
Function
):
...
@@ -181,25 +201,28 @@ class SparseImplicitGemmFunction(Function):
...
@@ -181,25 +201,28 @@ class SparseImplicitGemmFunction(Function):
is_train
:
bool
,
is_train
:
bool
,
is_subm
:
bool
,
is_subm
:
bool
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
fp32_accum
:
Optional
[
bool
]
=
None
):
fp32_accum
:
Optional
[
bool
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
act_alpha
:
float
=
0.0
,
act_beta
:
float
=
0.0
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
):
try
:
try
:
out
,
mask_out
,
mask_width
=
ops
.
implicit_gemm
(
features
,
filters
,
out
,
mask_out
,
mask_width
=
ops
.
implicit_gemm
(
pair_fwd
,
features
,
filters
,
pair_fwd
,
pair_mask_fwd_splits
,
pair_mask_fwd_splits
,
mask_argsort_fwd_splits
,
num_activate_out
,
masks
,
is_train
,
mask_argsort_fwd_splits
,
is_subm
,
timer
,
fp32_accum
,
bias
,
act_alpha
,
act_beta
,
num_activate_out
,
masks
,
act_type
)
is_train
,
is_subm
,
timer
,
fp32_accum
)
except
Exception
as
e
:
except
Exception
as
e
:
msg
=
"[Exception|implicit_gemm]"
msg
=
"[Exception|implicit_gemm]"
msg
+=
f
"feat=
{
features
.
shape
}
,w=
{
filters
.
shape
}
,pair=
{
pair_fwd
.
shape
}
,"
msg
+=
f
"feat=
{
features
.
shape
}
,w=
{
filters
.
shape
}
,pair=
{
pair_fwd
.
shape
}
,"
msg
+=
f
"act=
{
num_activate_out
}
,issubm=
{
is_subm
}
,istrain=
{
is_train
}
"
msg
+=
f
"act=
{
num_activate_out
}
,issubm=
{
is_subm
}
,istrain=
{
is_train
}
"
print
(
msg
,
file
=
sys
.
stderr
)
print
(
msg
,
file
=
sys
.
stderr
)
spconv_save_debug_data
((
pair_fwd
,
pair_bwd
,
pair_mask_fwd_splits
,
spconv_save_debug_data
(
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
(
pair_fwd
,
pair_bwd
,
pair_mask_fwd_splits
,
masks
))
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
,
raise
e
mask_argsort_bwd_splits
,
masks
))
raise
e
ctx
.
save_for_backward
(
features
,
filters
,
pair_fwd
,
pair_bwd
)
ctx
.
save_for_backward
(
features
,
filters
,
pair_fwd
,
pair_bwd
)
ctx
.
mask_width
=
mask_width
ctx
.
mask_width
=
mask_width
ctx
.
mask_out
=
mask_out
ctx
.
mask_out
=
mask_out
...
@@ -253,12 +276,13 @@ class SparseImplicitGemmFunction(Function):
...
@@ -253,12 +276,13 @@ class SparseImplicitGemmFunction(Function):
msg
+=
f
"feat=
{
features
.
shape
}
,w=
{
filters
.
shape
}
,pair=
{
pair_fwd
.
shape
}
,"
msg
+=
f
"feat=
{
features
.
shape
}
,w=
{
filters
.
shape
}
,pair=
{
pair_fwd
.
shape
}
,"
msg
+=
f
"issubm=
{
is_subm
}
,do=
{
grad_output
.
shape
}
"
msg
+=
f
"issubm=
{
is_subm
}
,do=
{
grad_output
.
shape
}
"
print
(
msg
,
file
=
sys
.
stderr
)
print
(
msg
,
file
=
sys
.
stderr
)
spconv_save_debug_data
((
pair_fwd
,
pair_bwd
,
pair_mask_fwd_splits
,
spconv_save_debug_data
(
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
(
pair_fwd
,
pair_bwd
,
pair_mask_fwd_splits
,
masks
))
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
,
raise
e
mask_argsort_bwd_splits
,
masks
))
raise
e
None_9
=
[
None
]
*
1
2
None_9
=
[
None
]
*
1
6
return
(
input_bp
,
filters_bp
,
*
None_9
)
return
(
input_bp
,
filters_bp
,
*
None_9
)
...
@@ -272,27 +296,35 @@ class SubMConvFunction(Function):
...
@@ -272,27 +296,35 @@ class SubMConvFunction(Function):
indice_pair_num
,
indice_pair_num
,
num_activate_out
,
num_activate_out
,
algo
,
algo
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
)):
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
act_alpha
:
float
=
0.0
,
act_beta
:
float
=
0.0
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
):
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
ctx
.
algo
=
algo
ctx
.
algo
=
algo
ctx
.
timer
=
timer
ctx
.
timer
=
timer
try
:
try
:
return
ops
.
indice_conv
(
features
,
return
ops
.
indice_conv
(
features
,
filters
,
filters
,
indice_pairs
,
indice_pairs
,
indice_pair_num
,
indice_pair_num
,
num_activate_out
,
num_activate_out
,
False
,
False
,
True
,
True
,
algo
=
algo
,
algo
=
algo
,
timer
=
timer
)
timer
=
timer
,
bias
=
bias
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
,
act_type
=
act_type
)
except
Exception
as
e
:
except
Exception
as
e
:
msg
=
"[Exception|indice_conv|subm]"
msg
=
"[Exception|indice_conv|subm]"
msg
+=
f
"feat=
{
features
.
shape
}
,w=
{
filters
.
shape
}
,pair=
{
indice_pairs
.
shape
}
,"
msg
+=
f
"feat=
{
features
.
shape
}
,w=
{
filters
.
shape
}
,pair=
{
indice_pairs
.
shape
}
,"
msg
+=
f
"pairnum=
{
indice_pair_num
}
,act=
{
num_activate_out
}
,algo=
{
algo
}
"
msg
+=
f
"pairnum=
{
indice_pair_num
}
,act=
{
num_activate_out
}
,algo=
{
algo
}
"
print
(
msg
,
file
=
sys
.
stderr
)
print
(
msg
,
file
=
sys
.
stderr
)
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
raise
e
raise
e
@
staticmethod
@
staticmethod
@
once_differentiable
@
once_differentiable
...
@@ -316,10 +348,9 @@ class SubMConvFunction(Function):
...
@@ -316,10 +348,9 @@ class SubMConvFunction(Function):
msg
+=
f
"pairnum=
{
indice_pair_num
}
,do=
{
grad_output
.
shape
}
"
msg
+=
f
"pairnum=
{
indice_pair_num
}
,do=
{
grad_output
.
shape
}
"
print
(
msg
,
file
=
sys
.
stderr
)
print
(
msg
,
file
=
sys
.
stderr
)
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
spconv_save_debug_data
((
indice_pairs
,
indice_pair_num
))
raise
e
raise
e
return
input_bp
,
filters_bp
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
input_bp
,
filters_bp
,
None
,
None
,
None
,
None
,
None
class
SparseMaxPoolFunction
(
Function
):
class
SparseMaxPoolFunction
(
Function
):
...
@@ -361,13 +392,17 @@ class SparseMaxPoolImplicitGemmFunction(Function):
...
@@ -361,13 +392,17 @@ class SparseMaxPoolImplicitGemmFunction(Function):
features
,
out
,
grad_output
,
indice_pairs_bwd
)
features
,
out
,
grad_output
,
indice_pairs_bwd
)
return
input_bp
,
None
,
None
,
None
return
input_bp
,
None
,
None
,
None
class
SparseAvgPoolImplicitGemmFunction
(
Function
):
class
SparseAvgPoolImplicitGemmFunction
(
Function
):
@
staticmethod
@
staticmethod
@
_TORCH_CUSTOM_FWD
@
_TORCH_CUSTOM_FWD
def
forward
(
ctx
,
features
:
torch
.
Tensor
,
indice_pairs_fwd
:
torch
.
Tensor
,
def
forward
(
ctx
,
features
:
torch
.
Tensor
,
indice_pairs_fwd
:
torch
.
Tensor
,
indice_pairs_bwd
:
torch
.
Tensor
,
num_activate_out
:
int
,
calc_count
):
indice_pairs_bwd
:
torch
.
Tensor
,
num_activate_out
:
int
,
out
,
count
=
ops
.
indice_avgpool_implicit_gemm
(
features
,
indice_pairs_fwd
,
calc_count
):
num_activate_out
,
calc_count
)
out
,
count
=
ops
.
indice_avgpool_implicit_gemm
(
features
,
indice_pairs_fwd
,
num_activate_out
,
calc_count
)
ctx
.
save_for_backward
(
indice_pairs_bwd
,
features
,
out
,
count
)
ctx
.
save_for_backward
(
indice_pairs_bwd
,
features
,
out
,
count
)
return
out
return
out
...
@@ -398,6 +433,7 @@ def _indice_to_scalar(indices: torch.Tensor, shape: List[int]):
...
@@ -398,6 +433,7 @@ def _indice_to_scalar(indices: torch.Tensor, shape: List[int]):
scalar_inds
+=
stride
[
i
]
*
indices
[:,
i
]
scalar_inds
+=
stride
[
i
]
*
indices
[:,
i
]
return
scalar_inds
.
contiguous
()
return
scalar_inds
.
contiguous
()
def
sparse_add_hash_based
(
*
tens
:
SparseConvTensor
):
def
sparse_add_hash_based
(
*
tens
:
SparseConvTensor
):
""" sparse add with misaligned indices.
""" sparse add with misaligned indices.
if you use sparse add, the indice_dict will be dropped and impossible
if you use sparse add, the indice_dict will be dropped and impossible
...
@@ -417,7 +453,7 @@ def sparse_add_hash_based(*tens: SparseConvTensor):
...
@@ -417,7 +453,7 @@ def sparse_add_hash_based(*tens: SparseConvTensor):
if
max_num_indices
<
ten
.
features
.
shape
[
0
]:
if
max_num_indices
<
ten
.
features
.
shape
[
0
]:
max_num_indices_idx
=
i
max_num_indices_idx
=
i
max_num_indices
=
ten
.
features
.
shape
[
0
]
max_num_indices
=
ten
.
features
.
shape
[
0
]
first
=
tens
[
0
]
first
=
tens
[
0
]
feat
=
first
.
features
feat
=
first
.
features
shape
=
[
first
.
batch_size
,
*
first
.
spatial_shape
]
shape
=
[
first
.
batch_size
,
*
first
.
spatial_shape
]
...
@@ -438,21 +474,29 @@ def sparse_add_hash_based(*tens: SparseConvTensor):
...
@@ -438,21 +474,29 @@ def sparse_add_hash_based(*tens: SparseConvTensor):
# assign arange to values of hash table
# assign arange to values of hash table
count
=
table
.
assign_arange_
()
count
=
table
.
assign_arange_
()
count_val
=
count
.
item
()
count_val
=
count
.
item
()
out_features
=
torch
.
zeros
([
int
(
count_val
),
feat
.
shape
[
1
]],
dtype
=
feat
.
dtype
,
device
=
feat
.
device
)
out_features
=
torch
.
zeros
([
int
(
count_val
),
feat
.
shape
[
1
]],
out_indices
=
torch
.
zeros
([
int
(
count_val
),
first
.
indices
.
shape
[
1
]],
dtype
=
first
.
indices
.
dtype
,
device
=
first
.
indices
.
device
)
dtype
=
feat
.
dtype
,
device
=
feat
.
device
)
out_indices
=
torch
.
zeros
([
int
(
count_val
),
first
.
indices
.
shape
[
1
]],
dtype
=
first
.
indices
.
dtype
,
device
=
first
.
indices
.
device
)
for
ten
,
scalar
in
zip
(
tens
,
scalars
):
for
ten
,
scalar
in
zip
(
tens
,
scalars
):
out_inds
,
_
=
table
.
query
(
scalar
)
out_inds
,
_
=
table
.
query
(
scalar
)
out_inds
=
out_inds
.
long
()
out_inds
=
out_inds
.
long
()
out_features
[
out_inds
]
+=
ten
.
features
out_features
[
out_inds
]
+=
ten
.
features
out_indices
[
out_inds
]
=
ten
.
indices
out_indices
[
out_inds
]
=
ten
.
indices
res
=
SparseConvTensor
(
out_features
,
out_indices
,
first
.
spatial_shape
,
first
.
batch_size
,
res
=
SparseConvTensor
(
out_features
,
benchmark
=
first
.
benchmark
)
out_indices
,
first
.
spatial_shape
,
first
.
batch_size
,
benchmark
=
first
.
benchmark
)
if
count_val
==
max_num_indices
:
if
count_val
==
max_num_indices
:
res
.
indice_dict
=
tens
[
max_num_indices_idx
].
indice_dict
res
.
indice_dict
=
tens
[
max_num_indices_idx
].
indice_dict
res
.
benchmark_record
=
first
.
benchmark_record
res
.
benchmark_record
=
first
.
benchmark_record
res
.
_timer
=
first
.
_timer
res
.
_timer
=
first
.
_timer
res
.
thrust_allocator
=
first
.
thrust_allocator
res
.
thrust_allocator
=
first
.
thrust_allocator
return
res
return
res
def
sparse_add
(
*
tens
:
SparseConvTensor
):
def
sparse_add
(
*
tens
:
SparseConvTensor
):
"""reuse torch.sparse. the internal is sort + unique
"""reuse torch.sparse. the internal is sort + unique
...
@@ -461,7 +505,9 @@ def sparse_add(*tens: SparseConvTensor):
...
@@ -461,7 +505,9 @@ def sparse_add(*tens: SparseConvTensor):
max_num_indices_idx
=
0
max_num_indices_idx
=
0
ten_ths
:
List
[
torch
.
Tensor
]
=
[]
ten_ths
:
List
[
torch
.
Tensor
]
=
[]
first
=
tens
[
0
]
first
=
tens
[
0
]
res_shape
=
[
first
.
batch_size
,
*
first
.
spatial_shape
,
first
.
features
.
shape
[
1
]]
res_shape
=
[
first
.
batch_size
,
*
first
.
spatial_shape
,
first
.
features
.
shape
[
1
]
]
for
i
,
ten
in
enumerate
(
tens
):
for
i
,
ten
in
enumerate
(
tens
):
assert
ten
.
spatial_shape
==
tens
[
0
].
spatial_shape
assert
ten
.
spatial_shape
==
tens
[
0
].
spatial_shape
...
@@ -470,18 +516,25 @@ def sparse_add(*tens: SparseConvTensor):
...
@@ -470,18 +516,25 @@ def sparse_add(*tens: SparseConvTensor):
if
max_num_indices
<
ten
.
features
.
shape
[
0
]:
if
max_num_indices
<
ten
.
features
.
shape
[
0
]:
max_num_indices_idx
=
i
max_num_indices_idx
=
i
max_num_indices
=
ten
.
features
.
shape
[
0
]
max_num_indices
=
ten
.
features
.
shape
[
0
]
ten_ths
.
append
(
torch
.
sparse_coo_tensor
(
ten
.
indices
.
T
,
ten
.
features
,
res_shape
,
requires_grad
=
True
))
ten_ths
.
append
(
torch
.
sparse_coo_tensor
(
ten
.
indices
.
T
,
ten
.
features
,
res_shape
,
requires_grad
=
True
))
c_th
=
reduce
(
lambda
x
,
y
:
x
+
y
,
ten_ths
).
coalesce
()
c_th
=
reduce
(
lambda
x
,
y
:
x
+
y
,
ten_ths
).
coalesce
()
c_th_inds
=
c_th
.
indices
().
T
.
contiguous
().
int
()
c_th_inds
=
c_th
.
indices
().
T
.
contiguous
().
int
()
c_th_values
=
c_th
.
values
()
c_th_values
=
c_th
.
values
()
assert
c_th_values
.
is_contiguous
()
assert
c_th_values
.
is_contiguous
()
res
=
SparseConvTensor
(
c_th_values
,
c_th_inds
,
first
.
spatial_shape
,
first
.
batch_size
,
res
=
SparseConvTensor
(
c_th_values
,
benchmark
=
first
.
benchmark
)
c_th_inds
,
first
.
spatial_shape
,
first
.
batch_size
,
benchmark
=
first
.
benchmark
)
if
c_th_values
.
shape
[
0
]
==
max_num_indices
:
if
c_th_values
.
shape
[
0
]
==
max_num_indices
:
res
.
indice_dict
=
tens
[
max_num_indices_idx
].
indice_dict
res
.
indice_dict
=
tens
[
max_num_indices_idx
].
indice_dict
res
.
benchmark_record
=
first
.
benchmark_record
res
.
benchmark_record
=
first
.
benchmark_record
res
.
_timer
=
first
.
_timer
res
.
_timer
=
first
.
_timer
res
.
thrust_allocator
=
first
.
thrust_allocator
res
.
thrust_allocator
=
first
.
thrust_allocator
return
res
return
res
spconv/pytorch/ops.py
View file @
f8c25027
...
@@ -29,6 +29,8 @@ from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
...
@@ -29,6 +29,8 @@ from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from
spconv.constants
import
SPCONV_CPP_INDICE_PAIRS
,
SPCONV_CPP_INDICE_PAIRS_IGEMM
,
SPCONV_CPP_GEMM
,
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
from
spconv.constants
import
SPCONV_CPP_INDICE_PAIRS
,
SPCONV_CPP_INDICE_PAIRS_IGEMM
,
SPCONV_CPP_GEMM
,
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
import
spconv.core_cc
as
_ext
import
spconv.core_cc
as
_ext
from
spconv.core_cc.csrc.sparse.convops.spops
import
ConvGemmOps
from
spconv.core_cc.csrc.sparse.convops.spops
import
ConvGemmOps
from
spconv.core_cc.csrc.sparse.inference
import
InferenceOps
from
spconv.utils
import
nullcontext
from
spconv.utils
import
nullcontext
if
hasattr
(
_ext
,
"cumm"
):
if
hasattr
(
_ext
,
"cumm"
):
...
@@ -784,7 +786,11 @@ def indice_conv(features: torch.Tensor,
...
@@ -784,7 +786,11 @@ def indice_conv(features: torch.Tensor,
inverse
:
bool
=
False
,
inverse
:
bool
=
False
,
subm
:
bool
=
False
,
subm
:
bool
=
False
,
algo
:
ConvAlgo
=
ConvAlgo
.
Native
,
algo
:
ConvAlgo
=
ConvAlgo
.
Native
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
)):
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
act_alpha
:
float
=
0.0
,
act_beta
:
float
=
0.0
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
):
# filters: RSKC
# filters: RSKC
# stream = get_current_stream()
# stream = get_current_stream()
# CONV.stream_synchronize(stream)
# CONV.stream_synchronize(stream)
...
@@ -793,6 +799,9 @@ def indice_conv(features: torch.Tensor,
...
@@ -793,6 +799,9 @@ def indice_conv(features: torch.Tensor,
features
=
features
.
contiguous
()
features
=
features
.
contiguous
()
if
features
.
dtype
==
torch
.
int8
or
features
.
dtype
==
torch
.
qint8
:
if
features
.
dtype
==
torch
.
int8
or
features
.
dtype
==
torch
.
qint8
:
raise
NotImplementedError
(
"work in progress"
)
raise
NotImplementedError
(
"work in progress"
)
bias_tv
=
tv
.
Tensor
()
if
bias
is
not
None
:
bias_tv
=
torch_tensor_to_tv
(
bias
)
if
SPCONV_CPP_GEMM
and
GEMM_CPP
is
not
None
:
if
SPCONV_CPP_GEMM
and
GEMM_CPP
is
not
None
:
# print("CPPPPPP!!!", features.device)
# print("CPPPPPP!!!", features.device)
...
@@ -822,10 +831,18 @@ def indice_conv(features: torch.Tensor,
...
@@ -822,10 +831,18 @@ def indice_conv(features: torch.Tensor,
FILTER_HWIO
,
features_tv
,
filters_tv
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
arch
,
indice_pairs_tv
,
indice_pair_num_tv
,
arch
,
num_activate_out
,
inverse
,
subm
,
algo
.
value
,
num_activate_out
,
inverse
,
subm
,
algo
.
value
,
stream
)
stream
,
bias_tv
,
act_alpha
,
act_beta
,
act_type
)
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
return
out_features
return
out_features
if
not
features
.
is_cuda
:
stream
=
0
else
:
stream
=
get_current_stream
()
has_bias
=
bias
is
not
None
has_act
=
act_type
!=
tv
.
gemm
.
Activation
.
None_
if
has_bias
or
has_act
:
assert
features
.
is_cuda
,
"cpu don't support act and bias"
if
not
ALL_WEIGHT_IS_KRSC
:
if
not
ALL_WEIGHT_IS_KRSC
:
kv_dim
=
0
kv_dim
=
0
is_KC_not_CK
=
not
FILTER_HWIO
is_KC_not_CK
=
not
FILTER_HWIO
...
@@ -875,7 +892,17 @@ def indice_conv(features: torch.Tensor,
...
@@ -875,7 +892,17 @@ def indice_conv(features: torch.Tensor,
out_features
=
torch
.
zeros
((
num_activate_out
,
out_channel
),
out_features
=
torch
.
zeros
((
num_activate_out
,
out_channel
),
dtype
=
features
.
dtype
,
dtype
=
features
.
dtype
,
device
=
features
.
device
)
device
=
features
.
device
)
c
=
torch_tensor_to_tv
(
out_features
)
if
kv
==
1
and
subm
:
if
kv
==
1
and
subm
:
if
(
has_act
and
has_bias
):
InferenceOps
.
bias_add_act_inplace
(
c
,
bias_tv
,
act_type
,
act_alpha
,
act_beta
,
stream
)
else
:
if
has_act
:
InferenceOps
.
activation_inplace
(
c
,
act_type
,
act_alpha
,
act_beta
,
stream
)
if
has_bias
:
InferenceOps
.
bias_add_inplace
(
c
,
bias_tv
,
stream
)
return
out_features
return
out_features
indice_pair_num_cpu
=
indice_pair_num
.
cpu
().
tolist
()
indice_pair_num_cpu
=
indice_pair_num
.
cpu
().
tolist
()
...
@@ -928,7 +955,6 @@ def indice_conv(features: torch.Tensor,
...
@@ -928,7 +955,6 @@ def indice_conv(features: torch.Tensor,
SpconvOps
.
scatter_add_cpu
(
c
,
out_buffer_tv
,
out_indices
)
SpconvOps
.
scatter_add_cpu
(
c
,
out_buffer_tv
,
out_indices
)
return
out_features
return
out_features
stream
=
get_current_stream
()
profile_idx
=
kv_center
profile_idx
=
kv_center
if
subm
:
if
subm
:
...
@@ -1020,6 +1046,14 @@ def indice_conv(features: torch.Tensor,
...
@@ -1020,6 +1046,14 @@ def indice_conv(features: torch.Tensor,
# gather_times += gather_time
# gather_times += gather_time
inited
=
True
inited
=
True
if
(
has_act
and
has_bias
):
InferenceOps
.
bias_add_act_inplace
(
c
,
bias_tv
,
act_type
,
act_alpha
,
act_beta
,
stream
)
else
:
if
has_act
:
InferenceOps
.
activation_inplace
(
c
,
act_type
,
act_alpha
,
act_beta
,
stream
)
if
has_bias
:
InferenceOps
.
bias_add_inplace
(
c
,
bias_tv
,
stream
)
# CONV.stream_synchronize(stream)
# CONV.stream_synchronize(stream)
# print(out_features.mean(), out_features.max(), out_features.min())
# print(out_features.mean(), out_features.max(), out_features.min())
...
@@ -1391,8 +1425,16 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1391,8 +1425,16 @@ def implicit_gemm(features: torch.Tensor,
is_train
:
bool
,
is_train
:
bool
,
is_subm
:
bool
,
is_subm
:
bool
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
fp32_accum
:
Optional
[
bool
]
=
None
):
fp32_accum
:
Optional
[
bool
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
act_alpha
:
float
=
0.0
,
act_beta
:
float
=
0.0
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
):
stream
=
get_current_stream
()
stream
=
get_current_stream
()
bias_tv
=
tv
.
Tensor
()
if
bias
is
not
None
:
bias_tv
=
torch_tensor_to_tv
(
bias
)
if
SPCONV_CPP_GEMM
and
CONV_CPP
is
not
None
:
if
SPCONV_CPP_GEMM
and
CONV_CPP
is
not
None
:
alloc
=
TorchAllocator
(
features
.
device
)
alloc
=
TorchAllocator
(
features
.
device
)
...
@@ -1420,7 +1462,7 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1420,7 +1462,7 @@ def implicit_gemm(features: torch.Tensor,
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
pair_fwd_tv
,
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
pair_fwd_tv
,
pair_mask_fwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
pair_mask_fwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
num_activate_out
,
mask_tv
,
arch
,
is_train
,
is_subm
,
stream
,
num_activate_out
,
mask_tv
,
arch
,
is_train
,
is_subm
,
stream
,
timer_cpp
,
auto_fp32_accum
,
fp32_accum
)
timer_cpp
,
auto_fp32_accum
,
fp32_accum
,
bias_tv
,
act_alpha
,
act_beta
,
act_type
)
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
mask_output_fwd
=
alloc
.
allocated
.
get
(
AllocKeys
.
MaskOutputFwd
,
None
)
mask_output_fwd
=
alloc
.
allocated
.
get
(
AllocKeys
.
MaskOutputFwd
,
None
)
if
is_train
:
if
is_train
:
...
@@ -1512,6 +1554,10 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1512,6 +1554,10 @@ def implicit_gemm(features: torch.Tensor,
# t = time.time()
# t = time.time()
# print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# with tv.measure_and_print("f16 time"):
# with tv.measure_and_print("f16 time"):
bias_tv
=
tv
.
Tensor
()
if
bias
is
not
None
:
bias_tv
=
torch_tensor_to_tv
(
bias
)
with
timer
.
record
(
"implicit_gemm"
,
stream
):
with
timer
.
record
(
"implicit_gemm"
,
stream
):
for
j
in
range
(
num_split
):
for
j
in
range
(
num_split
):
beta
=
0
if
j
==
0
else
1
beta
=
0
if
j
==
0
else
1
...
@@ -1530,7 +1576,11 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1530,7 +1576,11 @@ def implicit_gemm(features: torch.Tensor,
mask_width
=-
1
,
mask_width
=-
1
,
beta
=
beta
,
beta
=
beta
,
stream
=
stream
,
stream
=
stream
,
verbose
=
False
)
verbose
=
False
,
bias
=
bias_tv
,
act_type
=
act_type
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
)
# INT8_TEST = True
# INT8_TEST = True
# if INT8_TEST:
# if INT8_TEST:
# if features.shape[1] % 32 != 0:
# if features.shape[1] % 32 != 0:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment