"server/vscode:/vscode.git/clone" did not exist on "92c1ecd0089d329d4b5a5e2b9327da828b888d34"
Unverified Commit 4a7e5609 authored by maskjp's avatar maskjp Committed by GitHub
Browse files

fix bug to allow pointnet to train in fp16 (#1207)



* fix bug to allow pointnet to train in fp16

* remove unused import

* fix lint

* fix lint for gather_points_cuda.cu
Co-authored-by: default avatarpeng <maskjp@tamu.edu>
parent dabe3ff4
......@@ -27,7 +27,7 @@ class GatherPoints(Function):
B, npoint = indices.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, npoint)
output = features.new_zeros((B, C, npoint))
gather_points_ext.gather_points_wrapper(B, C, N, npoint, features,
indices, output)
......@@ -41,7 +41,7 @@ class GatherPoints(Function):
idx, C, N = ctx.for_backwards
B, npoint = idx.size()
grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
grad_features = grad_out.new_zeros((B, C, N))
grad_out_data = grad_out.data.contiguous()
gather_points_ext.gather_points_grad_wrapper(B, C, N, npoint,
grad_out_data, idx,
......
#include <ATen/cuda/CUDAContext.h>
#include <ATen/TensorUtils.h>
#include <THC/THC.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern THCState *state;
int gather_points_wrapper(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor,
at::Tensor out_tensor);
at::Tensor& points_tensor, at::Tensor& idx_tensor,
at::Tensor& out_tensor);
void gather_points_kernel_launcher(int b, int c, int n, int npoints,
const float *points, const int *idx,
float *out, cudaStream_t stream);
const at::Tensor& points_tensor,
const at::Tensor& idx_tensor,
at::Tensor& out_tensor);
int gather_points_grad_wrapper(int b, int c, int n, int npoints,
at::Tensor grad_out_tensor,
at::Tensor idx_tensor,
at::Tensor grad_points_tensor);
at::Tensor& grad_out_tensor,
at::Tensor& idx_tensor,
at::Tensor& grad_points_tensor);
void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
const float *grad_out, const int *idx,
float *grad_points,
cudaStream_t stream);
const at::Tensor& grad_out_tensor,
const at::Tensor& idx_tensor,
at::Tensor& grad_points_tensor);
int gather_points_wrapper(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor,
at::Tensor out_tensor) {
const float *points = points_tensor.data_ptr<float>();
const int *idx = idx_tensor.data_ptr<int>();
float *out = out_tensor.data_ptr<float>();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
gather_points_kernel_launcher(b, c, n, npoints, points, idx, out, stream);
at::Tensor& points_tensor, at::Tensor& idx_tensor,
at::Tensor& out_tensor)
{
gather_points_kernel_launcher(b, c, n, npoints, points_tensor, idx_tensor, out_tensor);
return 1;
}
int gather_points_grad_wrapper(int b, int c, int n, int npoints,
at::Tensor grad_out_tensor,
at::Tensor idx_tensor,
at::Tensor grad_points_tensor) {
const float *grad_out = grad_out_tensor.data_ptr<float>();
const int *idx = idx_tensor.data_ptr<int>();
float *grad_points = grad_points_tensor.data_ptr<float>();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
gather_points_grad_kernel_launcher(b, c, n, npoints, grad_out, idx,
grad_points, stream);
at::Tensor& grad_out_tensor,
at::Tensor& idx_tensor,
at::Tensor& grad_points_tensor)
{
gather_points_grad_kernel_launcher(b, c, n, npoints, grad_out_tensor, idx_tensor,
grad_points_tensor);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("gather_points_wrapper", &gather_points_wrapper,
"gather_points_wrapper");
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper,
......
#include <stdio.h>
#include <stdlib.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
template <typename scalar_t>
__global__ void gather_points_kernel(int b, int c, int n, int m,
const float *__restrict__ points,
const scalar_t *__restrict__ points,
const int *__restrict__ idx,
float *__restrict__ out) {
scalar_t *__restrict__ out) {
// points: (B, C, N)
// idx: (B, M)
// output:
......@@ -26,8 +33,10 @@ __global__ void gather_points_kernel(int b, int c, int n, int m,
}
void gather_points_kernel_launcher(int b, int c, int n, int npoints,
const float *points, const int *idx,
float *out, cudaStream_t stream) {
const at::Tensor& points_tensor,
const at::Tensor& idx_tensor,
at::Tensor& out_tensor)
{
// points: (B, C, N)
// idx: (B, npoints)
// output:
......@@ -37,21 +46,31 @@ void gather_points_kernel_launcher(int b, int c, int n, int npoints,
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c,
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
out_tensor.scalar_type(), "gather_points_kernel",
[&]
{
const scalar_t *points = points_tensor.data_ptr<scalar_t>();
const int *idx = idx_tensor.data_ptr<int>();
scalar_t *out = out_tensor.data_ptr<scalar_t>();
gather_points_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points,
idx, out);
});
err = cudaGetLastError();
if (cudaSuccess != err) {
if (cudaSuccess != err)
{
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
template <typename scalar_t>
__global__ void gather_points_grad_kernel(int b, int c, int n, int m,
const float *__restrict__ grad_out,
const scalar_t *__restrict__ grad_out,
const int *__restrict__ idx,
float *__restrict__ grad_points) {
scalar_t *__restrict__ grad_points) {
// grad_out: (B, C, M)
// idx: (B, M)
// output:
......@@ -70,9 +89,10 @@ __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
}
void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
const float *grad_out, const int *idx,
float *grad_points,
cudaStream_t stream) {
const at::Tensor& grad_out_tensor,
const at::Tensor& idx_tensor,
at::Tensor& grad_points_tensor)
{
// grad_out: (B, C, npoints)
// idx: (B, npoints)
// output:
......@@ -83,11 +103,21 @@ void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
gather_points_grad_kernel<<<blocks, threads, 0, stream>>>(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_points_tensor.scalar_type(), "gather_points_grad_kernel",
[&]
{
const scalar_t *grad_out = grad_out_tensor.data_ptr<scalar_t>();
const int *idx = idx_tensor.data_ptr<int>();
scalar_t *grad_points = grad_points_tensor.data_ptr<scalar_t>();
gather_points_grad_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
b, c, n, npoints, grad_out, idx, grad_points);
});
err = cudaGetLastError();
if (cudaSuccess != err) {
if (cudaSuccess != err)
{
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
......
from typing import Tuple
import torch
from mmcv.runner import force_fp32
from torch import nn as nn
from torch.autograd import Function
from typing import Tuple
from ..ball_query import ball_query
from ..knn import knn
......@@ -60,7 +62,9 @@ class QueryAndGroup(nn.Module):
if self.max_radius is None:
assert not self.normalize_xyz, \
'can not normalize grouped xyz when max_radius is None'
self.fp16_enabled = False
@force_fp32()
def forward(self, points_xyz, center_xyz, features=None):
"""forward.
......@@ -141,7 +145,9 @@ class GroupAll(nn.Module):
def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz
self.fp16_enabled = False
@force_fp32()
def forward(self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
......@@ -183,7 +189,7 @@ class GroupingOperation(Function):
Args:
features (Tensor): (B, C, N) tensor of features to group.
indices (Tensor): (B, npoint, nsample) the indicies of
indices (Tensor): (B, npoint, nsample) the indices of
features to group with.
Returns:
......
......@@ -2,9 +2,16 @@
import pytest
import torch
from mmdet3d.ops import (ball_query, furthest_point_sample,
furthest_point_sample_with_dist, gather_points,
grouping_operation, knn, three_interpolate, three_nn)
from mmdet3d.ops import (
ball_query,
furthest_point_sample,
furthest_point_sample_with_dist,
gather_points,
grouping_operation,
knn,
three_interpolate,
three_nn,
)
def test_fps():
......@@ -236,6 +243,8 @@ def test_gather_points():
[-0.7172, 0.0462, -0.6227, -0.7172, -0.7172, -0.7172]]]).cuda()
assert torch.allclose(output, expected_output)
output_half = gather_points(features.half(), idx)
assert torch.allclose(output_half, expected_output.half())
def test_three_interpolate():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment