"...composable_kernel_rocm.git" did not exist on "73b67f290f6602fe0461d48a2c103de460f14084"
Commit 55057f09 authored by danyao12's avatar danyao12
Browse files

rename and sync macro

parent 51ec5aa0
...@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1 Gemm1
*/ */
#define RANGE_HDKO 0 // 0~2 #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -75,7 +75,7 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial ...@@ -75,7 +75,7 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
#if(RANGE_HDKO == 0) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
...@@ -145,7 +145,7 @@ using DeviceGemmInstance = ...@@ -145,7 +145,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 1) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
...@@ -215,7 +215,7 @@ using DeviceGemmInstance = ...@@ -215,7 +215,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 2) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
......
...@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1 Gemm1
*/ */
#define RANGE_HDKO 0 // 0~2 #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -75,7 +75,7 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial ...@@ -75,7 +75,7 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
#if(RANGE_HDKO == 0) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
...@@ -145,7 +145,7 @@ using DeviceGemmInstance = ...@@ -145,7 +145,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 1) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
...@@ -215,7 +215,7 @@ using DeviceGemmInstance = ...@@ -215,7 +215,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 2) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
......
...@@ -11,14 +11,8 @@ int run(int argc, char* argv[]) ...@@ -11,14 +11,8 @@ int run(int argc, char* argv[])
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 1000; // 120 ck::index_t M = 1000; // 120
ck::index_t N = 1000; // 1000 ck::index_t N = 1000; // 1000
#if(RANGE_HDKO == 0) ck::index_t K = DIM;
ck::index_t K = 32; // K/O<=32 ck::index_t O = DIM;
#elif(RANGE_HDKO == 1)
ck::index_t K = 64; // 32<K/O<=64
#elif(RANGE_HDKO == 2)
ck::index_t K = 72; // 64<K/O<=128
#endif
ck::index_t O = K;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
......
...@@ -85,14 +85,8 @@ int run(int argc, char* argv[]) ...@@ -85,14 +85,8 @@ int run(int argc, char* argv[])
int M = 128 * (rand() % 8) + (rand() % 128); int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128); int N = 128 * (rand() % 8) + (rand() % 128);
#if(RANGE_HDKO == 0) int K = DIM;
int K = 32; // K/O<=32 int O = DIM;
#elif(RANGE_HDKO == 1)
int K = 56; // 32<K/O<=64
#elif(RANGE_HDKO == 2)
int K = 80; // 64<K/O<=128
#endif
int O = K;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G1 = rand() % 5 + 1;
......
...@@ -39,7 +39,7 @@ __global__ void ...@@ -39,7 +39,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_pt1( kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -909,7 +909,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -909,7 +909,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_pt1< const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment