Commit fb075b86 authored by Michael Carilli's avatar Michael Carilli
Browse files

Compilation succeeds on 0.4, 18.04-6 containers, and current upstream master

parent bf855389
#pragma once
#include <ATen/Half.h>
#include <ATen/cuda/CUDAHalf.cuh>
// Type traits to convert types to CUDA-specific types. Used primarily to
// convert at::Half to CUDA's half type. This makes the conversion explicit.
// Disambiguate from whatever is in aten
namespace apex { namespace cuda {
template <typename T>
struct TypeConversion {
using type = T;
};
template <>
struct TypeConversion<at::Half> {
using type = half;
};
template <typename T>
using type = typename TypeConversion<T>::type;
}} // namespace apex::cuda
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
// Lock in a local version of CUDATypeConversion.cuh
#include "CUDATypeConversion.cuh"
#include <THC/THCNumerics.cuh>
#if __CUDACC_VER_MAJOR__ >= 9 #if __CUDACC_VER_MAJOR__ >= 9
#define __SHFL_DOWN(var, delta) __shfl_down_sync(0xffffffff, var, delta) #define __SHFL_DOWN(var, delta) __shfl_down_sync(0xffffffff, var, delta)
#else #else
...@@ -13,19 +18,13 @@ ...@@ -13,19 +18,13 @@
#define __SYNCWARP #define __SYNCWARP
#endif #endif
// not a long term solution, need to get this code into upstream.
#ifdef VERSION_LE_04 #ifdef VERSION_LE_04
#define USING_ACCSCALAR_T using accscalar_t = cuda::acc_type<cuda_scalar_t>; #define USING_ACCSCALAR_T using accscalar_t = cuda::acc_type<cuda_scalar_t>;
#else #else
#define USING_ACCSCALAR_T using accscalar_t = acc_type<cuda_scalar_t, true>; #define USING_ACCSCALAR_T using accscalar_t = acc_type<cuda_scalar_t, true>;
#endif #endif
#ifdef VERSION_LE_04
#define REDUCE_ADD ReduceAdd<accscalar_t, accscalar_t>()
#else
#define REDUCE_ADD ReduceAdd<accscalar_t>()
#endif
// Block size for weight_norm_*_first_dim_kernel. // Block size for weight_norm_*_first_dim_kernel.
// Currently, kernels are non-persistent. // Currently, kernels are non-persistent.
// Dialing up the block size to, say 1024, can improve performance by // Dialing up the block size to, say 1024, can improve performance by
...@@ -44,13 +43,13 @@ ...@@ -44,13 +43,13 @@
// blocks across the slow dimension up to the hardware-max block size of 1024. // blocks across the slow dimension up to the hardware-max block size of 1024.
#define TILE_H 64 #define TILE_H 64
// For reference, in THCTensorMathReduce.cuh: // Lock in a local version of ReduceAdd, copied from THCTensorMathReduce.cuh:
// template <typename T> template <typename T>
// struct ReduceAdd { struct ReduceAdd {
// inline __device__ T operator()(const T a, const T b) const { inline __device__ T operator()(const T a, const T b) const {
// return THCNumerics<T>::add(a, b); return THCNumerics<T>::add(a, b);
// } }
// }; };
// lanes is intended to be <= 32. // lanes is intended to be <= 32.
template template
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
// #include "ATen/AccumulateType.h" // #include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDATensorMethods.cuh" #include "ATen/cuda/CUDATensorMethods.cuh"
#include "ATen/cuda/CUDATypeConversion.cuh" #include "ATen/cuda/CUDATypeConversion.cuh"
#include <THC/THCTensorMathReduce.cuh> // #include <THC/THCTensorMathReduce.cuh>
#include <THC/THCGeneral.h>
#include <assert.h> #include <assert.h>
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ATen/cuda/CUDATensorMethods.cuh" #include "ATen/cuda/CUDATensorMethods.cuh"
#include "ATen/cuda/CUDATypeConversion.cuh" #include "ATen/cuda/CUDATypeConversion.cuh"
#include <THC/THCTensorMathReduce.cuh> // #include <THC/THCTensorMathReduce.cuh>
template template
<typename scalar_t, <typename scalar_t,
...@@ -46,7 +46,7 @@ __global__ void weight_norm_bwd_first_dim_kernel ...@@ -46,7 +46,7 @@ __global__ void weight_norm_bwd_first_dim_kernel
thread_sum += pLpwi*savedvi; // AccumOp, could do Kahan here thread_sum += pLpwi*savedvi; // AccumOp, could do Kahan here
} }
reduce_block_into_lanes(s, thread_sum, 1, REDUCE_ADD); reduce_block_into_lanes(s, thread_sum, 1, ReduceAdd<accscalar_t>());
accscalar_t result = s[0]; accscalar_t result = s[0];
// Could choose to save reciprocal of norm instead I suppose, but norms is probably // Could choose to save reciprocal of norm instead I suppose, but norms is probably
...@@ -105,7 +105,7 @@ __global__ void weight_norm_bwd_last_dim_kernel ...@@ -105,7 +105,7 @@ __global__ void weight_norm_bwd_last_dim_kernel
slower_dims_location += blockDim.y; slower_dims_location += blockDim.y;
} }
reduce_block_into_lanes(s, thread_sum, blockDim.x, REDUCE_ADD); reduce_block_into_lanes(s, thread_sum, blockDim.x, ReduceAdd<accscalar_t>());
accscalar_t result = s[threadIdx.x]; accscalar_t result = s[threadIdx.x];
// Broadcast load; could use shared memory instead. // Broadcast load; could use shared memory instead.
...@@ -145,7 +145,7 @@ void weight_norm_bwd_cuda ...@@ -145,7 +145,7 @@ void weight_norm_bwd_cuda
{ {
#ifdef DEBUG_ANY #ifdef DEBUG_ANY
using namespace std; using namespace std;
cout << "Hello from send_to_bwd with pLpw.type = " << pLpw.type << endl; cout << "Hello from send_to_bwd with pLpw.type() = " << pLpw.type() << endl;
#endif #endif
const int ndims = savedv.ndimension(); const int ndims = savedv.ndimension();
...@@ -164,7 +164,7 @@ void weight_norm_bwd_cuda ...@@ -164,7 +164,7 @@ void weight_norm_bwd_cuda
"weight_norm_bwd_first_dim_kernel", "weight_norm_bwd_first_dim_kernel",
[&] [&]
{ {
using cuda_scalar_t = cuda::type<scalar_t>; using cuda_scalar_t = apex::cuda::type<scalar_t>;
USING_ACCSCALAR_T USING_ACCSCALAR_T
weight_norm_bwd_first_dim_kernel weight_norm_bwd_first_dim_kernel
...@@ -197,7 +197,7 @@ void weight_norm_bwd_cuda ...@@ -197,7 +197,7 @@ void weight_norm_bwd_cuda
"weight_norm_bwd_last_dim_kernel", "weight_norm_bwd_last_dim_kernel",
[&] [&]
{ {
using cuda_scalar_t = cuda::type<scalar_t>; using cuda_scalar_t = apex::cuda::type<scalar_t>;
USING_ACCSCALAR_T USING_ACCSCALAR_T
weight_norm_bwd_last_dim_kernel weight_norm_bwd_last_dim_kernel
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ATen/cuda/CUDATensorMethods.cuh" #include "ATen/cuda/CUDATensorMethods.cuh"
#include "ATen/cuda/CUDATypeConversion.cuh" #include "ATen/cuda/CUDATypeConversion.cuh"
#include <THC/THCTensorMathReduce.cuh> // #include <THC/THCTensorMathReduce.cuh>
template template
<typename scalar_t, <typename scalar_t,
...@@ -44,7 +44,7 @@ __global__ void weight_norm_fwd_first_dim_kernel ...@@ -44,7 +44,7 @@ __global__ void weight_norm_fwd_first_dim_kernel
thread_sum += val_f*val_f; // AccumOp, could do Kahan here thread_sum += val_f*val_f; // AccumOp, could do Kahan here
} }
reduce_block_into_lanes(s, thread_sum, 1, REDUCE_ADD); reduce_block_into_lanes(s, thread_sum, 1, ReduceAdd<accscalar_t>());
accscalar_t result = s[0]; accscalar_t result = s[0];
result = sqrtf(result); result = sqrtf(result);
...@@ -98,7 +98,7 @@ __global__ void weight_norm_fwd_last_dim_kernel ...@@ -98,7 +98,7 @@ __global__ void weight_norm_fwd_last_dim_kernel
slower_dims_location += blockDim.y; slower_dims_location += blockDim.y;
} }
reduce_block_into_lanes(s, thread_sum, blockDim.x, REDUCE_ADD); reduce_block_into_lanes(s, thread_sum, blockDim.x, ReduceAdd<accscalar_t>());
// Better to pass an EpilogueOp to reduce_block_into_lanes, implement later // Better to pass an EpilogueOp to reduce_block_into_lanes, implement later
if(threadIdx.y == 0) if(threadIdx.y == 0)
...@@ -136,7 +136,7 @@ void weight_norm_fwd_cuda ...@@ -136,7 +136,7 @@ void weight_norm_fwd_cuda
{ {
#ifdef DEBUG_ANY #ifdef DEBUG_ANY
using namespace std; using namespace std;
cout << "hello from send_to_fwd with v.type = " << v.type << endl; cout << "hello from send_to_fwd with v.type() = " << v.type() << endl;
#endif #endif
const int ndims = v.ndimension(); const int ndims = v.ndimension();
...@@ -155,7 +155,7 @@ void weight_norm_fwd_cuda ...@@ -155,7 +155,7 @@ void weight_norm_fwd_cuda
"weight_norm_fwd_first_dim_kernel", "weight_norm_fwd_first_dim_kernel",
[&] [&]
{ {
using cuda_scalar_t = cuda::type<scalar_t>; using cuda_scalar_t = apex::cuda::type<scalar_t>;
USING_ACCSCALAR_T USING_ACCSCALAR_T
weight_norm_fwd_first_dim_kernel weight_norm_fwd_first_dim_kernel
...@@ -186,7 +186,7 @@ void weight_norm_fwd_cuda ...@@ -186,7 +186,7 @@ void weight_norm_fwd_cuda
"weight_norm_fwd_last_dim_kernel", "weight_norm_fwd_last_dim_kernel",
[&] [&]
{ {
using cuda_scalar_t = cuda::type<scalar_t>; using cuda_scalar_t = apex::cuda::type<scalar_t>;
USING_ACCSCALAR_T USING_ACCSCALAR_T
// just trying this formatting out to see how it feels... // just trying this formatting out to see how it feels...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment