Commit 2b195e43 authored by yan.yan's avatar yan.yan
Browse files

Merge branch 'master' of https://github.com/traveller59/spconv

parents 73a5ce7d 70067da8
...@@ -56,7 +56,7 @@ class ConvGemmOps: ...@@ -56,7 +56,7 @@ class ConvGemmOps:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int: def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> Tuple[int, Any]:
""" """
Args: Args:
allocator: allocator:
......
...@@ -2013,9 +2013,9 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2013,9 +2013,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
// tv::ssprint(tune_res.algo_desp.__repr__(), "WTF", exists, // tv::ssprint(tune_res.algo_desp.__repr__(), "WTF", exists,
// features.shape(), filters.shape(), out_features.shape(), tv::CUDAEvent::sync_and_duration(start_ev, end_ev)); // features.shape(), filters.shape(), out_features.shape(), tv::CUDAEvent::sync_and_duration(start_ev, end_ev));
return mask_width; return std::make_tuple(mask_width, tune_res);
""") """)
return code.ret("int") return code.ret("std::tuple<int, ConvTuneResult>")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.static_function @pccm.static_function
......
...@@ -136,6 +136,12 @@ class SparseConvolution(SparseModule): ...@@ -136,6 +136,12 @@ class SparseConvolution(SparseModule):
self._register_load_state_dict_pre_hook( self._register_load_state_dict_pre_hook(
self._load_weight_different_layout) self._load_weight_different_layout)
def get_max_num_voxels(self) -> Optional[torch.Tensor]:
if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None
def _load_weight_different_layout(self, state_dict, prefix, local_metadata, def _load_weight_different_layout(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, strict, missing_keys, unexpected_keys,
error_msgs): error_msgs):
......
...@@ -1413,7 +1413,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1413,7 +1413,7 @@ def implicit_gemm(features: torch.Tensor,
fp32_accum = False fp32_accum = False
arch = get_arch() arch = get_arch()
mask_width = ConvGemmOps.implicit_gemm( mask_width, tune_res_cpp = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv, alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv, pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, arch, is_train, is_subm, stream, num_activate_out, mask_tv, arch, is_train, is_subm, stream,
......
...@@ -90,6 +90,12 @@ class SparseMaxPool(SparseModule): ...@@ -90,6 +90,12 @@ class SparseMaxPool(SparseModule):
s += f', algo={self.algo}' s += f', algo={self.algo}'
return s.format(**self.__dict__) return s.format(**self.__dict__)
def get_max_num_voxels(self) -> Optional[torch.Tensor]:
if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None
def forward(self, input): def forward(self, input):
assert isinstance(input, spconv.SparseConvTensor) assert isinstance(input, spconv.SparseConvTensor)
features = input.features features = input.features
...@@ -283,6 +289,11 @@ class SparseAvgPool(SparseModule): ...@@ -283,6 +289,11 @@ class SparseAvgPool(SparseModule):
s += f', algo={self.algo}' s += f', algo={self.algo}'
return s.format(**self.__dict__) return s.format(**self.__dict__)
def get_max_num_voxels(self) -> Optional[torch.Tensor]:
if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None
def forward(self, input): def forward(self, input):
assert isinstance(input, spconv.SparseConvTensor) assert isinstance(input, spconv.SparseConvTensor)
features = input.features features = input.features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment