"docs/vscode:/vscode.git/clone" did not exist on "9e5ce968549835c2ca1db2ca725687f381321b6a"
Unverified Commit 4a7e5609 authored by maskjp's avatar maskjp Committed by GitHub
Browse files

fix bug to allow pointnet to train in fp16 (#1207)



* fix bug to allow pointnet to train in fp16

* remove unused import

* fix lint

* fix lint for gather_points_cuda.cu
Co-authored-by: default avatarpeng <maskjp@tamu.edu>
parent dabe3ff4
...@@ -27,7 +27,7 @@ class GatherPoints(Function): ...@@ -27,7 +27,7 @@ class GatherPoints(Function):
B, npoint = indices.size() B, npoint = indices.size()
_, C, N = features.size() _, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, npoint) output = features.new_zeros((B, C, npoint))
gather_points_ext.gather_points_wrapper(B, C, N, npoint, features, gather_points_ext.gather_points_wrapper(B, C, N, npoint, features,
indices, output) indices, output)
...@@ -41,7 +41,7 @@ class GatherPoints(Function): ...@@ -41,7 +41,7 @@ class GatherPoints(Function):
idx, C, N = ctx.for_backwards idx, C, N = ctx.for_backwards
B, npoint = idx.size() B, npoint = idx.size()
grad_features = torch.cuda.FloatTensor(B, C, N).zero_() grad_features = grad_out.new_zeros((B, C, N))
grad_out_data = grad_out.data.contiguous() grad_out_data = grad_out.data.contiguous()
gather_points_ext.gather_points_grad_wrapper(B, C, N, npoint, gather_points_ext.gather_points_grad_wrapper(B, C, N, npoint,
grad_out_data, idx, grad_out_data, idx,
......
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/TensorUtils.h>
#include <THC/THC.h> #include <THC/THC.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/serialize/tensor.h> #include <torch/serialize/tensor.h>
#include <vector> #include <vector>
extern THCState *state; extern THCState *state;
int gather_points_wrapper(int b, int c, int n, int npoints, int gather_points_wrapper(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor& points_tensor, at::Tensor& idx_tensor,
at::Tensor out_tensor); at::Tensor& out_tensor);
void gather_points_kernel_launcher(int b, int c, int n, int npoints, void gather_points_kernel_launcher(int b, int c, int n, int npoints,
const float *points, const int *idx, const at::Tensor& points_tensor,
float *out, cudaStream_t stream); const at::Tensor& idx_tensor,
at::Tensor& out_tensor);
int gather_points_grad_wrapper(int b, int c, int n, int npoints, int gather_points_grad_wrapper(int b, int c, int n, int npoints,
at::Tensor grad_out_tensor, at::Tensor& grad_out_tensor,
at::Tensor idx_tensor, at::Tensor& idx_tensor,
at::Tensor grad_points_tensor); at::Tensor& grad_points_tensor);
void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints, void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
const float *grad_out, const int *idx, const at::Tensor& grad_out_tensor,
float *grad_points, const at::Tensor& idx_tensor,
cudaStream_t stream); at::Tensor& grad_points_tensor);
int gather_points_wrapper(int b, int c, int n, int npoints, int gather_points_wrapper(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor& points_tensor, at::Tensor& idx_tensor,
at::Tensor out_tensor) { at::Tensor& out_tensor)
const float *points = points_tensor.data_ptr<float>(); {
const int *idx = idx_tensor.data_ptr<int>(); gather_points_kernel_launcher(b, c, n, npoints, points_tensor, idx_tensor, out_tensor);
float *out = out_tensor.data_ptr<float>();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
gather_points_kernel_launcher(b, c, n, npoints, points, idx, out, stream);
return 1; return 1;
} }
int gather_points_grad_wrapper(int b, int c, int n, int npoints, int gather_points_grad_wrapper(int b, int c, int n, int npoints,
at::Tensor grad_out_tensor, at::Tensor& grad_out_tensor,
at::Tensor idx_tensor, at::Tensor& idx_tensor,
at::Tensor grad_points_tensor) { at::Tensor& grad_points_tensor)
const float *grad_out = grad_out_tensor.data_ptr<float>(); {
const int *idx = idx_tensor.data_ptr<int>(); gather_points_grad_kernel_launcher(b, c, n, npoints, grad_out_tensor, idx_tensor,
float *grad_points = grad_points_tensor.data_ptr<float>(); grad_points_tensor);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
gather_points_grad_kernel_launcher(b, c, n, npoints, grad_out, idx,
grad_points, stream);
return 1; return 1;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("gather_points_wrapper", &gather_points_wrapper, m.def("gather_points_wrapper", &gather_points_wrapper,
"gather_points_wrapper"); "gather_points_wrapper");
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper, m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper,
......
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#define TOTAL_THREADS 1024 #define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256 #define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
template <typename scalar_t>
__global__ void gather_points_kernel(int b, int c, int n, int m, __global__ void gather_points_kernel(int b, int c, int n, int m,
const float *__restrict__ points, const scalar_t *__restrict__ points,
const int *__restrict__ idx, const int *__restrict__ idx,
float *__restrict__ out) { scalar_t *__restrict__ out) {
// points: (B, C, N) // points: (B, C, N)
// idx: (B, M) // idx: (B, M)
// output: // output:
...@@ -26,8 +33,10 @@ __global__ void gather_points_kernel(int b, int c, int n, int m, ...@@ -26,8 +33,10 @@ __global__ void gather_points_kernel(int b, int c, int n, int m,
} }
void gather_points_kernel_launcher(int b, int c, int n, int npoints, void gather_points_kernel_launcher(int b, int c, int n, int npoints,
const float *points, const int *idx, const at::Tensor& points_tensor,
float *out, cudaStream_t stream) { const at::Tensor& idx_tensor,
at::Tensor& out_tensor)
{
// points: (B, C, N) // points: (B, C, N)
// idx: (B, npoints) // idx: (B, npoints)
// output: // output:
...@@ -35,23 +44,33 @@ void gather_points_kernel_launcher(int b, int c, int n, int npoints, ...@@ -35,23 +44,33 @@ void gather_points_kernel_launcher(int b, int c, int n, int npoints,
cudaError_t err; cudaError_t err;
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c,
b); // blockIdx.x(col), blockIdx.y(row) b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK); dim3 threads(THREADS_PER_BLOCK);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
gather_points_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points,
idx, out); AT_DISPATCH_FLOATING_TYPES_AND_HALF(
out_tensor.scalar_type(), "gather_points_kernel",
[&]
{
const scalar_t *points = points_tensor.data_ptr<scalar_t>();
const int *idx = idx_tensor.data_ptr<int>();
scalar_t *out = out_tensor.data_ptr<scalar_t>();
gather_points_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points,
idx, out);
});
err = cudaGetLastError(); err = cudaGetLastError();
if (cudaSuccess != err) { if (cudaSuccess != err)
{
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1); exit(-1);
} }
} }
template <typename scalar_t>
__global__ void gather_points_grad_kernel(int b, int c, int n, int m, __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
const float *__restrict__ grad_out, const scalar_t *__restrict__ grad_out,
const int *__restrict__ idx, const int *__restrict__ idx,
float *__restrict__ grad_points) { scalar_t *__restrict__ grad_points) {
// grad_out: (B, C, M) // grad_out: (B, C, M)
// idx: (B, M) // idx: (B, M)
// output: // output:
...@@ -70,9 +89,10 @@ __global__ void gather_points_grad_kernel(int b, int c, int n, int m, ...@@ -70,9 +89,10 @@ __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
} }
void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints, void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
const float *grad_out, const int *idx, const at::Tensor& grad_out_tensor,
float *grad_points, const at::Tensor& idx_tensor,
cudaStream_t stream) { at::Tensor& grad_points_tensor)
{
// grad_out: (B, C, npoints) // grad_out: (B, C, npoints)
// idx: (B, npoints) // idx: (B, npoints)
// output: // output:
...@@ -80,14 +100,24 @@ void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints, ...@@ -80,14 +100,24 @@ void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
cudaError_t err; cudaError_t err;
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c,
b); // blockIdx.x(col), blockIdx.y(row) b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK); dim3 threads(THREADS_PER_BLOCK);
gather_points_grad_kernel<<<blocks, threads, 0, stream>>>( cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
b, c, n, npoints, grad_out, idx, grad_points); AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_points_tensor.scalar_type(), "gather_points_grad_kernel",
[&]
{
const scalar_t *grad_out = grad_out_tensor.data_ptr<scalar_t>();
const int *idx = idx_tensor.data_ptr<int>();
scalar_t *grad_points = grad_points_tensor.data_ptr<scalar_t>();
gather_points_grad_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
b, c, n, npoints, grad_out, idx, grad_points);
});
err = cudaGetLastError(); err = cudaGetLastError();
if (cudaSuccess != err) { if (cudaSuccess != err)
{
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1); exit(-1);
} }
......
from typing import Tuple
import torch import torch
from mmcv.runner import force_fp32
from torch import nn as nn from torch import nn as nn
from torch.autograd import Function from torch.autograd import Function
from typing import Tuple
from ..ball_query import ball_query from ..ball_query import ball_query
from ..knn import knn from ..knn import knn
...@@ -60,7 +62,9 @@ class QueryAndGroup(nn.Module): ...@@ -60,7 +62,9 @@ class QueryAndGroup(nn.Module):
if self.max_radius is None: if self.max_radius is None:
assert not self.normalize_xyz, \ assert not self.normalize_xyz, \
'can not normalize grouped xyz when max_radius is None' 'can not normalize grouped xyz when max_radius is None'
self.fp16_enabled = False
@force_fp32()
def forward(self, points_xyz, center_xyz, features=None): def forward(self, points_xyz, center_xyz, features=None):
"""forward. """forward.
...@@ -141,7 +145,9 @@ class GroupAll(nn.Module): ...@@ -141,7 +145,9 @@ class GroupAll(nn.Module):
def __init__(self, use_xyz: bool = True): def __init__(self, use_xyz: bool = True):
super().__init__() super().__init__()
self.use_xyz = use_xyz self.use_xyz = use_xyz
self.fp16_enabled = False
@force_fp32()
def forward(self, def forward(self,
xyz: torch.Tensor, xyz: torch.Tensor,
new_xyz: torch.Tensor, new_xyz: torch.Tensor,
...@@ -183,7 +189,7 @@ class GroupingOperation(Function): ...@@ -183,7 +189,7 @@ class GroupingOperation(Function):
Args: Args:
features (Tensor): (B, C, N) tensor of features to group. features (Tensor): (B, C, N) tensor of features to group.
indices (Tensor): (B, npoint, nsample) the indicies of indices (Tensor): (B, npoint, nsample) the indices of
features to group with. features to group with.
Returns: Returns:
......
...@@ -2,9 +2,16 @@ ...@@ -2,9 +2,16 @@
import pytest import pytest
import torch import torch
from mmdet3d.ops import (ball_query, furthest_point_sample, from mmdet3d.ops import (
furthest_point_sample_with_dist, gather_points, ball_query,
grouping_operation, knn, three_interpolate, three_nn) furthest_point_sample,
furthest_point_sample_with_dist,
gather_points,
grouping_operation,
knn,
three_interpolate,
three_nn,
)
def test_fps(): def test_fps():
...@@ -236,6 +243,8 @@ def test_gather_points(): ...@@ -236,6 +243,8 @@ def test_gather_points():
[-0.7172, 0.0462, -0.6227, -0.7172, -0.7172, -0.7172]]]).cuda() [-0.7172, 0.0462, -0.6227, -0.7172, -0.7172, -0.7172]]]).cuda()
assert torch.allclose(output, expected_output) assert torch.allclose(output, expected_output)
output_half = gather_points(features.half(), idx)
assert torch.allclose(output_half, expected_output.half())
def test_three_interpolate(): def test_three_interpolate():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment