Commit 88eee5fe authored by Jeff Daily's avatar Jeff Daily
Browse files

updates to MHA, compilation still broken

parent 1fd257e2
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
......
...@@ -11,7 +11,14 @@ ...@@ -11,7 +11,14 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cmath> #include <cmath>
#ifdef __HIP_PLATFORM_HCC__
#define WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
#else
#define WARP_SHFL_XOR __shfl_xor_sync
#endif
namespace { namespace {
template <typename Datatype, int ELEMENTS_PER_LDG> template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
...@@ -127,7 +134,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc ...@@ -127,7 +134,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -152,7 +159,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc ...@@ -152,7 +159,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -351,7 +358,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, ...@@ -351,7 +358,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -375,7 +382,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, ...@@ -375,7 +382,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
auto seeds = at::cuda::philox::unpack(philox_args); auto seeds = at::cuda::philox::unpack(philox_args);
...@@ -505,7 +512,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint ...@@ -505,7 +512,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -529,7 +536,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint ...@@ -529,7 +536,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
...@@ -765,7 +772,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_ ...@@ -765,7 +772,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -790,7 +797,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_ ...@@ -790,7 +797,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -1020,7 +1027,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c ...@@ -1020,7 +1027,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -1045,7 +1052,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c ...@@ -1045,7 +1052,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -1243,7 +1250,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s ...@@ -1243,7 +1250,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -1268,7 +1275,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s ...@@ -1268,7 +1275,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -1385,7 +1392,7 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8 ...@@ -1385,7 +1392,7 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8
return false; return false;
} }
int log2_ceil_native(int value) { static int log2_ceil_native(int value) {
int log2_value = 0; int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value; while ((1 << log2_value) < value) ++log2_value;
return log2_value; return log2_value;
...@@ -1394,7 +1401,7 @@ int log2_ceil_native(int value) { ...@@ -1394,7 +1401,7 @@ int log2_ceil_native(int value) {
template <typename T> template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{ {
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__)
return __shfl_xor_sync(mask, value, laneMask, width); return __shfl_xor_sync(mask, value, laneMask, width);
#else #else
return __shfl_xor(value, laneMask, width); return __shfl_xor(value, laneMask, width);
...@@ -1835,7 +1842,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput ...@@ -1835,7 +1842,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -1860,7 +1867,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput ...@@ -1860,7 +1867,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -2305,7 +2312,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, con ...@@ -2305,7 +2312,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, con
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -2516,7 +2523,7 @@ __global__ void masked_softmax_warp_backward(__half *gradInput, const __half *gr ...@@ -2516,7 +2523,7 @@ __global__ void masked_softmax_warp_backward(__half *gradInput, const __half *gr
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
......
...@@ -39,7 +39,6 @@ if IS_ROCM_PYTORCH: ...@@ -39,7 +39,6 @@ if IS_ROCM_PYTORCH:
else: else:
rocm_include_dirs = [] rocm_include_dirs = []
include_dirs=[os.path.join(this_dir, 'csrc')] + rocm_include_dirs
if not torch.cuda.is_available() and not IS_ROCM_PYTORCH: if not torch.cuda.is_available() and not IS_ROCM_PYTORCH:
# https://github.com/NVIDIA/apex/issues/486 # https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
...@@ -157,9 +156,10 @@ if "--distributed_adam" in sys.argv: ...@@ -157,9 +156,10 @@ if "--distributed_adam" in sys.argv:
hipcc_args_adam = ['-O3'] + version_dependent_macros hipcc_args_adam = ['-O3'] + version_dependent_macros
ext_modules.append( ext_modules.append(
CUDAExtension(name='distributed_adam_cuda', CUDAExtension(name='distributed_adam_cuda',
sources=['./apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp', sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp',
'./apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'], 'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'],
include_dirs=include_dirs + [this_dir + '/apex/contrib/csrc/optimizers/'], include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/optimizers')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam})) 'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam}))
...@@ -280,9 +280,10 @@ if "--xentropy" in sys.argv: ...@@ -280,9 +280,10 @@ if "--xentropy" in sys.argv:
print ("INFO: Building the xentropy extension.") print ("INFO: Building the xentropy extension.")
ext_modules.append( ext_modules.append(
CUDAExtension(name='xentropy_cuda', CUDAExtension(name='xentropy_cuda',
sources=['./apex/contrib/csrc/xentropy/interface.cpp', sources=['apex/contrib/csrc/xentropy/interface.cpp',
'./apex/contrib/csrc/xentropy/xentropy_kernel.cu'], 'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=include_dirs + [this_dir + '/apex/contrib/csrc/xentropy/'], include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/xentropy')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
...@@ -302,9 +303,10 @@ if "--deprecated_fused_adam" in sys.argv: ...@@ -302,9 +303,10 @@ if "--deprecated_fused_adam" in sys.argv:
hipcc_args_fused_adam = ['-O3'] + version_dependent_macros hipcc_args_fused_adam = ['-O3'] + version_dependent_macros
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_adam_cuda', CUDAExtension(name='fused_adam_cuda',
sources=['./apex/contrib/csrc/optimizers/fused_adam_cuda.cpp', sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'./apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'], 'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
include_dirs=include_dirs + [this_dir + '/apex/contrib/csrc/optimizers/'], include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/optimizers')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam})) 'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))
...@@ -363,7 +365,7 @@ if "--fast_layer_norm" in sys.argv: ...@@ -363,7 +365,7 @@ if "--fast_layer_norm" in sys.argv:
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'-I./apex/contrib/csrc/layer_norm/', '-Iapex/contrib/csrc/layer_norm',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
...@@ -375,7 +377,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -375,7 +377,7 @@ if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False) cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
...@@ -387,99 +389,84 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -387,99 +389,84 @@ if "--fast_multihead_attn" in sys.argv:
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('arch=compute_80,code=sm_80')
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
nvcc_args_mha = ['-O3', '-gencode', 'arch=compute_70,code=sm_70', '-I./apex/contrib/csrc/multihead_attn/cutlass/', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda', '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag nvcc_args_mha = ['-O3',
hipcc_args_mha = ['-O3', '-I./apex/contrib/csrc/multihead_attn/cutlass/', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag '-gencode',
'arch=compute_70,code=sm_70',
'-Iapex/contrib/csrc/multihead_attn/cutlass',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag
hipcc_args_mha = ['-O3',
'-Iapex/contrib/csrc/multihead_attn/cutlass',
'-I/opt/rocm/include/hiprand',
'-I/opt/rocm/include/rocrand',
'-U__HIP_NO_HALF_OPERATORS__',
'-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_additive_mask_softmax_dropout', CUDAExtension(name='fast_additive_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp', sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp',
'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'], 'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_mask_softmax_dropout', CUDAExtension(name='fast_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp', sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp',
'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'], 'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask', CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'], 'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias', CUDAExtension(name='fast_self_multihead_attn_bias',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'], 'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn', CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'], 'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add', CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['./apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp',
'./apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'], 'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
include_dirs=include_dirs + [this_dir + '/apex/contrib/csrc/multihead_attn/'], include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha})) 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn', CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'], 'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add', CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['./apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp',
'./apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'], 'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
include_dirs=include_dirs + [this_dir + '/apex/contrib/csrc/multihead_attn/'], include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha})) 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment