Unverified Commit 3344233f authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

[contrib] Support for xentropy extension. (#34)

* enable deprecated fused adam optimizer

* enable deprecated fused lamb

* enable xentropy extension

* add warpsize 32 for nv and 64 for amd

* update compiler arguments

* update the syncwarp conditions

* update syncwarp condition
parent 17fbbf91
...@@ -85,6 +85,14 @@ ...@@ -85,6 +85,14 @@
#define ALIGN_BYTES 16 #define ALIGN_BYTES 16
#ifdef __HIP_PLATFORM_HCC__
#define WARP_SIZE 64
#define SYNCWARP(mask)
#else
#define WARP_SIZE 32
#define SYNCWARP(mask) __syncwarp(mask)
#endif
using Tensor = at::Tensor; using Tensor = at::Tensor;
using TensorList = at::TensorList; using TensorList = at::TensorList;
using ScalarType = at::ScalarType; using ScalarType = at::ScalarType;
...@@ -126,7 +134,7 @@ inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { ...@@ -126,7 +134,7 @@ inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads)); uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
while (block_size < (max_block_size/2)) block_size *= 2; while (block_size < (max_block_size/2)) block_size *= 2;
// Launch at least a single warp - the kernel assumes that. // Launch at least a single warp - the kernel assumes that.
block_size = std::max(block_size, static_cast<uint64_t>(32)); block_size = std::max(block_size, static_cast<uint64_t>(WARP_SIZE));
return dim3(block_size); return dim3(block_size);
} }
...@@ -195,15 +203,15 @@ blockReduce(AccumT* smem, AccumT val, ...@@ -195,15 +203,15 @@ blockReduce(AccumT* smem, AccumT val,
AccumT warpVal = defaultVal; AccumT warpVal = defaultVal;
// First warp will perform per-warp reductions for the remaining warps // First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; uint32_t mask = (((uint64_t)1) << (blockDim.x / WARP_SIZE)) - 1;
if (threadIdx.x < 32) { if (threadIdx.x < WARP_SIZE) {
int lane = threadIdx.x % 32; int lane = threadIdx.x % WARP_SIZE;
if (lane < blockDim.x / 32) { if (lane < blockDim.x / WARP_SIZE) {
#pragma unroll #pragma unroll
for (int i = 0; i < 32; ++i) { for (int i = 0; i < WARP_SIZE; ++i) {
warpVal = r(warpVal, smem[lane * 32 + i]); warpVal = r(warpVal, smem[lane * WARP_SIZE + i]);
} }
__syncwarp(mask); SYNCWARP(mask);
smem[lane] = warpVal; smem[lane] = warpVal;
} }
} }
...@@ -214,7 +222,7 @@ blockReduce(AccumT* smem, AccumT val, ...@@ -214,7 +222,7 @@ blockReduce(AccumT* smem, AccumT val,
AccumT blockVal = defaultVal; AccumT blockVal = defaultVal;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / 32; ++i) { for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) {
blockVal = r(blockVal, smem[i]); blockVal = r(blockVal, smem[i]);
} }
smem[0] = blockVal; smem[0] = blockVal;
...@@ -249,16 +257,16 @@ blockReduce(AccumT* smem, ...@@ -249,16 +257,16 @@ blockReduce(AccumT* smem,
AccumT warpVal2 = defaultVal2; AccumT warpVal2 = defaultVal2;
// First warp will perform per-warp reductions for the remaining warps // First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; uint32_t mask = (((uint64_t)1) << (blockDim.x / WARP_SIZE)) - 1;
if (threadIdx.x < 32) { if (threadIdx.x < WARP_SIZE) {
int lane = threadIdx.x % 32; int lane = threadIdx.x % WARP_SIZE;
if (lane < blockDim.x / 32) { if (lane < blockDim.x / WARP_SIZE) {
#pragma unroll #pragma unroll
for (int i = 0; i < 32; ++i) { for (int i = 0; i < WARP_SIZE; ++i) {
warpVal1 = r1(warpVal1, smem[lane * 32 + i]); warpVal1 = r1(warpVal1, smem[lane * WARP_SIZE + i]);
warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]); warpVal2 = r2(warpVal2, smem[lane * WARP_SIZE + i + blockDim.x]);
} }
__syncwarp(mask); SYNCWARP(mask);
smem[lane] = warpVal1; smem[lane] = warpVal1;
smem[lane + blockDim.x] = warpVal2; smem[lane + blockDim.x] = warpVal2;
} }
...@@ -271,7 +279,7 @@ blockReduce(AccumT* smem, ...@@ -271,7 +279,7 @@ blockReduce(AccumT* smem,
AccumT blockVal2 = defaultVal2; AccumT blockVal2 = defaultVal2;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / 32; ++i) { for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) {
blockVal1 = r1(blockVal1, smem[i]); blockVal1 = r1(blockVal1, smem[i]);
blockVal2 = r2(blockVal2, smem[i + blockDim.x]); blockVal2 = r2(blockVal2, smem[i + blockDim.x]);
} }
......
...@@ -269,9 +269,12 @@ if "--xentropy" in sys.argv: ...@@ -269,9 +269,12 @@ if "--xentropy" in sys.argv:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None: is_rocm_pytorch = check_if_rocm_pytorch()
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
if not is_rocm_pytorch:
ext_modules.append( ext_modules.append(
CUDAExtension(name='xentropy_cuda', CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp', sources=['apex/contrib/csrc/xentropy/interface.cpp',
...@@ -279,6 +282,14 @@ if "--xentropy" in sys.argv: ...@@ -279,6 +282,14 @@ if "--xentropy" in sys.argv:
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
else:
ext_modules.append(
CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/hip/xentropy_kernel.hip'],
include_dirs=[os.path.join(this_dir, 'csrc/hip')],
extra_compile_args=['-O3'] + version_dependent_macros))
if "--deprecated_fused_adam" in sys.argv: if "--deprecated_fused_adam" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment