Commit 72786aea authored by danyao12's avatar danyao12 Committed by wunhuang
Browse files

rename and sync macro

parent 02415adc
......@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
*/
#define RANGE_HDKO 0 // 0~2
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -75,7 +75,7 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
#if(RANGE_HDKO == 0)
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
......@@ -145,7 +145,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 1)
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
......@@ -215,7 +215,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 2)
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
......
......@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
*/
#define RANGE_HDKO 0 // 0~2
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -75,7 +75,7 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
#if(RANGE_HDKO == 0)
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
......@@ -145,7 +145,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 1)
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
......@@ -215,7 +215,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 2)
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
......
......@@ -11,14 +11,8 @@ int run(int argc, char* argv[])
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 1000; // 120
ck::index_t N = 1000; // 1000
#if(RANGE_HDKO == 0)
ck::index_t K = 32; // K/O<=32
#elif(RANGE_HDKO == 1)
ck::index_t K = 64; // 32<K/O<=64
#elif(RANGE_HDKO == 2)
ck::index_t K = 72; // 64<K/O<=128
#endif
ck::index_t O = K;
ck::index_t K = DIM;
ck::index_t O = DIM;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
......
......@@ -83,16 +83,10 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++)
{
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
#if(RANGE_HDKO == 0)
int K = 32; // K/O<=32
#elif(RANGE_HDKO == 1)
int K = 56; // 32<K/O<=64
#elif(RANGE_HDKO == 2)
int K = 80; // 64<K/O<=128
#endif
int O = K;
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1;
......
......@@ -39,7 +39,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_pt1(
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -909,7 +909,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_pt1<
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment