Unverified Commit 468b5713 authored by Yan Yan's avatar Yan Yan Committed by GitHub
Browse files

Merge pull request #153 from xmyqsh/reduce_subM_indice_conv_buffer

reduce subM indiceConv(Backward) bufferSize
parents 11bcbbf6 b728dac2
...@@ -165,26 +165,37 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -165,26 +165,37 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
auto numInPlanes = features.size(1); auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto indicePairMaxSizeIter =
std::max_element(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + kernelVolume);
int indicePairMaxOffset =
indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>();
int indicePairMaxSize = *indicePairMaxSizeIter;
auto options = auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device()); torch::TensorOptions().dtype(features.dtype()).device(features.device());
torch::Tensor output = torch::zeros({numActOut, numOutPlanes}, options); torch::Tensor output = torch::zeros({numActOut, numOutPlanes}, options);
torch::Tensor inputBuffer =
torch::empty({indicePairMaxSize, numInPlanes}, options);
torch::Tensor outputBuffer =
torch::empty({indicePairMaxSize, numOutPlanes}, options);
filters = filters.view({-1, numInPlanes, numOutPlanes}); filters = filters.view({-1, numInPlanes, numOutPlanes});
// init for subM
int indicePairMaxOffset = kernelVolume / 2;
int indicePairMaxSize = numActOut;
if (subM) { // the center index of subm conv don't need gather and scatter if (subM) { // the center index of subm conv don't need gather and scatter
// add. // add.
torch::mm_out(output, features, filters[indicePairMaxOffset]); torch::mm_out(output, features, filters[indicePairMaxOffset]);
// get indice pair second max size based on subM symmetric property
indicePairMaxSize =
*std::max_element(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + indicePairMaxOffset);
if (indicePairMaxSize == 0) {
return output;
}
} else {
indicePairMaxSize =
*std::max_element(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + kernelVolume);
} }
torch::Tensor inputBuffer =
torch::empty({indicePairMaxSize, numInPlanes}, options);
torch::Tensor outputBuffer =
torch::empty({indicePairMaxSize, numOutPlanes}, options);
double totalGatherTime = 0; double totalGatherTime = 0;
double totalGEMMTime = 0; double totalGEMMTime = 0;
double totalSAddTime = 0; double totalSAddTime = 0;
...@@ -399,29 +410,41 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -399,29 +410,41 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
auto numInPlanes = features.size(1); auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto indicePairMaxSizeIter =
std::max_element(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + kernelVolume);
int indicePairMaxOffset =
indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>();
int indicePairMaxSize = *indicePairMaxSizeIter;
auto options = auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device()); torch::TensorOptions().dtype(features.dtype()).device(features.device());
auto filterShape = filters.sizes(); auto filterShape = filters.sizes();
torch::Tensor inputGrad = torch::zeros(features.sizes(), options); torch::Tensor inputGrad = torch::zeros(features.sizes(), options);
torch::Tensor filtersGrad = torch::empty(filterShape, options); torch::Tensor filtersGrad = torch::empty(filterShape, options);
torch::Tensor inputBuffer =
torch::empty({indicePairMaxSize, numInPlanes}, options);
torch::Tensor outputBuffer =
torch::empty({indicePairMaxSize, numOutPlanes}, options);
filters = filters.view({-1, numInPlanes, numOutPlanes}); filters = filters.view({-1, numInPlanes, numOutPlanes});
filtersGrad = filtersGrad.view({-1, numInPlanes, numOutPlanes}); filtersGrad = filtersGrad.view({-1, numInPlanes, numOutPlanes});
// init for subM
int indicePairMaxOffset = kernelVolume / 2;
int indicePairMaxSize = indicePairNumCpu.data_ptr<int>()[indicePairMaxOffset];
if (subM) { if (subM) {
auto filterGradSub = filtersGrad[indicePairMaxOffset]; auto filterGradSub = filtersGrad[indicePairMaxOffset];
torch::mm_out(filterGradSub, features.t(), outGrad); torch::mm_out(filterGradSub, features.t(), outGrad);
torch::mm_out(inputGrad, outGrad, filters[indicePairMaxOffset].t()); torch::mm_out(inputGrad, outGrad, filters[indicePairMaxOffset].t());
// get indice pair second max size based on subM symmetric property
indicePairMaxSize =
*std::max_element(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + indicePairMaxOffset);
if (indicePairMaxSize == 0) {
return {inputGrad, filtersGrad.view(filterShape)};
}
} else {
indicePairMaxSize =
*std::max_element(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + kernelVolume);
} }
torch::Tensor inputBuffer =
torch::empty({indicePairMaxSize, numInPlanes}, options);
torch::Tensor outputBuffer =
torch::empty({indicePairMaxSize, numOutPlanes}, options);
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data_ptr<int>()[i]; auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0 || (subM && i == indicePairMaxOffset)) { if (nHot <= 0 || (subM && i == indicePairMaxOffset)) {
...@@ -594,4 +617,4 @@ indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters, ...@@ -594,4 +617,4 @@ indiceConvBackwardBatch(torch::Tensor features, torch::Tensor filters,
return {inputGrad, filtersGrad.view(filterShape)}; return {inputGrad, filtersGrad.view(filterShape)};
} }
} // namespace spconv } // namespace spconv
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment