Commit ac407086 authored by ltqin's avatar ltqin
Browse files

change k=64 config

parent a465a936
......@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 1
#define USING_K128 1
#define USING_K128 0
#include <iostream>
#include <numeric>
......@@ -213,7 +213,7 @@ using DeviceGemmInstance =
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
......@@ -340,8 +340,8 @@ int run(int argc, char* argv[])
ck::index_t K = 64;
ck::index_t O = 64;
#endif
ck::index_t G0 = 54;
ck::index_t G1 = 16;
ck::index_t G0 = 3;
ck::index_t G1 = 2;
float alpha = 1.f / std::sqrt(K);
......
......@@ -49,7 +49,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid,
......
......@@ -90,7 +90,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
using type = T;
};
#if defined(__gfx90a__)
#if defined(__gfx90a_masking__)
template <>
struct TypeMap<ck::half_t>
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment