Commit 1436a66a authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Merge remote-tracking branch 'origin/master' into IFU-master-2021-10-15

parents aee9f00d 08e88b1b
...@@ -11,7 +11,14 @@ ...@@ -11,7 +11,14 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cmath> #include <cmath>
#ifdef __HIP_PLATFORM_HCC__
#define APEX_WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
#else
#define APEX_WARP_SHFL_XOR __shfl_xor_sync
#endif
namespace { namespace {
template <typename Datatype, int ELEMENTS_PER_LDG> template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
...@@ -127,7 +134,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc ...@@ -127,7 +134,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -152,7 +159,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc ...@@ -152,7 +159,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -351,7 +358,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, ...@@ -351,7 +358,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -375,7 +382,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, ...@@ -375,7 +382,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
auto seeds = at::cuda::philox::unpack(philox_args); auto seeds = at::cuda::philox::unpack(philox_args);
...@@ -505,7 +512,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint ...@@ -505,7 +512,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -529,7 +536,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint ...@@ -529,7 +536,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
...@@ -765,7 +772,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_ ...@@ -765,7 +772,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -790,7 +797,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_ ...@@ -790,7 +797,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -1020,7 +1027,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c ...@@ -1020,7 +1027,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -1045,7 +1052,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c ...@@ -1045,7 +1052,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -1243,7 +1250,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s ...@@ -1243,7 +1250,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -1268,7 +1275,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s ...@@ -1268,7 +1275,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -1385,7 +1392,7 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8 ...@@ -1385,7 +1392,7 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8
return false; return false;
} }
int log2_ceil_native(int value) { static int log2_ceil_native(int value) {
int log2_value = 0; int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value; while ((1 << log2_value) < value) ++log2_value;
return log2_value; return log2_value;
...@@ -1394,7 +1401,7 @@ int log2_ceil_native(int value) { ...@@ -1394,7 +1401,7 @@ int log2_ceil_native(int value) {
template <typename T> template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{ {
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__)
return __shfl_xor_sync(mask, value, laneMask, width); return __shfl_xor_sync(mask, value, laneMask, width);
#else #else
return __shfl_xor(value, laneMask, width); return __shfl_xor(value, laneMask, width);
...@@ -1835,7 +1842,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput ...@@ -1835,7 +1842,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
...@@ -1860,7 +1867,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput ...@@ -1860,7 +1867,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -2305,7 +2312,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, con ...@@ -2305,7 +2312,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, con
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -2516,7 +2523,7 @@ __global__ void masked_softmax_warp_backward(__half *gradInput, const __half *gr ...@@ -2516,7 +2523,7 @@ __global__ void masked_softmax_warp_backward(__half *gradInput, const __half *gr
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0;i < WARP_BATCH;++i) { for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
......
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
extern THCState *state; extern THCState *state;
rocblas_datatype a_type = rocblas_datatype_f16_r;
rocblas_datatype b_type = rocblas_datatype_f16_r;
rocblas_datatype c_type = rocblas_datatype_f16_r;
rocblas_datatype d_type = rocblas_datatype_f16_r;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
int32_t solution_index = 0;
rocblas_int flags = 0;
cublasOperation_t convertTransToCublasOperation(char trans) { cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't') return CUBLAS_OP_T; if (trans == 't') return CUBLAS_OP_T;
else if (trans == 'n') return CUBLAS_OP_N; else if (trans == 'n') return CUBLAS_OP_N;
...@@ -28,9 +33,9 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -28,9 +33,9 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
} }
} }
void CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m, long n, long k, void RocblasStridedBatchedGemm(THCState *state, char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, long batchCount, cublasGemmAlgo_t algo=CUBLAS_GEMM_DEFAULT_TENSOR_OP) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo) {
cublasOperation_t opa = convertTransToCublasOperation(transa); cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb); cublasOperation_t opb = convertTransToCublasOperation(transb);
...@@ -39,237 +44,28 @@ void CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m, ...@@ -39,237 +44,28 @@ void CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m,
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
float fAlpha = alpha; float fAlpha = alpha;
float fBeta = beta; float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); THCublasCheck(rocblas_gemm_strided_batched_ex(handle,
THCublasCheck(cublasGemmStridedBatchedEx(handle,
opa, opb, (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA, (void*)&fAlpha, a, a_type, (int)lda, strideA,
b, CUDA_R_16F, (int)ldb, strideB, b, b_type, (int)ldb, strideB,
(void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC, (void*)&fBeta, c, c_type, (int)ldc, strideC,
(int)batchCount, CUDA_R_32F, algo)); d, d_type, int(ldd), strideD,
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); (int)batchCount, compute_type, algo, solution_index, flags));
}
template<cutlass::MatrixLayout::Kind A_LAYOUT, cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int DST_C>
void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, long batchCount) {
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);
typedef cutlass::gemm::WmmaGemmTraits<
A_LAYOUT,
B_LAYOUT,
cutlass::Shape<32, 16, 16>,
half,
half,
half,
cutlass::gemm::LinearScaling<float>,
float,
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
typename cutlass::Shape<16, 16, 16>,
SRC_A, //kScalarsPerLdgA_
SRC_B, //kScalarsPerLdgB_
SRC_A, //KScalarsPerLdsA_
SRC_B, //KScalarsPerLdsB_
DST_C, //kScalarsPerLdgCAndStgD_
DST_C/2, //kScalarsPerStsD_
DST_C/2 //kScalarsPerLdsD_
>
WmmaGemmTraits;
typedef cutlass::gemm::Gemm<WmmaGemmTraits> Gemm;
typename Gemm::Params params;
int result = params.initialize(
m, // M dimension for each batch
n, // N dimension for each batch
k, // K dimension for each batch
alpha, // scalar alpha
a,
lda,
strideA, // distance in memory between the first element of neighboring batch
b,
ldb,
strideB, // distance in memory between the first element of neighboring batch
beta, // scalar beta
c, // source matrix C
ldc,
strideC, // distance in memory between the first element of neighboring batch
c, // destination matrix C (may be different memory than source C matrix)
ldc,
strideC, // distance in memory between the first element of neighboring batch
batchCount
);
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is limited to 16 bits.
// To implement batched GEMM with larger batch size, we fragment it into
// smaller batched GEMMs of gridDim.z <= 64k
long batchesLeft = batchCount;
long iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
do {
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f TotalBatches: %ld iterBatchCount %ld\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);
int result = params.initialize(
m, // M dimension for each batch
n, // N dimension for each batch
k, // K dimension for each batch
alpha, // scalar alpha
a,
lda,
strideA, // distance in memory between the first element of neighboring batch
b,
ldb,
strideB, // distance in memory between the first element of neighboring batch
beta, // scalar beta
c, // source matrix C
ldc,
strideC, // distance in memory between the first element of neighboring batch
c, // destination matrix C (may be different memory than source C matrix)
ldc,
strideC, // distance in memory between the first element of neighboring batch
iterBatchCount
);
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// Launch the CUTLASS GEMM kernel.
THCudaCheck(Gemm::launch(params, stream));
// Update batched GEMM params based on completed work
batchesLeft = batchesLeft - iterBatchCount;
a += iterBatchCount * strideA;
b += iterBatchCount * strideB;
c += iterBatchCount * strideC;;
iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
} while(batchesLeft > 0);
} }
void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, long batchCount) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount) {
auto stream = c10::cuda::getCurrentCUDAStream(); auto stream = c10::cuda::getCurrentCUDAStream();
//printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);
if ( (transa == 't') && (transb == 'n') ) { if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { else { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
int m_rem = m % 64;
int n_rem = n % 64;
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
}
}*/
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
} else if ( (transa == 'n') && (transb == 'n') ) { } else if ( (transa == 'n') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { else { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
int m_rem = m % 64;
int n_rem = n % 64;
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
}
}*/
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
} else if ( (transa == 'n') && (transb == 't') ) { } else if ( (transa == 'n') && (transb == 't') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { else { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
int m_rem = m % 64;
int n_rem = n % 64;
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
}
}*/
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
} else { } else {
AT_ASSERTM(false, "TransA and TransB are invalid"); AT_ASSERTM(false, "TransA and TransB are invalid");
} }
...@@ -311,7 +107,7 @@ void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, i ...@@ -311,7 +107,7 @@ void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, i
void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k, void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, long batchCount) float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount)
{ {
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
...@@ -323,7 +119,7 @@ void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long ...@@ -323,7 +119,7 @@ void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
//gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); //gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); gemm_switch_fp32accum(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount);
} }
/****** /******
......
...@@ -234,12 +234,12 @@ void fused_adam_cuda( ...@@ -234,12 +234,12 @@ void fused_adam_cuda(
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) { if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
//all other values should be fp32 for half gradients //all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(), p.DATA_PTR<accscalar_t>(),
...@@ -308,12 +308,12 @@ void fused_adam_cuda_mt( ...@@ -308,12 +308,12 @@ void fused_adam_cuda_mt(
size_t tl_sz = tensor_lists.size(); size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) { if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half || tensor_lists[3][0].scalar_type() == at::ScalarType::BFloat16) {
//alher values should be fp32 for half gradients //alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type //dich is done on the gradient type
if (tl_sz == 5) { if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
...@@ -330,7 +330,7 @@ void fused_adam_cuda_mt( ...@@ -330,7 +330,7 @@ void fused_adam_cuda_mt(
decay); decay);
); );
} else { } else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
...@@ -846,13 +846,13 @@ void fused_reversible_adam_cuda( ...@@ -846,13 +846,13 @@ void fused_reversible_adam_cuda(
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) { if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
//all other values should be fp32 for half gradients //all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) { if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) {
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(), p.DATA_PTR<accscalar_t>(),
...@@ -871,7 +871,7 @@ void fused_reversible_adam_cuda( ...@@ -871,7 +871,7 @@ void fused_reversible_adam_cuda(
); );
} else { } else {
AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type"); AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type");
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_e5m2_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_e5m2_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>( reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(), p.DATA_PTR<accscalar_t>(),
...@@ -991,12 +991,12 @@ void fused_maybe_adam_undo_cuda( ...@@ -991,12 +991,12 @@ void fused_maybe_adam_undo_cuda(
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) { if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
//all other values should be fp32 for half gradients //all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
maybe_adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( maybe_adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL, overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
......
...@@ -187,7 +187,7 @@ void multi_tensor_fused_adam_cuda( ...@@ -187,7 +187,7 @@ void multi_tensor_fused_adam_cuda(
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tl_sz == 5) { if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
...@@ -206,7 +206,7 @@ void multi_tensor_fused_adam_cuda( ...@@ -206,7 +206,7 @@ void multi_tensor_fused_adam_cuda(
(adamMode_t) mode); (adamMode_t) mode);
); );
} else { } else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
......
...@@ -585,7 +585,7 @@ std::vector<Tensor> host_softmax_xentropy( ...@@ -585,7 +585,7 @@ std::vector<Tensor> host_softmax_xentropy(
const Tensor & labels_, const Tensor & labels_,
const float smoothing, const float smoothing,
const bool half_to_float){ const bool half_to_float){
if (half_to_float) AT_ASSERTM(input_.type().scalarType() == ScalarType::Half,"conversion is supported for Half type only"); if (half_to_float) AT_ASSERTM(input_.type().scalarType() == ScalarType::Half || input_.type().scalarType() == ScalarType::BFloat16,"conversion is supported for Half and BFloat16 type only");
AT_ASSERTM(labels_.type().scalarType() == ScalarType::Long,"Label type should be CUDA Long"); AT_ASSERTM(labels_.type().scalarType() == ScalarType::Long,"Label type should be CUDA Long");
auto input = input_.contiguous(); auto input = input_.contiguous();
...@@ -616,7 +616,7 @@ std::vector<Tensor> host_softmax_xentropy( ...@@ -616,7 +616,7 @@ std::vector<Tensor> host_softmax_xentropy(
dim3 grid(outer_size); dim3 grid(outer_size);
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "host_softmax_xentropy", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(input.scalar_type(), 0, "host_softmax_xentropy",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0); const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size); dim3 block = SoftMax_getBlockSize(ILP, dim_size);
...@@ -684,7 +684,7 @@ Tensor host_softmax_xentropy_backward( ...@@ -684,7 +684,7 @@ Tensor host_softmax_xentropy_backward(
dim3 grid(outer_size); dim3 grid(outer_size);
DISPATCH_FLOAT_AND_HALF(gI.scalar_type(), 0, "host_softmax_xentropy_backward", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(gI.scalar_type(), 0, "host_softmax_xentropy_backward",
using accscalar_t = acc_type<scalar_t_0, true>; using accscalar_t = acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0); const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size); dim3 block = SoftMax_getBlockSize(ILP, dim_size);
...@@ -723,7 +723,7 @@ at::Tensor softmax_xentropy_backward_cuda( ...@@ -723,7 +723,7 @@ at::Tensor softmax_xentropy_backward_cuda(
const float smoothing) { const float smoothing) {
bool half_to_float = grad_loss.type().scalarType() != logits.type().scalarType(); bool half_to_float = grad_loss.type().scalarType() != logits.type().scalarType();
if (half_to_float) { if (half_to_float) {
AT_ASSERTM((grad_loss.type().scalarType() == ScalarType::Float && logits.type().scalarType() == ScalarType::Half), "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); AT_ASSERTM((grad_loss.type().scalarType() == ScalarType::Float && (logits.type().scalarType() == ScalarType::Half || logits.type().scalarType() == ScalarType::BFloat16)), "expected input and grad types to match, or input to be at::Half or at::Bfloat16 and grad to be at::Float");
} }
return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float); return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float);
} }
...@@ -263,6 +263,6 @@ class EncdecAttnFunc(torch.autograd.Function): ...@@ -263,6 +263,6 @@ class EncdecAttnFunc(torch.autograd.Function):
input_q_grads, input_kv_grads, \ input_q_grads, input_kv_grads, \
input_weight_q_grads, input_weight_kv_grads, output_weight_grads, \ input_weight_q_grads, input_weight_kv_grads, output_weight_grads, \
input_bias_grads_q, input_bias_grads_kv, output_bias_grads, \ input_bias_grads_q, input_bias_grads_kv, output_bias_grads, \
None, None None, None, None
encdec_attn_func = EncdecAttnFunc.apply encdec_attn_func = EncdecAttnFunc.apply
...@@ -9,7 +9,7 @@ class FastSelfAttnNormAddFunc(torch.autograd.Function): ...@@ -9,7 +9,7 @@ class FastSelfAttnNormAddFunc(torch.autograd.Function):
dropout_prob_t = torch.tensor([dropout_prob]) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
print("---use_mask-----",use_mask)
lyr_nrm_results, \ lyr_nrm_results, \
lyr_nrm_mean, \ lyr_nrm_mean, \
lyr_nrm_invvar, \ lyr_nrm_invvar, \
......
...@@ -230,6 +230,6 @@ class SelfAttnFunc(torch.autograd.Function): ...@@ -230,6 +230,6 @@ class SelfAttnFunc(torch.autograd.Function):
input_grads, \ input_grads, \
input_weight_grads, output_weight_grads, \ input_weight_grads, output_weight_grads, \
input_bias_grads, output_bias_grads, \ input_bias_grads, output_bias_grads, \
None, None None, None, None
self_attn_func = SelfAttnFunc.apply self_attn_func = SelfAttnFunc.apply
...@@ -87,10 +87,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -87,10 +87,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
adam_w_mode=True, use_nvlamb=False, adam_w_mode=True, use_nvlamb=False,
step_supports_amp_scaling=True, overlap_reductions=True, step_supports_amp_scaling=True, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False, verbose=False, clip_after_ar=True, e5m2_allgather=False, verbose=False, clip_after_ar=True):
full_ar=False, set_param_views_to_flat_buffer=False, skip_allgather=False,
fuse_scale=False, param_order=None, nccl_allgather_channels=0):
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay, betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, grad_averaging=grad_averaging,
...@@ -122,12 +120,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -122,12 +120,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
self._verbose = verbose self._verbose = verbose
self._clip_after_ar = clip_after_ar self._clip_after_ar = clip_after_ar
self._full_ar = full_ar
self._fuse_scale = fuse_scale
self._L2_grad_norm = None self._L2_grad_norm = None
self._set_flat_param_view = set_param_views_to_flat_buffer
self._skip_ag = skip_allgather
self._fused_norm = fused_norm
self._current_process_group = c10d._get_default_group() self._current_process_group = c10d._get_default_group()
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys()) self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
...@@ -144,108 +138,63 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -144,108 +138,63 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
import inspect import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option" #assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
self._num_rs_pg = dwu_num_rs_pg self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1:
if self._full_ar: # full all reduce, only need AR and AG groups
self._ar_pg = [] self._ar_pg = []
# consider all the ranks for dev_i in range(self._group_size):
ranks = list(range(0, self._world_size)) ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks) for i in range(self._num_ar_pg):
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
for i in range(self._num_ar_pg):
if self._verbose:
print(f"creating new AR group {i}: {ranks}")
grp = torch.distributed.new_group(ranks=ranks)
if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
if self._verbose: if self._verbose:
print(f"group {i}: init barrier (device: {torch.cuda.current_device()})") print(f"creating new group {i}: {ranks}")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()]) grp = torch.distributed.new_group(ranks=ranks)
if self._verbose: if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
print(f"created new AR group {i}: {ranks}")
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
if nccl_allgather_channels > 0:
os.putenv('NCCL_MAX_NCHANNELS', str(nccl_allgather_channels))
if self._num_ag_pg == 0:
self._ag_pg = self._ar_pg
self._ag_st = self._ar_st
self._num_ag_pg = self._num_ar_pg
else:
self._ag_pg = []
ranks = []
stride = torch.cuda.device_count()
for i in range(self._num_groups):
rs = list(range(i*stride, (i+1)*stride))
ranks.append(rs)
for rs in ranks:
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=rs)
if torch.distributed.get_rank() in rs:
if self._verbose:
print(f"creating AG group {i}: {rs}")
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
else: # reduce-scatter + all-reduce, need RS, AR, AG groups
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
if self._verbose:
print(f"creating new AR group {i}: {ranks}")
grp = torch.distributed.new_group(ranks=ranks)
if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
if self._verbose:
print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
if self._verbose: if self._verbose:
print(f"created new AR group {i}: {ranks}") print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
if self._verbose:
print(f"created new group {i}")
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp) self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)] self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
rs_ranks = [] #for ar_pg in self._ar_pg:
for group_i in range(self._num_groups): # torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)]) rs_ranks = []
self._rs_pg = [] for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
#torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
#for rs_pg in self._rs_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups): for group_i in range(self._num_groups):
ranks = rs_ranks[group_i] ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg): for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks) grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp) self._ag_pg.append(grp)
if self._verbose: self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
print(f"creating RS group : {ranks}") #for ag_pg in self._ag_pg:
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks) # torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
if self._verbose:
print(f"creating AG group : {ranks}")
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.barrier(group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream() self._l2_grad_norm_st = torch.cuda.Stream()
self._completion_st = torch.cuda.Stream() self._completion_st = torch.cuda.Stream()
self._step.record_stream(self._completion_st) self._step.record_stream(self._completion_st)
...@@ -259,6 +208,9 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -259,6 +208,9 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False
self._param_order = self.AtomicCounter() self._param_order = self.AtomicCounter()
def _lazy_init_stage1(self):
if self._lazy_init_stage1_done: return
p_offset = 0 p_offset = 0
p_i = 0 p_i = 0
self._model_params = [] self._model_params = []
...@@ -272,6 +224,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -272,6 +224,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
eps = group['eps'] eps = group['eps']
weight_decay = group['weight_decay'] weight_decay = group['weight_decay']
for p in group['params']: for p in group['params']:
torch.distributed.broadcast(p, 0)
if not p.requires_grad: if not p.requires_grad:
continue continue
self._model_params.append(p) self._model_params.append(p)
...@@ -284,12 +237,19 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -284,12 +237,19 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
eps eps
)) ))
p_grads_size = p.numel() p_grads_size = p.numel()
if self._set_flat_param_view: def wrapper(param, param_i):
if param_order: param_tmp = param.expand_as(param)
# this is executed when param_order is specified by the user grad_acc = param_tmp.grad_fn.next_functions[0][0]
self._param_order.add(param_order[p]) def allreduce_hook(*unused):
else: if self._first_step:
self._param_order.add(p_i) # first time
self._param_order.add(param_i)
else:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
wrapper(p, p_i)
p_offset += p_grads_size p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters: # RNN is one example of consecutive parameters:
...@@ -298,8 +258,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -298,8 +258,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
p_offset = ((p_offset + 63) // 64) * 64 p_offset = ((p_offset + 63) // 64) * 64
prev = p prev = p
p_i += 1 p_i += 1
if param_order:
self._param_order.order = torch.argsort(torch.tensor(self._param_order.order)).tolist()
self._grads_generated = [False]*len(self._model_params) self._grads_generated = [False]*len(self._model_params)
self._grads_fp16, self._grads_fp32 = [], [] self._grads_fp16, self._grads_fp32 = [], []
if self._overlap_reductions: if self._overlap_reductions:
...@@ -348,6 +306,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -348,6 +306,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._block_size = self._total_param_size // self._num_blocks self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size self._shard_size = self._chunk_size // self._group_size
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda') self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
...@@ -482,6 +441,45 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -482,6 +441,45 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._low_param_i[block_id] = p_i self._low_param_i[block_id] = p_i
#print("self._low_param_i", self._low_param_i) #print("self._low_param_i", self._low_param_i)
self._lazy_init_stage1_done = True
def _lazy_init_stage2(self):
if self._lazy_init_stage2_done: return
self._param_order.order.reverse()
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
# re-collect grads info (size, offset) after ordering
prev = None
p_offset = 0
self._grads_info = []
self._individual_flat_grads = []
for i, p in enumerate(self._model_params):
p_grads_size = p.numel()
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))
# for the first iteration
self._do_overlapped_reduction(i, p)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
#print("self._low_param_i", self._low_param_i)
# This paragraph does two things: # This paragraph does two things:
# 1) Copy model parameters into master buffer # 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather # 2) Create tensor lists for unpacking new parameter tensor after all-gather
...@@ -589,90 +587,30 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -589,90 +587,30 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return flush_block return flush_block
def _full_all_reduce_scale(self, block_id, scale):
works = [None]*self._num_chunks
if self._clip_after_ar:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
else:
glob_chunk_id = block_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works0 = torch.distributed.all_reduce(self._flat_grads_blocks[block_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
for i in range(self._num_chunks):
works[i]=works0
self._reductions_works[block_id] = works
def _full_all_reduce(self, block_id):
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _reduce_scatter_and_all_reduce_scale(self, block_id, scale):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _reduce_scatter_and_all_reduce(self, block_id):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _pipeline_block_reductions(self, block_id): def _pipeline_block_reductions(self, block_id):
if self._clip_after_ar: if self._clip_after_ar:
self._flatten_grad_mt(1.0/self._world_size) self._flatten_grad_mt(1.0/self._world_size)
if self._full_ar: # Reduction within each node
self._full_all_reduce(block_id) # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
else: # The output format is the same as the fp32 master parameters
self._reduce_scatter_and_all_reduce(block_id) works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True, no_copy=False)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
# Compute L2 grad norm # Compute L2 grad norm
if block_id == 0: if block_id == 0:
...@@ -682,10 +620,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -682,10 +620,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._reductions_works[block_id][chunk_id].wait() self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed # Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda') l2_grad_norm_sq = torch.empty([1], device='cuda')
if 0:#self._full_ar: l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
l2_grad_norm_sq = self._flat_grads_shards[self._rank_in_group].norm(dtype=torch.float32, p=2)**2
else:
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg) torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt() self._L2_grad_norm = l2_grad_norm_sq.sqrt()
else: else:
...@@ -695,8 +630,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -695,8 +630,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
# Compute L2 grad norm # Compute L2 grad norm
self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream()) self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._l2_grad_norm_st): with torch.cuda.stream(self._l2_grad_norm_st):
if not self._fused_norm: self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st) torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
# Apply clipping & pre-reduction scaling on grads # Apply clipping & pre-reduction scaling on grads
...@@ -707,19 +641,29 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -707,19 +641,29 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
tmp = torch.cat(((self._one), (coeff))) tmp = torch.cat(((self._one), (coeff)))
index = (coeff+1>coeff).int() index = (coeff+1>coeff).int()
scale = tmp.index_select(0, index).half()/self._world_size scale = tmp.index_select(0, index).half()/self._world_size
if not self._fuse_scale: self._flat_grads.mul_(scale)
self._flat_grads.mul_(scale)
# Reduction within each node
if self._full_ar: # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
if self._fuse_scale: # The output format is the same as the fp32 master parameters
self._full_all_reduce_scale(block_id, scale) works = [None]*self._num_chunks
else: for chunk_id in range(self._num_chunks):
self._full_all_reduce(block_id) glob_chunk_id = block_id * self._num_chunks + chunk_id
else: rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
if self._fuse_scale: rs_stream.wait_stream(torch.cuda.current_stream())
self._reduce_scatter_and_all_reduce_scale(block_id, scale) rs_stream.wait_stream(self._l2_grad_norm_st)
else: with torch.cuda.stream(rs_stream):
self._reduce_scatter_and_all_reduce(block_id) works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True, no_copy=False)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
if block_id == 0: if block_id == 0:
for block_id in range(self._num_blocks): for block_id in range(self._num_blocks):
...@@ -757,14 +701,12 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -757,14 +701,12 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
# check global_grad_norm and fill overflow_buf # check global_grad_norm and fill overflow_buf
is_finite = (global_grad_norm + 1 > global_grad_norm).int() is_finite = (global_grad_norm + 1 > global_grad_norm).int()
self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1 self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1
torch.distributed.all_reduce(is_finite,
if not self._clip_after_ar: op=torch.distributed.ReduceOp.MIN,
torch.distributed.all_reduce(is_finite, group=self._current_process_group)
op=torch.distributed.ReduceOp.MIN, torch.distributed.all_reduce(self._overflow_buf,
group=self._current_process_group) op=torch.distributed.ReduceOp.MAX,
torch.distributed.all_reduce(self._overflow_buf, group=self._current_process_group)
op=torch.distributed.ReduceOp.MAX,
group=self._current_process_group)
# increment step counter if no overflow # increment step counter if no overflow
self._step += is_finite self._step += is_finite
...@@ -803,14 +745,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -803,14 +745,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_weight_decay, self._contrib_weight_decay,
global_grad_norm, global_grad_norm,
self._use_nvlamb) self._use_nvlamb)
if not self._skip_ag: torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0],no_copy=False)
# allgather chunking is currently not supported for clip after allreduce
if not self._clip_after_ar:
for block in range(self._num_blocks):
for chunk in range(self._num_chunks):
torch.distributed.all_gather(self._new_params2_shards[block][chunk], self._fp16_p_chunks[block][chunk], group=self._ag_pg[0], no_copy=True)
else:
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale): def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0: if len(self._grads_fp16) > 0:
...@@ -912,21 +847,21 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -912,21 +847,21 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
optimizer_state["found_inf_per_device"][current_device] = found_inf optimizer_state["found_inf_per_device"][current_device] = found_inf
self._completion_st.wait_stream(torch.cuda.current_stream()) self._completion_st.wait_stream(torch.cuda.current_stream())
if not self._set_flat_param_view:
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params # Copy self._new_params to model params
with torch.no_grad(): with torch.no_grad():
if self._packed_flat_to_model_params_fp16 is not None: if self._packed_flat_to_model_params_fp16 is not None:
multi_tensor_applier( multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt, fused_adam_cuda.maybe_cast_mt,
self._overflow_buf, self._overflow_buf,
self._packed_flat_to_model_params_fp16) self._packed_flat_to_model_params_fp16)
if self._packed_flat_to_model_params_fp32 is not None: if self._packed_flat_to_model_params_fp32 is not None:
multi_tensor_applier( multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt, fused_adam_cuda.maybe_cast_mt,
self._overflow_buf, self._overflow_buf,
self._packed_flat_to_model_params_fp32) self._packed_flat_to_model_params_fp32)
torch.cuda.current_stream().wait_stream(self._completion_st) torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks self._reductions_works = [None]*self._num_blocks
......
...@@ -137,17 +137,22 @@ version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 ...@@ -137,17 +137,22 @@ version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--distributed_adam" in sys.argv: if "--distributed_adam" in sys.argv:
sys.argv.remove("--distributed_adam") sys.argv.remove("--distributed_adam")
if CUDA_HOME is None: from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
nvcc_args_adam = ['-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_adam = ['-O3'] + version_dependent_macros
ext_modules.append( ext_modules.append(
CUDAExtension(name='distributed_adam_cuda', CUDAExtension(name='distributed_adam_cuda',
sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp', sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp',
'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'], 'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/optimizers')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3', 'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam}))
'--use_fast_math'] + version_dependent_macros}))
if "--distributed_lamb" in sys.argv: if "--distributed_lamb" in sys.argv:
sys.argv.remove("--distributed_lamb") sys.argv.remove("--distributed_lamb")
...@@ -296,7 +301,8 @@ if "--xentropy" in sys.argv: ...@@ -296,7 +301,8 @@ if "--xentropy" in sys.argv:
CUDAExtension(name='xentropy_cuda', CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp', sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'], 'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/xentropy')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
...@@ -317,7 +323,8 @@ if "--deprecated_fused_adam" in sys.argv: ...@@ -317,7 +323,8 @@ if "--deprecated_fused_adam" in sys.argv:
CUDAExtension(name='fused_adam_cuda', CUDAExtension(name='fused_adam_cuda',
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp', sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'], 'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/optimizers')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam})) 'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))
...@@ -371,6 +378,7 @@ if "--fast_layer_norm" in sys.argv: ...@@ -371,6 +378,7 @@ if "--fast_layer_norm" in sys.argv:
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'-Iapex/contrib/csrc/layer_norm',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}, '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
...@@ -416,121 +424,101 @@ if "--fmha" in sys.argv: ...@@ -416,121 +424,101 @@ if "--fmha" in sys.argv:
if "--fast_multihead_attn" in sys.argv: if "--fast_multihead_attn" in sys.argv:
sys.argv.remove("--fast_multihead_attn") sys.argv.remove("--fast_multihead_attn")
if CUDA_HOME is None: from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) if not IS_ROCM_PYTORCH:
if int(bare_metal_major) >= 11: _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
cc_flag.append('-gencode') if int(bare_metal_major) >= 11:
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
nvcc_args_mha = ['-O3',
'-gencode',
'arch=compute_70,code=sm_70',
'-Iapex/contrib/csrc/multihead_attn/cutlass',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag
hipcc_args_mha = ['-O3',
'-Iapex/contrib/csrc/multihead_attn/cutlass',
'-I/opt/rocm/include/hiprand',
'-I/opt/rocm/include/rocrand',
'-U__HIP_NO_HALF_OPERATORS__',
'-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_additive_mask_softmax_dropout', CUDAExtension(name='fast_additive_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp', sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp',
'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'], 'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_mask_softmax_dropout', CUDAExtension(name='fast_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp', sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp',
'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'], 'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask', CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'], 'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias', CUDAExtension(name='fast_self_multihead_attn_bias',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'], 'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn', CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'], 'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add', CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'], 'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn', CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'], 'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add', CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'], 'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
if "--transducer" in sys.argv: if "--transducer" in sys.argv:
sys.argv.remove("--transducer") sys.argv.remove("--transducer")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment