Commit 70067da8 authored by yan.yan's avatar yan.yan
Browse files

small changes

parent 0c07559f
......@@ -56,7 +56,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> Tuple[int, Any]:
"""
Args:
allocator:
......
......@@ -2013,9 +2013,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
// tv::ssprint(tune_res.algo_desp.__repr__(), "WTF", exists,
// features.shape(), filters.shape(), out_features.shape(), tv::CUDAEvent::sync_and_duration(start_ev, end_ev));
return mask_width;
return std::make_tuple(mask_width, tune_res);
""")
return code.ret("int")
return code.ret("std::tuple<int, ConvTuneResult>")
@pccm.pybind.mark
@pccm.static_function
......
......@@ -136,6 +136,12 @@ class SparseConvolution(SparseModule):
self._register_load_state_dict_pre_hook(
self._load_weight_different_layout)
def get_max_num_voxels(self) -> Optional[torch.Tensor]:
if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None
def _load_weight_different_layout(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs):
......
......@@ -1366,7 +1366,7 @@ def implicit_gemm(features: torch.Tensor,
fp32_accum = False
arch = get_arch()
mask_width = ConvGemmOps.implicit_gemm(
mask_width, tune_res_cpp = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, arch, is_train, is_subm, stream, timer_cpp,
......@@ -1460,7 +1460,7 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream)
# t = time.time()
print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# with tv.measure_and_print("f16 time"):
with timer.record("implicit_gemm", stream):
for j in range(num_split):
......
......@@ -90,6 +90,12 @@ class SparseMaxPool(SparseModule):
s += f', algo={self.algo}'
return s.format(**self.__dict__)
def get_max_num_voxels(self) -> Optional[torch.Tensor]:
if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None
def forward(self, input):
assert isinstance(input, spconv.SparseConvTensor)
features = input.features
......@@ -282,6 +288,11 @@ class SparseAvgPool(SparseModule):
if self.algo is not None:
s += f', algo={self.algo}'
return s.format(**self.__dict__)
def get_max_num_voxels(self) -> Optional[torch.Tensor]:
if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None
def forward(self, input):
assert isinstance(input, spconv.SparseConvTensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment