Commit a465a936 authored by ltqin's avatar ltqin
Browse files

add k=64 config

parent d7244818
......@@ -25,6 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 1
#define USING_K128 1
#include <iostream>
#include <numeric>
......@@ -87,6 +88,7 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
#if USING_K128
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
NumDimG,
......@@ -153,6 +155,73 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#else
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
......@@ -222,14 +291,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
#if USING_MASK
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
#endif
// P = Softmax(S)
auto ref_softmax = ReferenceSoftmaxInstance{};
......@@ -266,10 +333,15 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
#if USING_K128
ck::index_t K = 128;
ck::index_t O = 128;
ck::index_t G0 = 3;
ck::index_t G1 = 2;
#else
ck::index_t K = 64;
ck::index_t O = 64;
#endif
ck::index_t G0 = 54;
ck::index_t G1 = 16;
float alpha = 1.f / std::sqrt(K);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment