"...composable_kernel_rocm.git" did not exist on "9aa174066f769dd944ff32f263d06f8679e59a4d"
Commit 7e3a5613 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean up

parent 50530c17
......@@ -90,10 +90,10 @@ struct ExecutionConfig final
bool time_kernel = true;
};
#define DefaultConvParam \
ck::utils::conv::ConvParam \
{ \
2, 32, 2, 32, 32, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
#define DefaultConvParam \
ck::utils::conv::ConvParam \
{ \
2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
}
inline void print_help_msg()
......
......@@ -90,10 +90,10 @@ struct ExecutionConfig final
bool time_kernel = true;
};
#define DefaultConvParam \
ck::utils::conv::ConvParam \
{ \
2, 32, 2, 32, 32, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
#define DefaultConvParam \
ck::utils::conv::ConvParam \
{ \
2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
}
inline void print_help_msg()
......
......@@ -9,10 +9,10 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 60;
ck::index_t N = 100;
ck::index_t M = 120;
ck::index_t N = 1000;
ck::index_t K = 64;
ck::index_t O = 64;
ck::index_t O = 128;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
......
......@@ -194,9 +194,6 @@ struct BlockwiseGemmWMMA
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
// static_assert(AEnableLds == true, "only support EnableLds");
// static_assert(BEnableLds == true, "only support EnableLds");
}
// transposed WMMA output C' = B' * A'
......
......@@ -137,8 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
......
......@@ -562,7 +562,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_manu = true;
static constexpr auto AEnableLds_manu = false;
static constexpr auto B0EnableLds_manu = true;
static constexpr auto B1EnableLds_manu = true;
......
......@@ -300,7 +300,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_manu = true;
static constexpr auto AEnableLds_manu = false;
static constexpr auto B0EnableLds_manu = true;
static constexpr auto B1EnableLds_manu = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment