Commit 3bda1b64 authored by yan.yan's avatar yan.yan
Browse files

v2.1.11: fix #385, fix volta (V100) wgrad kernel

parent b0ff62f3
# Changelog # Changelog
## [2.1.11] - 2021-11-22
### Fixed
- Fixed a bug Volta kernels (TITAN V, Tesla V100), backward weight kernels use f16 as accumulator. we should use f32.
- Fixed a corner case when user use kernel size = 1x1 but stride != 1.
- Fixed a corner case when input feature is non-contiguous when maxpool.
## [2.1.10] - 2021-11-19 ## [2.1.10] - 2021-11-19
### Fixed ### Fixed
- Fixed a bug in utils.PointToVoxel, shouldn't get cuda stream in cpu code - Fixed a bug in utils.PointToVoxel, shouldn't get cuda stream in cpu code
......
...@@ -449,7 +449,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -449,7 +449,7 @@ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -461,7 +461,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -461,7 +461,7 @@ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -473,7 +473,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -473,7 +473,7 @@ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvBwdWeight, (128, 128, 32), (32, 64, 32), *gen_conv_params(ConvBwdWeight, (128, 128, 32), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
......
...@@ -108,7 +108,11 @@ class SparseConvolution(SparseModule): ...@@ -108,7 +108,11 @@ class SparseConvolution(SparseModule):
self.out_channels = out_channels self.out_channels = out_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
kv = int(np.prod(kernel_size)) kv = int(np.prod(kernel_size))
kv_stride = int(np.prod(stride))
self.conv1x1 = kv == 1 self.conv1x1 = kv == 1
# TODO we should deprecate support for ksize == 1 but stride != 1.
if not subm:
self.conv1x1 &= kv_stride == 1
self.stride = stride self.stride = stride
self.padding = padding self.padding = padding
self.dilation = dilation self.dilation = dilation
...@@ -247,6 +251,8 @@ class SparseConvolution(SparseModule): ...@@ -247,6 +251,8 @@ class SparseConvolution(SparseModule):
if self.bias is not None: if self.bias is not None:
features += self.bias features += self.bias
out_tensor = out_tensor.replace_feature(features) out_tensor = out_tensor.replace_feature(features)
# padding may change spatial shape of conv 1x1.
out_tensor.spatial_shape = out_spatial_shape
return out_tensor return out_tensor
indice_dict = input.indice_dict.copy() indice_dict = input.indice_dict.copy()
......
...@@ -789,6 +789,9 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -789,6 +789,9 @@ def indice_conv_backward(features: torch.Tensor,
filters = filters.reshape(-1, *filters.shape[-2:]) filters = filters.reshape(-1, *filters.shape[-2:])
kv = filters.shape[0] kv = filters.shape[0]
kv_center = kv // 2 kv_center = kv // 2
# TODO handle this in nn.Module to make sure features in backward is contiguous
if not features.is_contiguous():
features = features.contiguous()
if not out_bp.is_contiguous(): if not out_bp.is_contiguous():
out_bp = out_bp.contiguous() out_bp = out_bp.contiguous()
assert out_bp.is_contiguous() assert out_bp.is_contiguous()
...@@ -1200,6 +1203,9 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1200,6 +1203,9 @@ def implicit_gemm_backward(features: torch.Tensor,
raise NotImplementedError("work in progress") raise NotImplementedError("work in progress")
if not out_bp.is_contiguous(): if not out_bp.is_contiguous():
out_bp = out_bp.contiguous() out_bp = out_bp.contiguous()
if not features.is_contiguous():
features = features.contiguous()
assert out_bp.is_contiguous() assert out_bp.is_contiguous()
assert filters.is_contiguous() assert filters.is_contiguous()
assert features.is_contiguous() assert features.is_contiguous()
...@@ -1357,6 +1363,8 @@ def indice_maxpool(features: torch.Tensor, indice_pairs: torch.Tensor, ...@@ -1357,6 +1363,8 @@ def indice_maxpool(features: torch.Tensor, indice_pairs: torch.Tensor,
# stream = get_current_stream() # stream = get_current_stream()
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
# t = time.time() # t = time.time()
if not features.is_contiguous():
features = features.contiguous()
out_channel = features.shape[-1] out_channel = features.shape[-1]
out_features = torch.zeros((num_activate_out, out_channel), out_features = torch.zeros((num_activate_out, out_channel),
...@@ -1399,6 +1407,9 @@ def indice_maxpool_backward(features, out_features, out_bp, indice_pairs, ...@@ -1399,6 +1407,9 @@ def indice_maxpool_backward(features, out_features, out_bp, indice_pairs,
indice_pair_num_cpu = indice_pair_num.cpu().tolist() indice_pair_num_cpu = indice_pair_num.cpu().tolist()
if not out_bp.is_contiguous(): if not out_bp.is_contiguous():
out_bp = out_bp.contiguous() out_bp = out_bp.contiguous()
if not features.is_contiguous():
features = features.contiguous()
out_features_tv = torch_tensor_to_tv(out_features) out_features_tv = torch_tensor_to_tv(out_features)
features_tv = torch_tensor_to_tv(features) features_tv = torch_tensor_to_tv(features)
out_bp_tv = torch_tensor_to_tv(out_bp) out_bp_tv = torch_tensor_to_tv(out_bp)
...@@ -1428,6 +1439,8 @@ def indice_maxpool_implicit_gemm(features: torch.Tensor, ...@@ -1428,6 +1439,8 @@ def indice_maxpool_implicit_gemm(features: torch.Tensor,
stream = get_current_stream() stream = get_current_stream()
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
# t = time.time() # t = time.time()
if not features.is_contiguous():
features = features.contiguous()
out_channel = features.shape[-1] out_channel = features.shape[-1]
out_features = torch.empty((num_activate_out, out_channel), out_features = torch.empty((num_activate_out, out_channel),
...@@ -1456,6 +1469,9 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp, ...@@ -1456,6 +1469,9 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp,
assert features.is_cuda assert features.is_cuda
if not out_bp.is_contiguous(): if not out_bp.is_contiguous():
out_bp = out_bp.contiguous() out_bp = out_bp.contiguous()
if not features.is_contiguous():
features = features.contiguous()
stream = get_current_stream() stream = get_current_stream()
out_features_tv = torch_tensor_to_tv(out_features) out_features_tv = torch_tensor_to_tv(out_features)
features_tv = torch_tensor_to_tv(features) features_tv = torch_tensor_to_tv(features)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment