Commit 05cbea11 authored by Josh Fromm's avatar Josh Fromm Committed by Facebook GitHub Bot
Browse files

Hipify Pytorch3D (#1851)

Summary:
X-link: https://github.com/pytorch/pytorch/pull/133343

X-link: https://github.com/fairinternal/pytorch3d/pull/45

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1851

Enables pytorch3d to build on AMD. An important part of enabling this was not compiling the Pulsar backend when the target is AMD. There are simply too many kernel incompatibilites to make it work (I tried haha). Fortunately, it doesnt seem like most modern applications of pytorch3d rely on Pulsar. We should be able to unlock most of pytorch3d's goodness on AMD without it.

Reviewed By: bottler, houseroad

Differential Revision: D61171993

fbshipit-source-id: fd4aee378a3568b22676c5bf2b727c135ff710af
parent 38afdcfc
......@@ -7,11 +7,15 @@
*/
// clang-format off
#if !defined(USE_ROCM)
#include "./pulsar/global.h" // Include before <torch/extension.h>.
#endif
#include <torch/extension.h>
// clang-format on
#if !defined(USE_ROCM)
#include "./pulsar/pytorch/renderer.h"
#include "./pulsar/pytorch/tensor_util.h"
#endif
#include "ball_query/ball_query.h"
#include "blending/sigmoid_alpha_blend.h"
#include "compositing/alpha_composite.h"
......@@ -99,6 +103,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("marching_cubes", &MarchingCubes);
// Pulsar.
// Pulsar not enabled on AMD.
#if !defined(USE_ROCM)
#ifdef PULSAR_LOGGING_ENABLED
c10::ShowLogInfoToStderr();
#endif
......@@ -183,4 +189,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.attr("MAX_UINT") = py::int_(MAX_UINT);
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
#endif
}
......@@ -144,7 +144,7 @@ __device__ void CheckPixelInsideFace(
const bool zero_face_area =
(face_area <= kEpsilon && face_area >= -1.0f * kEpsilon);
if (zmax < 0 || cull_backfaces && back_face || outside_bbox ||
if (zmax < 0 || (cull_backfaces && back_face) || outside_bbox ||
zero_face_area) {
return;
}
......
......@@ -18,6 +18,8 @@ const auto vEpsilon = 1e-8;
// Common functions and operators for float2.
// Complex arithmetic is already defined for AMD.
#if !defined(USE_ROCM)
__device__ inline float2 operator-(const float2& a, const float2& b) {
return make_float2(a.x - b.x, a.y - b.y);
}
......@@ -41,6 +43,7 @@ __device__ inline float2 operator*(const float2& a, const float2& b) {
__device__ inline float2 operator*(const float a, const float2& b) {
return make_float2(a * b.x, a * b.y);
}
#endif
__device__ inline float FloatMin3(const float a, const float b, const float c) {
return fminf(a, fminf(b, c));
......
......@@ -23,37 +23,51 @@ WarpReduceMin(scalar_t* min_dists, int64_t* min_idxs, const size_t tid) {
min_idxs[tid] = min_idxs[tid + 32];
min_dists[tid] = min_dists[tid + 32];
}
// AMD does not use explicit syncwarp and instead automatically inserts memory
// fences during compilation.
#if !defined(USE_ROCM)
__syncwarp();
#endif
// s = 16
if (min_dists[tid] > min_dists[tid + 16]) {
min_idxs[tid] = min_idxs[tid + 16];
min_dists[tid] = min_dists[tid + 16];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
// s = 8
if (min_dists[tid] > min_dists[tid + 8]) {
min_idxs[tid] = min_idxs[tid + 8];
min_dists[tid] = min_dists[tid + 8];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
// s = 4
if (min_dists[tid] > min_dists[tid + 4]) {
min_idxs[tid] = min_idxs[tid + 4];
min_dists[tid] = min_dists[tid + 4];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
// s = 2
if (min_dists[tid] > min_dists[tid + 2]) {
min_idxs[tid] = min_idxs[tid + 2];
min_dists[tid] = min_dists[tid + 2];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
// s = 1
if (min_dists[tid] > min_dists[tid + 1]) {
min_idxs[tid] = min_idxs[tid + 1];
min_dists[tid] = min_dists[tid + 1];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
}
template <typename scalar_t>
......@@ -65,30 +79,42 @@ __device__ void WarpReduceMax(
dists[tid] = dists[tid + 32];
dists_idx[tid] = dists_idx[tid + 32];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
if (dists[tid] < dists[tid + 16]) {
dists[tid] = dists[tid + 16];
dists_idx[tid] = dists_idx[tid + 16];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
if (dists[tid] < dists[tid + 8]) {
dists[tid] = dists[tid + 8];
dists_idx[tid] = dists_idx[tid + 8];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
if (dists[tid] < dists[tid + 4]) {
dists[tid] = dists[tid + 4];
dists_idx[tid] = dists_idx[tid + 4];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
if (dists[tid] < dists[tid + 2]) {
dists[tid] = dists[tid + 2];
dists_idx[tid] = dists_idx[tid + 2];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
if (dists[tid] < dists[tid + 1]) {
dists[tid] = dists[tid + 1];
dists_idx[tid] = dists_idx[tid + 1];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
}
......@@ -6,6 +6,8 @@
# pyre-unsafe
import torch
from .blending import (
BlendParams,
hard_rgb_blend,
......@@ -74,9 +76,13 @@ from .points import (
PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
PulsarPointsRenderer,
rasterize_points,
)
# Pulsar is not enabled on amd.
if not torch.version.hip:
from .points import PulsarPointsRenderer
from .splatter_blend import SplatterBlender
from .utils import (
convert_to_tensors_and_broadcast,
......
......@@ -6,8 +6,13 @@
# pyre-unsafe
import torch
from .compositor import AlphaCompositor, NormWeightedCompositor
from .pulsar.unified import PulsarPointsRenderer
# Pulsar not enabled on amd.
if not torch.version.hip:
from .pulsar.unified import PulsarPointsRenderer
from .rasterize_points import rasterize_points
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment