Commit cf0b0f01 authored by Hubert Lu's avatar Hubert Lu
Browse files

Fix some bugs related to THCState and cutlass

parent 9615983e
......@@ -140,8 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -194,8 +193,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -371,8 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -394,8 +391,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -434,8 +430,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -457,8 +452,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -595,3 +589,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace encdec
} // end namespace multihead_attn
......@@ -166,8 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -220,8 +219,7 @@ std::vector<torch::Tensor> fwd_cuda(
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -435,8 +433,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -458,8 +455,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -498,8 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -521,8 +516,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -675,3 +669,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace encdec_norm_add
} // end namespace multihead_attn
......@@ -116,8 +116,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -162,8 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -327,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -350,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -388,8 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -411,8 +406,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......
......@@ -108,8 +108,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -162,8 +161,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -327,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -350,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -383,8 +379,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches * q_seq_len, stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -406,8 +401,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -489,3 +483,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
......@@ -106,8 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -160,8 +159,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -322,8 +320,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -345,8 +342,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -385,8 +381,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -408,8 +403,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -493,3 +487,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
......@@ -128,8 +128,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -182,8 +181,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -380,8 +378,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -403,8 +400,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -443,8 +439,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -466,8 +461,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -565,3 +559,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace self_norm_add
} // end namespace multihead_attn
......@@ -161,7 +161,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src,
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
......@@ -186,7 +186,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -402,7 +402,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
......@@ -426,7 +426,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
auto seeds = at::cuda::philox::unpack(philox_args);
......@@ -564,7 +564,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
......@@ -588,7 +588,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
curandStatePhilox4_32_10_t state;
......@@ -874,7 +874,7 @@ __global__ void additive_masked_softmax_warp_forward(
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
......@@ -899,7 +899,7 @@ __global__ void additive_masked_softmax_warp_forward(
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -1164,7 +1164,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src,
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
......@@ -1189,7 +1189,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -1414,7 +1414,7 @@ __global__ void time_masked_softmax_warp_forward(
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
......@@ -1439,7 +1439,7 @@ __global__ void time_masked_softmax_warp_forward(
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -1586,13 +1586,13 @@ int log2_ceil_native(int value) {
}
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__)
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE>
__device__ __forceinline__ void warp_reduce_sum(acc_t *sum) {
......@@ -2149,7 +2149,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
......@@ -2174,7 +2174,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -2754,7 +2754,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -2988,7 +2988,7 @@ masked_softmax_warp_backward(__half *gradInput, const __half *grad,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -3137,3 +3137,4 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad,
}
return false;
}
......@@ -10,9 +10,9 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h"
//#include "cutlass/cutlass.h"
//#include "cutlass/gemm/gemm.h"
//#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
......@@ -110,7 +110,8 @@ void HgemmStridedBatched(char transa, char transb, long m,
long n, long k, float alpha, const half *a, long lda,
long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC,
long batchCount) {
half *d, long ldd, long strideD, long batchCount) {
if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) ||
(ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX))
......@@ -129,3 +130,4 @@ void HgemmStridedBatched(char transa, char transb, long m,
b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment