Unverified Commit 03ea1c9d authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Fix] Skip fused_lrelu op when gcc is less than 6.0 or cuda is less than 10.2 (#2671)

* disable filtered_lrelu_op

* fix lint

* add cuda version check

* warning if disable
parent 91ed30dd
......@@ -538,20 +538,6 @@ torch::Tensor bias_act_op(const torch::Tensor &input, const torch::Tensor &bias,
REGISTER_DEVICE_IMPL(bias_act_op_impl, CUDA, bias_act_op);
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op_impl(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns);
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns);
REGISTER_DEVICE_IMPL(filtered_lrelu_op_impl, CUDA, filtered_lrelu_op);
torch::Tensor filtered_lrelu_act_op_impl(torch::Tensor x, torch::Tensor si,
int sx, int sy, float gain,
float slope, float clamp,
......
......@@ -12,6 +12,7 @@
#include <cstdint>
#include "pytorch_cuda_helper.hpp"
#include "pytorch_device_registry.hpp"
//------------------------------------------------------------------------
// CUDA kernel parameters.
......@@ -26,12 +27,12 @@ struct filtered_lrelu_kernel_params {
int _dummy; // Alignment.
// Rest of the parameters.
const void* x; // Input tensor.
void* y; // Output tensor.
const void* b; // Bias tensor.
unsigned char* s; // Sign tensor in/out. NULL if unused.
const float* fu; // Upsampling filter.
const float* fd; // Downsampling filter.
const void *x; // Input tensor.
void *y; // Output tensor.
const void *b; // Bias tensor.
unsigned char *s; // Sign tensor in/out. NULL if unused.
const float *fu; // Upsampling filter.
const float *fd; // Downsampling filter.
int2 pad0; // Left/top padding.
float gain; // Additional gain factor.
......@@ -60,8 +61,8 @@ struct filtered_lrelu_kernel_params {
};
struct filtered_lrelu_act_kernel_params {
void* x; // Input/output, modified in-place.
unsigned char* s; // Sign tensor in/out. NULL if unused.
void *x; // Input/output, modified in-place.
unsigned char *s; // Sign tensor in/out. NULL if unused.
float gain; // Additional gain factor.
float slope; // Leaky ReLU slope on negative side.
......@@ -78,8 +79,8 @@ struct filtered_lrelu_act_kernel_params {
// CUDA kernel specialization.
struct filtered_lrelu_kernel_spec {
void* setup; // Function for filter kernel setup.
void* exec; // Function for main operation.
void *setup; // Function for filter kernel setup.
void *exec; // Function for main operation.
int2 tileOut; // Width/height of launch tile.
int numWarps; // Number of warps per thread block, determines launch block
// size.
......@@ -92,9 +93,9 @@ struct filtered_lrelu_kernel_spec {
template <class T, class index_t, bool signWrite, bool signRead>
filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
const filtered_lrelu_kernel_params& p, int sharedKB);
const filtered_lrelu_kernel_params &p, int sharedKB);
template <class T, bool signWrite, bool signRead>
void* choose_filtered_lrelu_act_kernel(void);
void *choose_filtered_lrelu_act_kernel(void);
//------------------------------------------------------------------------
// Helpers.
......@@ -166,7 +167,7 @@ struct InternalType<c10::Half> {
// This works only up to blocks of size 256 x 256 and for all N that are powers
// of two.
template <int N>
__device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) {
__device__ __forceinline__ void fast_div_mod(int &x, int &y, unsigned int i) {
if ((N & (N - 1)) && N <= 256)
y = (i * ((1 << 24) / N + 1)) >> 24; // Assumes N <= 256, i < N*256.
else
......@@ -177,8 +178,8 @@ __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) {
// Type cast stride before reading it.
template <class T>
__device__ __forceinline__ T get_stride(const int64_t& x) {
return *reinterpret_cast<const T*>(&x);
__device__ __forceinline__ T get_stride(const int64_t &x) {
return *reinterpret_cast<const T *>(&x);
}
//------------------------------------------------------------------------
......@@ -233,7 +234,7 @@ static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) {
// Host function to copy filters written by setup kernel into constant buffer
// for main kernel.
static cudaError_t copy_filters(cudaStream_t stream) {
void* src = 0;
void *src = 0;
cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
if (err) return err;
return cudaMemcpyToSymbolAsync(
......@@ -359,8 +360,8 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
"shared memory overflow");
// Declare shared memory arrays.
scalar_t* s_buf0;
scalar_t* s_buf1;
scalar_t *s_buf0;
scalar_t *s_buf1;
if (sharedKB <= 48) {
// Allocate shared memory arrays here.
__shared__ scalar_t
......@@ -373,18 +374,18 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
s_buf1 = s_buf0 + s_buf0_size;
} else {
// Use the dynamically allocated shared memory array.
s_buf0 = (scalar_t*)s_buf_raw;
s_buf0 = (scalar_t *)s_buf_raw;
s_buf1 = s_buf0 + s_buf0_size;
}
// Pointers to the buffers.
scalar_t*
scalar_t *
s_tileIn; // Input tile: [relInX * tileInH + relInY]
scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW +
scalar_t *s_tileUpX; // After horizontal upsampling: [relInY * tileUpW +
// relUpX]
scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW +
scalar_t *s_tileUpXY; // After upsampling: [relUpY * tileUpW +
// relUpX]
scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW
scalar_t *s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW
// + relOutX]
if (filterMode == MODE_SUSD) {
s_tileIn = s_buf0;
......@@ -444,8 +445,8 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Load input tile & apply bias. Unrolled.
scalar_t b =
(scalar_t) * (const T*)((const char*)p.b +
(channelIdx * get_stride<index_t>(p.bStride)));
(scalar_t) * (const T *)((const char *)p.b +
(channelIdx * get_stride<index_t>(p.bStride)));
index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) +
batchIdx * get_stride<index_t>(p.xStride.w);
int idx = threadIdx.x;
......@@ -459,10 +460,10 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
scalar_t v = 0;
if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
v = (scalar_t) * ((const T*)((const char*)p.x +
(inX * get_stride<index_t>(p.xStride.x) +
inY * get_stride<index_t>(p.xStride.y) +
mapOfsIn))) +
v = (scalar_t) * ((const T *)((const char *)p.x +
(inX * get_stride<index_t>(p.xStride.x) +
inY * get_stride<index_t>(p.xStride.y) +
mapOfsIn))) +
b;
bool skip = (loop == loopCountIN - 1) && (idx >= tileInW * tileInH);
......@@ -932,9 +933,9 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
index_t ofs = x * get_stride<index_t>(p.yStride.x) +
y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
if ((uint32_t)y + 0 < p.yShape.y)
*((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
*((T *)((char *)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
if ((uint32_t)y + 1 < ymax)
*((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) =
*((T *)((char *)p.y + ofs + get_stride<index_t>(p.yStride.y))) =
(T)(v.y * (scalar_t)c_fd[0]);
}
}
......@@ -1216,9 +1217,9 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
else if ((uint32_t)x < p.yShape.x &&
(uint32_t)y <
p.yShape.y) // Write directly into output buffer
*((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) +
y * get_stride<index_t>(p.yStride.y) +
mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
*((T *)((char *)p.y + (x * get_stride<index_t>(p.yStride.x) +
y * get_stride<index_t>(p.yStride.y) +
mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
}
}
}
......@@ -1298,9 +1299,9 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
int outY = tileOutY + relOutY0;
if (outX < p.yShape.x & outY < p.yShape.y)
*((T*)((char*)p.y +
(outX * get_stride<index_t>(p.yStride.x) +
outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
*((T *)((char *)p.y + (outX * get_stride<index_t>(p.yStride.x) +
outY * get_stride<index_t>(p.yStride.y) +
mapOfsOut))) = (T)v;
}
} else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) {
// Full downsampling filter.
......@@ -1330,9 +1331,9 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
if ((uint32_t)outY < p.yShape.y) {
index_t ofs = outX * get_stride<index_t>(p.yStride.x) +
outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
if (outX + 0 < p.yShape.x) *((T *)((char *)p.y + ofs)) = (T)v.x;
if (outX + 1 < p.yShape.x)
*((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) =
*((T *)((char *)p.y + ofs + get_stride<index_t>(p.yStride.x))) =
(T)v.y;
}
}
......@@ -1348,9 +1349,9 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
int outX = tileOutX + relOutX0;
int outY = tileOutY + relOutY0;
if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
*((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) +
outY * get_stride<index_t>(p.yStride.y) +
mapOfsOut))) = (T)v;
*((T *)((char *)p.y + (outX * get_stride<index_t>(p.yStride.x) +
outY * get_stride<index_t>(p.yStride.y) +
mapOfsOut))) = (T)v;
}
}
}
......@@ -1390,7 +1391,7 @@ static __global__ void filtered_lrelu_act_kernel(
if (x < p.xShape.x && y < p.xShape.y) {
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z +
w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
T *pv = ((T *)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
// Gain, LReLU, clamp.
......@@ -1420,7 +1421,7 @@ static __global__ void filtered_lrelu_act_kernel(
{
uint64_t is =
x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
((uint32_t*)p.s)[is >> 4] = s;
((uint32_t *)p.s)[is >> 4] = s;
}
} else if (signRead) {
// Process value if in p.x.
......@@ -1428,7 +1429,7 @@ static __global__ void filtered_lrelu_act_kernel(
{
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z +
w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
T *pv = ((T *)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
v *= p.gain;
......@@ -1457,7 +1458,7 @@ static __global__ void filtered_lrelu_act_kernel(
{
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z +
w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
T *pv = ((T *)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
v *= p.gain;
if (v < 0.f) v *= p.slope;
......@@ -1469,8 +1470,8 @@ static __global__ void filtered_lrelu_act_kernel(
}
template <class T, bool signWrite, bool signRead>
void* choose_filtered_lrelu_act_kernel(void) {
return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
void *choose_filtered_lrelu_act_kernel(void) {
return (void *)filtered_lrelu_act_kernel<T, signWrite, signRead>;
}
//------------------------------------------------------------------------
......@@ -1478,7 +1479,7 @@ void* choose_filtered_lrelu_act_kernel(void) {
template <class T, class index_t, bool signWrite, bool signRead>
filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
const filtered_lrelu_kernel_params& p, int sharedKB) {
const filtered_lrelu_kernel_params &p, int sharedKB) {
filtered_lrelu_kernel_spec s = {0};
// Return the first matching kernel.
......@@ -1498,8 +1499,8 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
static_assert(FD % D == 0, \
"downscaling filter size must be multiple of " \
"downscaling factor"); \
s.setup = (void*)setup_filters_kernel; \
s.exec = (void*) \
s.setup = (void *)setup_filters_kernel; \
s.exec = (void *) \
filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, \
U, FU, D, FD, TW, TH, W * 32, !!XR, !!WS>; \
s.tileOut = make_int2(TW, TH); \
......@@ -1583,6 +1584,21 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
//------------------------------------------------------------------------
#define BUILD_FILTERED_LRELU_OP 1
#ifdef __GNUC__
#if __GNUC__ < 6
#undef BUILD_FILTERED_LRELU_OP
#define BUILD_FILTERED_LRELU_OP 0
#endif
#endif
#if CUDA_VERSION < 10020
#undef BUILD_FILTERED_LRELU_OP
#define BUILD_FILTERED_LRELU_OP 0
#endif
#if BUILD_FILTERED_LRELU_OP == 1
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
......@@ -1770,8 +1786,8 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "filtered_lrelu_cuda", [&] {
if constexpr (sizeof(scalar_t) <=
4) // Exclude doubles. constexpr prevents template
// instantiation.
4) // Exclude doubles. constexpr
// prevents template instantiation.
{
// Choose kernel based on index type, datatype and sign read/write
// modes.
......@@ -1804,7 +1820,7 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
// that kernel exists.
// Launch CUDA kernel.
void* args[] = {&p};
void *args[] = {&p};
int bx = spec.numWarps * 32;
int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
......@@ -1859,6 +1875,23 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
return std::make_tuple(y, so, 0);
}
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op_impl(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
bool writeSigns);
REGISTER_DEVICE_IMPL(filtered_lrelu_op_impl, CUDA, filtered_lrelu_op);
#else
#pragma message( \
"filtered_lrelu_op is not available. " \
"Please update your compiler and cuda version.")
#endif
#undef BUILD_FILTERED_LRELU_OP
//------------------------------------------------------------------------
torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
......@@ -1920,7 +1953,7 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
p.sOfs = make_int2(sx, sy);
// Choose CUDA kernel.
void* func = 0;
void *func = 0;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "filtered_lrelu_act_cuda", [&] {
if (writeSigns)
......@@ -1933,7 +1966,7 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
TORCH_CHECK(func, "internal error - CUDA kernel not found");
// Launch CUDA kernel.
void* args[] = {&p};
void *args[] = {&p};
int bx = 128; // 4 warps per block.
// Logical size of launch = writeSigns ? p.s : p.x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment