Commit 27d764eb authored by ltqin's avatar ltqin
Browse files

Merge branch 'attn-bwd-develop' into attn-bwd-bf16-rtz

parents 022ce136 55057f09
...@@ -7,8 +7,8 @@ add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_pe ...@@ -7,8 +7,8 @@ add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_pe
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp) add_example_executable(example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp) add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1 batched_multihead_attention_backward_pt1.cpp) add_example_executable(example_grouped_multihead_attention_backward grouped_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt2 batched_multihead_attention_backward_pt2.cpp) add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp) add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
......
...@@ -24,8 +24,8 @@ Kernel outputs: ...@@ -24,8 +24,8 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 1 #define USING_MASK 0
#define USING_K128 1 #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -36,7 +36,8 @@ Kernel outputs: ...@@ -36,7 +36,8 @@ Kernel outputs:
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -90,9 +91,81 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali ...@@ -90,9 +91,81 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
#if USING_K128 // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
32, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -121,7 +194,7 @@ using DeviceGemmInstance = ...@@ -121,7 +194,7 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
64, // KPerBlock 64, // KPerBlock
128, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -130,7 +203,7 @@ using DeviceGemmInstance = ...@@ -130,7 +203,7 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave 2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
...@@ -154,14 +227,81 @@ using DeviceGemmInstance = ...@@ -154,14 +227,81 @@ using DeviceGemmInstance =
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#else // using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
// NumDimK,
// NumDimO,
// DataType,
// GemmDataType,
// ZDataType,
// LSEDataType,
// Acc0BiasDataType,
// Acc1BiasDataType,
// AccDataType,
// ShuffleDataType,
// QKVElementOp,
// QKVElementOp,
// Scale,
// QKVElementOp,
// YElementOp,
// GemmSpec,
// TensorSpecQ,
// TensorSpecK,
// TensorSpecV,
// TensorSpecY,
// 1,
// 256,
// 128, // MPerBlock
// 128, // NPerBlock
// 64, // KPerBlock
// 64, // Gemm1NPerBlock
// 64, // Gemm1KPerBlock
// 8, // AK1
// 8, // BK1
// 2, // B1K1
// 32, // MPerXDL
// 32, // NPerXDL
// 1, // MXdlPerWave
// 4, // NXdlPerWave
// 2, // Gemm1NXdlPerWave
// 2, // Gemm2NXdlPerWave
// S<4, 64, 1>, // ABlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<4, 64, 1>, // BBlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<8, 32, 1>, // B1BlockTransfer
// S<0, 2, 1>,
// S<0, 2, 1>,
// 1,
// 2,
// 2,
// false,
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -190,8 +330,8 @@ using DeviceGemmInstance = ...@@ -190,8 +330,8 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
64, // KPerBlock 64, // KPerBlock
64, // Gemm1NPerBlock 128, // Gemm1NPerBlock
64, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -199,7 +339,7 @@ using DeviceGemmInstance = ...@@ -199,7 +339,7 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave 2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
...@@ -219,15 +359,16 @@ using DeviceGemmInstance = ...@@ -219,15 +359,16 @@ using DeviceGemmInstance =
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
2, 4,
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out // fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
...@@ -337,27 +478,17 @@ int run(int argc, char* argv[]) ...@@ -337,27 +478,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; ck::index_t M = 512;
ck::index_t N = 512; ck::index_t N = 512;
#if USING_K128 ck::index_t K = DIM;
ck::index_t K = 128; ck::index_t O = DIM;
ck::index_t O = 128; ck::index_t G0 = 54;
#else ck::index_t G1 = 16;
ck::index_t K = 64;
ck::index_t O = 64;
#endif
ck::index_t G0 = 3;
ck::index_t G1 = 2;
float alpha = 1.f / std::sqrt(K);
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2; float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -384,12 +515,10 @@ int run(int argc, char* argv[]) ...@@ -384,12 +515,10 @@ int run(int argc, char* argv[])
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]); p_drop = std::stof(argv[10]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[12]);
p_drop = std::stoi(argv[13]);
} }
else else
{ {
...@@ -402,6 +531,11 @@ int run(int argc, char* argv[]) ...@@ -402,6 +531,11 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl; std::cout << "time_kernel: " << time_kernel << std::endl;
...@@ -536,7 +670,6 @@ int run(int argc, char* argv[]) ...@@ -536,7 +670,6 @@ int run(int argc, char* argv[])
// = 0 // = 0
} }
// calculate y & log-sum-exp beforehand
Tensor<DataType> q_g_m_k({BatchCount, M, K}); Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K}); Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
......
...@@ -9,6 +9,8 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -9,6 +9,8 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1 Gemm1
*/ */
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
...@@ -73,6 +75,77 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial ...@@ -73,6 +75,77 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
32, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
...@@ -142,6 +215,77 @@ using DeviceGemmInstance = ...@@ -142,6 +215,77 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
...@@ -32,7 +32,7 @@ Kernel outputs: ...@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define USING_HD32 0 #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -43,8 +43,9 @@ Kernel outputs: ...@@ -43,8 +43,9 @@ Kernel outputs:
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -99,6 +100,11 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali ...@@ -99,6 +100,11 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
...@@ -132,7 +138,7 @@ using DeviceGemmInstanceFWD = ...@@ -132,7 +138,7 @@ using DeviceGemmInstanceFWD =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
64, // Gemm1NPerBlock 32, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -141,7 +147,7 @@ using DeviceGemmInstanceFWD = ...@@ -141,7 +147,7 @@ using DeviceGemmInstanceFWD =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -160,23 +166,17 @@ using DeviceGemmInstanceFWD = ...@@ -160,23 +166,17 @@ using DeviceGemmInstanceFWD =
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
4, 2,
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// Headdim/K/O should be a multiple of 8, and it's only supported up to 64 in prototype1.
// If Headdim/K/O <= 32, ues 1st template.
// If 32 < Headdim/K/O <= 64, ues 2nd template.
#if USING_HD32
// 1st template
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -242,10 +242,79 @@ using DeviceGemmInstanceBWD = ...@@ -242,10 +242,79 @@ using DeviceGemmInstanceBWD =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#else #elif(DIM <= 64)
// 2nd template using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
DataType,
DataType,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -311,6 +380,212 @@ using DeviceGemmInstanceBWD = ...@@ -311,6 +380,212 @@ using DeviceGemmInstanceBWD =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
// NumDimK,
// NumDimO,
// DataType,
// GemmDataType,
// ZDataType,
// LSEDataType,
// Acc0BiasDataType,
// Acc1BiasDataType,
// AccDataType,
// ShuffleDataType,
// QKVElementOp,
// QKVElementOp,
// Scale,
// QKVElementOp,
// YElementOp,
// GemmSpec,
// TensorSpecQ,
// TensorSpecK,
// TensorSpecV,
// TensorSpecY,
// 1,
// 256,
// 128, // MPerBlock
// 128, // NPerBlock
// 64, // KPerBlock
// 64, // Gemm1NPerBlock
// 64, // Gemm1KPerBlock
// 8, // AK1
// 8, // BK1
// 2, // B1K1
// 32, // MPerXDL
// 32, // NPerXDL
// 1, // MXdlPerWave
// 4, // NXdlPerWave
// 2, // Gemm1NXdlPerWave
// 2, // Gemm2NXdlPerWave
// S<4, 64, 1>, // ABlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<4, 64, 1>, // BBlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<8, 32, 1>, // B1BlockTransfer
// S<0, 2, 1>,
// S<0, 2, 1>,
// 1,
// 2,
// 2,
// false,
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 128)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
DataType,
DataType,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
...@@ -382,14 +657,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -382,14 +657,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
#if USING_MASK
auto N = s_g_m_n.GetLengths()[2]; auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstanceFWD::C0MatrixMask(N); const auto mask = DeviceGemmInstanceFWD::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) { s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
#endif
// P = Softmax(S) // P = Softmax(S)
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
...@@ -424,22 +697,17 @@ int run(int argc, char* argv[]) ...@@ -424,22 +697,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 129; // 512 ck::index_t M = 512; // 512
ck::index_t N = 129; // 512 ck::index_t N = 512; // 512
ck::index_t K = 64; ck::index_t K = DIM;
ck::index_t O = 64; ck::index_t O = DIM;
ck::index_t G0 = 4; // 54 ck::index_t G0 = 4; // 54
ck::index_t G1 = 6; // 16 ck::index_t G1 = 6; // 16
float alpha = 1.f / std::sqrt(K); bool input_permute = false;
bool output_permute = false;
bool input_permute = true; float p_drop = 0.2;
bool output_permute = true;
float p_drop = 0.0;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -466,12 +734,10 @@ int run(int argc, char* argv[]) ...@@ -466,12 +734,10 @@ int run(int argc, char* argv[])
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]); p_drop = std::stof(argv[10]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[12]);
p_drop = std::stoi(argv[13]);
} }
else else
{ {
...@@ -484,6 +750,11 @@ int run(int argc, char* argv[]) ...@@ -484,6 +750,11 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl; std::cout << "time_kernel: " << time_kernel << std::endl;
......
...@@ -9,6 +9,8 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -9,6 +9,8 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1 Gemm1
*/ */
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
...@@ -73,6 +75,77 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial ...@@ -73,6 +75,77 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
32, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
...@@ -142,6 +215,77 @@ using DeviceGemmInstance = ...@@ -142,6 +215,77 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
...@@ -11,8 +11,8 @@ int run(int argc, char* argv[]) ...@@ -11,8 +11,8 @@ int run(int argc, char* argv[])
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 1000; // 120 ck::index_t M = 1000; // 120
ck::index_t N = 1000; // 1000 ck::index_t N = 1000; // 1000
ck::index_t K = 64; ck::index_t K = DIM;
ck::index_t O = 64; ck::index_t O = DIM;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
......
...@@ -10,10 +10,7 @@ int run(int argc, char* argv[]) ...@@ -10,10 +10,7 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
float p_drop = 0.1; float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -27,14 +24,15 @@ int run(int argc, char* argv[]) ...@@ -27,14 +24,15 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 6) else if(argc == 7)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
input_permute = std::stoi(argv[4]); p_drop = std::stoi(argv[4]);
output_permute = std::stoi(argv[5]); input_permute = std::stoi(argv[5]);
output_permute = std::stoi(argv[6]);
} }
else else
{ {
...@@ -45,6 +43,10 @@ int run(int argc, char* argv[]) ...@@ -45,6 +43,10 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1; // scaling after 1st gemm float alpha = 1; // scaling after 1st gemm
std::size_t group_count = 8; std::size_t group_count = 8;
...@@ -81,10 +83,10 @@ int run(int argc, char* argv[]) ...@@ -81,10 +83,10 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int M = 128 * (rand() % 8 + 1); int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8 + 1); int N = 128 * (rand() % 8) + (rand() % 128);
int K = 64; int K = DIM;
int O = 64; int O = DIM;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G1 = rand() % 5 + 1;
......
...@@ -50,10 +50,9 @@ template <typename GridwiseGemm, ...@@ -50,10 +50,9 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_pt1( kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
const DataType* __restrict__ p_a_grid, const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const DataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -233,7 +232,7 @@ template <index_t NumDimG, ...@@ -233,7 +232,7 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -255,7 +254,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -255,7 +254,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -597,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -597,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
...@@ -900,7 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -900,7 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_pt1< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
DataType, DataType,
ZDataType, ZDataType,
...@@ -1231,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1231,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1" str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -231,7 +231,7 @@ template <index_t NumDimG, ...@@ -231,7 +231,7 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -253,7 +253,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -253,7 +253,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -1230,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1230,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle" str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -413,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -413,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
......
...@@ -424,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -424,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
......
...@@ -85,7 +85,7 @@ template <typename DataType, ...@@ -85,7 +85,7 @@ template <typename DataType,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace ck { namespace ck {
template <typename FloatAB, template <typename FloatAB,
typename ZDataType,
typename FloatGemm, typename FloatGemm,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
...@@ -274,11 +275,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -274,11 +275,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
if(Gemm1N != K) // if(Gemm1N != K)
{ // {
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; // std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; // return false;
} // }
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{ {
...@@ -424,7 +425,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -424,7 +425,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
unsigned short* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid, FloatLSE* __restrict__ p_lse_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -876,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -876,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ushort, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -891,8 +892,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -891,8 +892,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3, // NInputNum n3, // NInputNum
n4>, n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim 9, // DstVectorDim
n4, // DstScalarPerVector 1, // DstScalarPerVector
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment