Unverified Commit 3344233f authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

[contrib] Support for xentropy extension. (#34)

* enable deprecated fused adam optimizer

* enable deprecated fused lamb

* enable xentropy extension

* add warpsize 32 for nv and 64 for amd

* update compiler arguments

* update the syncwarp conditions

* update syncwarp condition
parent 17fbbf91
......@@ -85,6 +85,14 @@
#define ALIGN_BYTES 16
#ifdef __HIP_PLATFORM_HCC__
#define WARP_SIZE 64
#define SYNCWARP(mask)
#else
#define WARP_SIZE 32
#define SYNCWARP(mask) __syncwarp(mask)
#endif
using Tensor = at::Tensor;
using TensorList = at::TensorList;
using ScalarType = at::ScalarType;
......@@ -126,7 +134,7 @@ inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
while (block_size < (max_block_size/2)) block_size *= 2;
// Launch at least a single warp - the kernel assumes that.
block_size = std::max(block_size, static_cast<uint64_t>(32));
block_size = std::max(block_size, static_cast<uint64_t>(WARP_SIZE));
return dim3(block_size);
}
......@@ -195,15 +203,15 @@ blockReduce(AccumT* smem, AccumT val,
AccumT warpVal = defaultVal;
// First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
if (threadIdx.x < 32) {
int lane = threadIdx.x % 32;
if (lane < blockDim.x / 32) {
uint32_t mask = (((uint64_t)1) << (blockDim.x / WARP_SIZE)) - 1;
if (threadIdx.x < WARP_SIZE) {
int lane = threadIdx.x % WARP_SIZE;
if (lane < blockDim.x / WARP_SIZE) {
#pragma unroll
for (int i = 0; i < 32; ++i) {
warpVal = r(warpVal, smem[lane * 32 + i]);
for (int i = 0; i < WARP_SIZE; ++i) {
warpVal = r(warpVal, smem[lane * WARP_SIZE + i]);
}
__syncwarp(mask);
SYNCWARP(mask);
smem[lane] = warpVal;
}
}
......@@ -214,7 +222,7 @@ blockReduce(AccumT* smem, AccumT val,
AccumT blockVal = defaultVal;
if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / 32; ++i) {
for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) {
blockVal = r(blockVal, smem[i]);
}
smem[0] = blockVal;
......@@ -249,16 +257,16 @@ blockReduce(AccumT* smem,
AccumT warpVal2 = defaultVal2;
// First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
if (threadIdx.x < 32) {
int lane = threadIdx.x % 32;
if (lane < blockDim.x / 32) {
uint32_t mask = (((uint64_t)1) << (blockDim.x / WARP_SIZE)) - 1;
if (threadIdx.x < WARP_SIZE) {
int lane = threadIdx.x % WARP_SIZE;
if (lane < blockDim.x / WARP_SIZE) {
#pragma unroll
for (int i = 0; i < 32; ++i) {
warpVal1 = r1(warpVal1, smem[lane * 32 + i]);
warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]);
for (int i = 0; i < WARP_SIZE; ++i) {
warpVal1 = r1(warpVal1, smem[lane * WARP_SIZE + i]);
warpVal2 = r2(warpVal2, smem[lane * WARP_SIZE + i + blockDim.x]);
}
__syncwarp(mask);
SYNCWARP(mask);
smem[lane] = warpVal1;
smem[lane + blockDim.x] = warpVal2;
}
......@@ -271,7 +279,7 @@ blockReduce(AccumT* smem,
AccumT blockVal2 = defaultVal2;
if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / 32; ++i) {
for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) {
blockVal1 = r1(blockVal1, smem[i]);
blockVal2 = r2(blockVal2, smem[i + blockDim.x]);
}
......
......@@ -269,16 +269,27 @@ if "--xentropy" in sys.argv:
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
is_rocm_pytorch = check_if_rocm_pytorch()
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
if not is_rocm_pytorch:
ext_modules.append(
CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
else:
ext_modules.append(
CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/hip/xentropy_kernel.hip'],
include_dirs=[os.path.join(this_dir, 'csrc/hip')],
extra_compile_args=['-O3'] + version_dependent_macros))
if "--deprecated_fused_adam" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment