Commit cf0b0f01 authored by Hubert Lu's avatar Hubert Lu
Browse files

Fix some bugs related to THCState and cutlass

parent 9615983e
...@@ -140,8 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -140,8 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -194,8 +193,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -194,8 +193,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -371,8 +369,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -371,8 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -394,8 +391,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -394,8 +391,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -434,8 +430,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -434,8 +430,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -457,8 +452,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -457,8 +452,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -595,3 +589,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -595,3 +589,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace encdec } // end namespace encdec
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -166,8 +166,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -166,8 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
solution_index, solution_index,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -220,8 +219,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -220,8 +219,7 @@ std::vector<torch::Tensor> fwd_cuda(
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -435,8 +433,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -435,8 +433,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -458,8 +455,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -458,8 +455,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -498,8 +494,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -498,8 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -521,8 +516,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -521,8 +516,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -675,3 +669,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -675,3 +669,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace encdec_norm_add } // end namespace encdec_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -116,8 +116,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -116,8 +116,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -162,8 +161,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -162,8 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -327,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -327,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -350,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -350,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -388,8 +384,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -388,8 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
stream); stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -411,8 +406,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -411,8 +406,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
......
...@@ -108,8 +108,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -108,8 +108,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -162,8 +161,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -162,8 +161,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -327,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -327,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -350,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -350,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -383,8 +379,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -383,8 +379,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches * q_seq_len, stream); attn_batches * q_seq_len, stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -406,8 +401,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -406,8 +401,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -489,3 +483,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -489,3 +483,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -106,8 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -106,8 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -160,8 +159,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -160,8 +159,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -322,8 +320,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -322,8 +320,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -345,8 +342,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -345,8 +342,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -385,8 +381,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -385,8 +381,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -408,8 +403,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -408,8 +403,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -493,3 +487,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -493,3 +487,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -128,8 +128,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -128,8 +128,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -182,8 +181,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -182,8 +181,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -380,8 +378,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -380,8 +378,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -403,8 +400,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -403,8 +400,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -443,8 +439,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -443,8 +439,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -466,8 +461,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -466,8 +461,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -565,3 +559,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -565,3 +559,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace self_norm_add } // end namespace self_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -161,7 +161,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, ...@@ -161,7 +161,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src,
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
...@@ -186,7 +186,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, ...@@ -186,7 +186,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -402,7 +402,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( ...@@ -402,7 +402,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
...@@ -426,7 +426,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( ...@@ -426,7 +426,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
auto seeds = at::cuda::philox::unpack(philox_args); auto seeds = at::cuda::philox::unpack(philox_args);
...@@ -564,7 +564,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward( ...@@ -564,7 +564,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
...@@ -588,7 +588,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward( ...@@ -588,7 +588,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
...@@ -874,7 +874,7 @@ __global__ void additive_masked_softmax_warp_forward( ...@@ -874,7 +874,7 @@ __global__ void additive_masked_softmax_warp_forward(
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
...@@ -899,7 +899,7 @@ __global__ void additive_masked_softmax_warp_forward( ...@@ -899,7 +899,7 @@ __global__ void additive_masked_softmax_warp_forward(
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -1164,7 +1164,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src, ...@@ -1164,7 +1164,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src,
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
...@@ -1189,7 +1189,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src, ...@@ -1189,7 +1189,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -1414,7 +1414,7 @@ __global__ void time_masked_softmax_warp_forward( ...@@ -1414,7 +1414,7 @@ __global__ void time_masked_softmax_warp_forward(
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
...@@ -1439,7 +1439,7 @@ __global__ void time_masked_softmax_warp_forward( ...@@ -1439,7 +1439,7 @@ __global__ void time_masked_softmax_warp_forward(
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -1586,13 +1586,13 @@ int log2_ceil_native(int value) { ...@@ -1586,13 +1586,13 @@ int log2_ceil_native(int value) {
} }
template <typename T> template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) {
{
#if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__) #if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__)
return __shfl_xor_sync(mask, value, laneMask, width); return __shfl_xor_sync(mask, value, laneMask, width);
#else #else
return __shfl_xor(value, laneMask, width); return __shfl_xor(value, laneMask, width);
#endif #endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE> template <typename acc_t, int WARP_BATCH, int WARP_SIZE>
__device__ __forceinline__ void warp_reduce_sum(acc_t *sum) { __device__ __forceinline__ void warp_reduce_sum(acc_t *sum) {
...@@ -2149,7 +2149,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute( ...@@ -2149,7 +2149,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(
float val[WARP_BATCH]; float val[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
...@@ -2174,7 +2174,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute( ...@@ -2174,7 +2174,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -2754,7 +2754,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, ...@@ -2754,7 +2754,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -2988,7 +2988,7 @@ masked_softmax_warp_backward(__half *gradInput, const __half *grad, ...@@ -2988,7 +2988,7 @@ masked_softmax_warp_backward(__half *gradInput, const __half *grad,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
...@@ -3137,3 +3137,4 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad, ...@@ -3137,3 +3137,4 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad,
} }
return false; return false;
} }
...@@ -10,9 +10,9 @@ ...@@ -10,9 +10,9 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include "cutlass/cutlass.h" //#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h" //#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h" //#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
...@@ -110,7 +110,8 @@ void HgemmStridedBatched(char transa, char transb, long m, ...@@ -110,7 +110,8 @@ void HgemmStridedBatched(char transa, char transb, long m,
long n, long k, float alpha, const half *a, long lda, long n, long k, float alpha, const half *a, long lda,
long strideA, const half *b, long ldb, long strideB, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, float beta, half *c, long ldc, long strideC,
long batchCount) { half *d, long ldd, long strideD, long batchCount) {
if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) ||
(ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX)) (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX))
...@@ -129,3 +130,4 @@ void HgemmStridedBatched(char transa, char transb, long m, ...@@ -129,3 +130,4 @@ void HgemmStridedBatched(char transa, char transb, long m,
b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount); b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment