Commit 83b53ec8 authored by fsx950223's avatar fsx950223
Browse files

add pt1 backward

parent b68a890d
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 1 #define USING_MASK 0
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -35,6 +35,7 @@ Kernel outputs: ...@@ -35,6 +35,7 @@ Kernel outputs:
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_pt1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -45,7 +46,8 @@ Kernel outputs: ...@@ -45,7 +46,8 @@ Kernel outputs:
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
#define FLASH_ATTN_IMPLENTATION 0
#define DIM 32
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -88,15 +90,16 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali ...@@ -88,15 +90,16 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
#if FLASH_ATTN_IMPLENTATION #if DIM == 32
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
...@@ -117,8 +120,8 @@ using DeviceGemmInstance = ...@@ -117,8 +120,8 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 32, // Gemm1NPerBlock
64, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -126,7 +129,8 @@ using DeviceGemmInstance = ...@@ -126,7 +129,8 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -149,13 +153,80 @@ using DeviceGemmInstance = ...@@ -149,13 +153,80 @@ using DeviceGemmInstance =
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif DIM == 64
// 2nd template
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#else #else
using DeviceGemmInstance = using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -331,8 +402,7 @@ int run(int argc, char* argv[]) ...@@ -331,8 +402,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float K = 128; float alpha = 1.f / std::sqrt(DIM);
float alpha = 1.f / std::sqrt(K);
float p_drop = 0.2; float p_drop = 0.2;
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
...@@ -341,7 +411,7 @@ int run(int argc, char* argv[]) ...@@ -341,7 +411,7 @@ int run(int argc, char* argv[])
const unsigned long long offset = 0; const unsigned long long offset = 0;
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = true;
if(argc == 1) if(argc == 1)
{ {
...@@ -428,8 +498,8 @@ int run(int argc, char* argv[]) ...@@ -428,8 +498,8 @@ int run(int argc, char* argv[])
{ {
int M = 128 * (rand() % 4 + 1); int M = 128 * (rand() % 4 + 1);
int N = 128 * (rand() % 4 + 1); int N = 128 * (rand() % 4 + 1);
int K = 128; int K = DIM;
int O = 128; int O = DIM;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 2 + 1; int G1 = rand() % 2 + 1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment