Commit 7e3a5613 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean up

parent 50530c17
...@@ -90,10 +90,10 @@ struct ExecutionConfig final ...@@ -90,10 +90,10 @@ struct ExecutionConfig final
bool time_kernel = true; bool time_kernel = true;
}; };
#define DefaultConvParam \ #define DefaultConvParam \
ck::utils::conv::ConvParam \ ck::utils::conv::ConvParam \
{ \ { \
2, 32, 2, 32, 32, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \ 2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
} }
inline void print_help_msg() inline void print_help_msg()
......
...@@ -90,10 +90,10 @@ struct ExecutionConfig final ...@@ -90,10 +90,10 @@ struct ExecutionConfig final
bool time_kernel = true; bool time_kernel = true;
}; };
#define DefaultConvParam \ #define DefaultConvParam \
ck::utils::conv::ConvParam \ ck::utils::conv::ConvParam \
{ \ { \
2, 32, 2, 32, 32, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \ 2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
} }
inline void print_help_msg() inline void print_help_msg()
......
...@@ -9,10 +9,10 @@ int run(int argc, char* argv[]) ...@@ -9,10 +9,10 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 60; ck::index_t M = 120;
ck::index_t N = 100; ck::index_t N = 1000;
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t O = 64; ck::index_t O = 128;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
......
...@@ -194,9 +194,6 @@ struct BlockwiseGemmWMMA ...@@ -194,9 +194,6 @@ struct BlockwiseGemmWMMA
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0, NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
// static_assert(AEnableLds == true, "only support EnableLds");
// static_assert(BEnableLds == true, "only support EnableLds");
} }
// transposed WMMA output C' = B' * A' // transposed WMMA output C' = B' * A'
......
...@@ -137,8 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -137,8 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = true; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = true; static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
......
...@@ -562,7 +562,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -562,7 +562,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_manu = true; static constexpr auto AEnableLds_manu = false;
static constexpr auto B0EnableLds_manu = true; static constexpr auto B0EnableLds_manu = true;
static constexpr auto B1EnableLds_manu = true; static constexpr auto B1EnableLds_manu = true;
......
...@@ -300,7 +300,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma ...@@ -300,7 +300,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_manu = true; static constexpr auto AEnableLds_manu = false;
static constexpr auto B0EnableLds_manu = true; static constexpr auto B0EnableLds_manu = true;
static constexpr auto B1EnableLds_manu = true; static constexpr auto B1EnableLds_manu = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment