Commit dded3450 authored by danyao12's avatar danyao12
Browse files

prototype2 qloop direction w/ layout change

parent 59088eca
......@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -37,7 +37,7 @@ Kernel outputs:
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -95,7 +95,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
static constexpr bool Deterministic = false;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
......@@ -298,6 +298,75 @@ using DeviceGemmInstance =
// MaskingSpec,
// Deterministic>;
#elif(DIM <= 128)
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
// NumDimK,
// NumDimO,
// InputDataType,
// OutputDataType,
// GemmDataType,
// ZDataType,
// LSEDataType,
// Acc0BiasDataType,
// Acc1BiasDataType,
// AccDataType,
// ShuffleDataType,
// QKVElementOp,
// QKVElementOp,
// Scale,
// QKVElementOp,
// YElementOp,
// GemmSpec,
// TensorSpecQ,
// TensorSpecK,
// TensorSpecV,
// TensorSpecY,
// 1,
// 256,
// 128, // MPerBlock
// 128, // NPerBlock
// 64, // KPerBlock
// 128, // Gemm1NPerBlock
// 32, // Gemm1KPerBlock
// 8, // AK1
// 8, // BK1
// 2, // B1K1
// 32, // MPerXDL
// 32, // NPerXDL
// 4, // MXdlPerWave
// 1, // NXdlPerWave
// 4, // Gemm1NXdlPerWave
// 2, // Gemm2NXdlPerWave
// S<4, 64, 1>, // ABlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<4, 64, 1>, // BBlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<8, 32, 1>, // B1BlockTransfer
// S<0, 2, 1>,
// S<0, 2, 1>,
// 1,
// 4,
// 2,
// false,
// 1, // CShuffleMXdlPerWavePerShuffle
// 4, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock, //
// CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec, // MaskingSpecialization
// Deterministic>;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG,
......@@ -326,7 +395,7 @@ using DeviceGemmInstance =
TensorSpecY,
1,
256,
128, // MPerBlock
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
128, // Gemm1NPerBlock
......@@ -336,10 +405,10 @@ using DeviceGemmInstance =
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // MXdlPerWave
1, // NXdlPerWave
4, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment