Commit cfaa1a3a authored by yanyan's avatar yanyan
Browse files

add Minkowski conv kernel

parent 9ce18407
// Copyright 2019 Yan Yan // Copyright 2019-2020 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -23,9 +23,9 @@ ...@@ -23,9 +23,9 @@
namespace spconv { namespace spconv {
enum ConvAlgo { kNative = 0, kBatch, kBatchGemmGather, kSparseConvNet }; enum ConvAlgo { kNative = 0, kBatch, kBatchGemmGather, kSparseConvNet, kMinkowskiEngine };
using all_conv_algos_t = using all_conv_algos_t =
tv::mp_list_c<int, kNative, kBatch, kBatchGemmGather, kSparseConvNet>; tv::mp_list_c<int, kNative, kBatch, kBatchGemmGather, kSparseConvNet, kMinkowskiEngine>;
// torch.jit's doc says only support int64, so we need to convert to int32. // torch.jit's doc says only support int64, so we need to convert to int32.
std::vector<torch::Tensor> std::vector<torch::Tensor>
...@@ -37,12 +37,6 @@ getIndicePairs(torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize, ...@@ -37,12 +37,6 @@ getIndicePairs(torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize,
std::vector<int64_t> outPadding, int64_t _subM, std::vector<int64_t> outPadding, int64_t _subM,
int64_t _transpose, int64_t _useHash); int64_t _transpose, int64_t _useHash);
torch::Tensor indiceConvBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t numActOut,
int64_t _inverse, int64_t _subM,
bool batchScatter);
torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
torch::Tensor indicePairs, torch::Tensor indiceNum, torch::Tensor indicePairs, torch::Tensor indiceNum,
int64_t numActOut, int64_t _inverse, int64_t _subM, int64_t numActOut, int64_t _inverse, int64_t _subM,
...@@ -53,11 +47,6 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -53,11 +47,6 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
torch::Tensor indiceNum, int64_t _inverse, int64_t _subM, torch::Tensor indiceNum, int64_t _inverse, int64_t _subM,
int64_t algo); int64_t algo);
std::vector<torch::Tensor>
indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters,
torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM, bool batchScatter);
} // namespace spconv } // namespace spconv
#endif #endif
\ No newline at end of file
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#pragma once #pragma once
#include "mp_helper.h" #include "mp_helper.h"
#include <tensorview/tensorview.h> #include <tensorview/tensorview.h>
#include <tensorview/tensor.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <torch/script.h> #include <torch/script.h>
#ifdef TV_CUDA #ifdef TV_CUDA
......
// Copyright 2019 Yan Yan // Copyright 2019-2020 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
// Copyright 2019 Yan Yan // Copyright 2019-2020 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
# Copyright 2019 Yan Yan # Copyright 2019-2020 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2019 Yan Yan # Copyright 2019-2020 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2019 Yan Yan # Copyright 2019-2020 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2019 Yan Yan # Copyright 2019-2020 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2019 Yan Yan # Copyright 2019-2020 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,6 +24,7 @@ class ConvAlgo(Enum): ...@@ -24,6 +24,7 @@ class ConvAlgo(Enum):
Batch = 1 # high memory cost, faster when number of points is small (< 50000) Batch = 1 # high memory cost, faster when number of points is small (< 50000)
BatchGemmGather = 2 # high memory cost, faster when number of points medium BatchGemmGather = 2 # high memory cost, faster when number of points medium
SparseConvNet = 3 SparseConvNet = 3
Minkowski = 4 # https://github.com/StanfordVL/MinkowskiEngine/blob/master/src/convolution.cu
def get_conv_output_size(input_size, kernel_size, stride, padding, dilation): def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
ndim = len(input_size) ndim = len(input_size)
......
# Copyright 2019 Yan Yan # Copyright 2019-2020 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2019 Yan Yan # Copyright 2019-2020 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2019 Yan Yan # Copyright 2019-2020 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2019 Yan Yan # Copyright 2019-2020 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
// Copyright 2019 Yan Yan // Copyright 2019-2020 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <spconv/fused_conv.cu.h> #include <spconv/fused_conv.cu.h>
#include <spconv/fused_conv.h> #include <spconv/fused_conv.h>
#include <tensorview/torch_utils.h> #include <tensorview/torch_utils.h>
#include <spconv/minkowski.cu.h>
namespace spconv { namespace spconv {
void fused_conv_cuda(torch::Tensor output, torch::Tensor features, void fused_conv_cuda(torch::Tensor output, torch::Tensor features,
torch::Tensor filters, torch::Tensor indicesIn, torch::Tensor filters, torch::Tensor indicesIn,
...@@ -39,4 +55,100 @@ void fused_conv_backward_cuda(torch::Tensor features, torch::Tensor din, ...@@ -39,4 +55,100 @@ void fused_conv_backward_cuda(torch::Tensor features, torch::Tensor din,
}); });
} }
void fused_conv_cuda_minkowski(torch::Tensor output, torch::Tensor features,
torch::Tensor filters, torch::Tensor indicesIn,
torch::Tensor indicesOut, int nHot) {
auto dtype = output.scalar_type();
auto in_nchannel = features.size(1);
auto out_nchannel = output.size(1);
int shared_mem_size = -1;
if ((in_nchannel > 16 && out_nchannel > 16 &&
in_nchannel * out_nchannel >= 512) ||
(in_nchannel > 24 && out_nchannel > 24))
shared_mem_size = 32;
else if (in_nchannel % 24 == 0 && out_nchannel % 24 == 0)
shared_mem_size = 24;
else if ((in_nchannel > 8 && out_nchannel > 8) ||
(in_nchannel % 16 == 0 && out_nchannel % 16 == 0))
shared_mem_size = 16;
else
shared_mem_size = 8;
constexpr int MAX_GRID = 65535;
auto stream = at::cuda::getCurrentCUDAStream();
using shmem_sizes_t = tv::mp_list_c<int, 32, 24, 16, 8>;
int num_grid = (nHot + shared_mem_size - 1) / shared_mem_size;
int num_div = (num_grid + MAX_GRID - 1) / MAX_GRID;
int step = (nHot + num_div - 1) / num_div;
dim3 threads(shared_mem_size, shared_mem_size);
tv::dispatch_torch<float>(dtype, [&](auto I) {
using T = decltype(I);
tv::DispatchInt<shmem_sizes_t>()(shared_mem_size, [&](auto ShSizeValue){
constexpr int ShmemSize = decltype(ShSizeValue)::value;
for (int s = 0; s < num_div; s++) {
int remainder = nHot - step * s;
int curr_num_active = remainder < step ? remainder : step;
dim3 grid((out_nchannel + threads.x - 1) / threads.x,
(curr_num_active + threads.y - 1) / threads.y);
matmul<T, int32_t, ShmemSize><<<grid, threads, 0, stream>>>(
features.data_ptr<T>(), in_nchannel, curr_num_active,
filters.data_ptr<T>(), out_nchannel,
in_nchannel, output.data_ptr<T>(), indicesIn.data_ptr<int32_t>(),
indicesOut.data_ptr<int32_t>());
}
});
});
}
void fused_conv_backward_cuda_minkowski(torch::Tensor features, torch::Tensor din,
torch::Tensor dout, torch::Tensor filters,
torch::Tensor dfilters, torch::Tensor indicesIn,
torch::Tensor indicesOut, int nHot) {
auto dtype = features.scalar_type();
auto in_nchannel = features.size(1);
auto out_nchannel = dout.size(1);
int shared_mem_size = -1;
if ((in_nchannel > 16 && out_nchannel > 16 &&
in_nchannel * out_nchannel >= 512) ||
(in_nchannel % 32 == 0 && out_nchannel % 32 == 0))
shared_mem_size = 32;
else if (in_nchannel % 24 == 0 && out_nchannel % 24 == 0)
shared_mem_size = 24;
else if ((in_nchannel > 8 && out_nchannel > 8) ||
(in_nchannel % 16 == 0 && out_nchannel % 16 == 0))
shared_mem_size = 16;
else
shared_mem_size = 8;
dim3 threads(shared_mem_size, shared_mem_size);
constexpr int MAX_GRID = 65535;
auto stream = at::cuda::getCurrentCUDAStream();
using shmem_sizes_t = tv::mp_list_c<int, 32, 24, 16, 8>;
int num_grid = (nHot + shared_mem_size - 1) / shared_mem_size;
int num_div = (num_grid + MAX_GRID - 1) / MAX_GRID;
int step = (nHot + num_div - 1) / num_div;
tv::dispatch_torch<float>(dtype, [&](auto I) {
using T = decltype(I);
tv::DispatchInt<shmem_sizes_t>()(shared_mem_size, [&](auto ShSizeValue){
constexpr int ShmemSize = decltype(ShSizeValue)::value;
for (int s = 0; s < num_div; s++) {
int remainder = nHot - step * s;
int curr_num_active = remainder < step ? remainder : step;
dim3 grid((in_nchannel + threads.x - 1) / threads.x,
(curr_num_active + threads.y - 1) / threads.y);
matmul2<T, int32_t, ShmemSize><<<grid, threads, 0, stream>>>(
dout.data_ptr<T>(), out_nchannel, curr_num_active, // A
filters.data_ptr<T>(), out_nchannel,
in_nchannel, // B
features.data_ptr<T>(), in_nchannel, curr_num_active, // D
din.data_ptr<T>(), // C
dfilters.data_ptr<T>(), // E
indicesIn.data_ptr<int32_t>(), indicesOut.data_ptr<int32_t>());
}
});
});
}
} // namespace spconv } // namespace spconv
\ No newline at end of file
// Copyright 2019 Yan Yan // Copyright 2019-2020 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
// Copyright 2019 Yan Yan // Copyright 2019-2020 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
// Copyright 2019 Yan Yan // Copyright 2019-2020 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
// Copyright 2019 Yan Yan // Copyright 2019-2020 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
// Copyright 2019 Yan Yan // Copyright 2019-2020 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment