Commit 51ec5aa0 authored by danyao12's avatar danyao12
Browse files

modify argc and macro

parent 80ef43a2
...@@ -25,7 +25,7 @@ Kernel outputs: ...@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define RANGE_HDKO 1 // 0~2 #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -91,11 +91,11 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali ...@@ -91,11 +91,11 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
// Headdim/K/O should be a multiple of 8. // DIM should be a multiple of 8.
// If Headdim/K/O <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
// If 32 < Headdim/K/O <= 64 , ues prototype1 2nd template. // If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < Headdim/K/O <= 128, ues prototype2 2nd template. // If 64 < DIM <= 128, ues prototype2 2nd template.
#if(RANGE_HDKO == 0) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
...@@ -163,7 +163,7 @@ using DeviceGemmInstance = ...@@ -163,7 +163,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 1) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
...@@ -299,7 +299,7 @@ using DeviceGemmInstance = ...@@ -299,7 +299,7 @@ using DeviceGemmInstance =
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock // 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization // MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 2) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
...@@ -478,21 +478,13 @@ int run(int argc, char* argv[]) ...@@ -478,21 +478,13 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; ck::index_t M = 512;
ck::index_t N = 512; ck::index_t N = 512;
#if(RANGE_HDKO == 0) ck::index_t K = DIM;
ck::index_t K = 32; // K/O<=32 ck::index_t O = DIM;
#elif(RANGE_HDKO == 1)
ck::index_t K = 64; // 32<K/O<=64
#elif(RANGE_HDKO == 2)
ck::index_t K = 80; // 64<K/O<=128
#endif
ck::index_t O = K;
ck::index_t G0 = 54; ck::index_t G0 = 54;
ck::index_t G1 = 16; ck::index_t G1 = 16;
float alpha = 1.f / std::sqrt(K);
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -510,7 +502,7 @@ int run(int argc, char* argv[]) ...@@ -510,7 +502,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 14) else if(argc == 13)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -523,11 +515,10 @@ int run(int argc, char* argv[]) ...@@ -523,11 +515,10 @@ int run(int argc, char* argv[])
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]); p_drop = std::stof(argv[10]);
p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[12]); input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[13]); output_permute = std::stoi(argv[12]);
} }
else else
{ {
...@@ -543,6 +534,7 @@ int run(int argc, char* argv[]) ...@@ -543,6 +534,7 @@ int run(int argc, char* argv[])
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
......
...@@ -32,7 +32,7 @@ Kernel outputs: ...@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define RANGE_HDKO 0 // 0~2 #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -100,11 +100,11 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali ...@@ -100,11 +100,11 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
// Headdim/K/O should be a multiple of 8. // DIM should be a multiple of 8.
// If Headdim/K/O <= 32 , ues bwd prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
// If 32 < Headdim/K/O <= 64 , ues bwd prototype1 2nd template. // If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < Headdim/K/O <= 128, ues bwd prototype2 2nd template. // If 64 < DIM <= 128, ues prototype2 2nd template.
#if(RANGE_HDKO == 0) #if(DIM <= 32)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
...@@ -242,7 +242,7 @@ using DeviceGemmInstanceBWD = ...@@ -242,7 +242,7 @@ using DeviceGemmInstanceBWD =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 1) #elif(DIM <= 64)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
...@@ -448,7 +448,7 @@ using DeviceGemmInstanceBWD = ...@@ -448,7 +448,7 @@ using DeviceGemmInstanceBWD =
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock // 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization // MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 2) #elif(DIM <= 128)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
...@@ -657,14 +657,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -657,14 +657,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
#if USING_MASK
auto N = s_g_m_n.GetLengths()[2]; auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstanceFWD::C0MatrixMask(N); const auto mask = DeviceGemmInstanceFWD::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) { s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
#endif
// P = Softmax(S) // P = Softmax(S)
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
...@@ -699,25 +697,17 @@ int run(int argc, char* argv[]) ...@@ -699,25 +697,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; // 512 ck::index_t M = 512; // 512
ck::index_t N = 512; // 512 ck::index_t N = 512; // 512
#if(RANGE_HDKO == 0) ck::index_t K = DIM;
ck::index_t K = 32; // K/O<=32 ck::index_t O = DIM;
#elif(RANGE_HDKO == 1)
ck::index_t K = 64; // 32<K/O<=64
#elif(RANGE_HDKO == 2)
ck::index_t K = 72; // 64<K/O<=128
#endif
ck::index_t O = K;
ck::index_t G0 = 4; // 54 ck::index_t G0 = 4; // 54
ck::index_t G1 = 6; // 16 ck::index_t G1 = 6; // 16
float alpha = 1.f / std::sqrt(K); bool input_permute = false;
bool output_permute = false;
bool input_permute = true;
bool output_permute = true;
float p_drop = 0.3; float p_drop = 0.2;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -731,7 +721,7 @@ int run(int argc, char* argv[]) ...@@ -731,7 +721,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 14) else if(argc == 13)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -744,11 +734,10 @@ int run(int argc, char* argv[]) ...@@ -744,11 +734,10 @@ int run(int argc, char* argv[])
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]); p_drop = std::stof(argv[10]);
p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[12]); input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[13]); output_permute = std::stoi(argv[12]);
} }
else else
{ {
...@@ -761,9 +750,10 @@ int run(int argc, char* argv[]) ...@@ -761,9 +750,10 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
......
...@@ -480,8 +480,8 @@ int run(int argc, char* argv[]) ...@@ -480,8 +480,8 @@ int run(int argc, char* argv[])
float alpha = 1.f / std::sqrt(DIM); float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.2; float p_drop = 0.2;
bool input_permute = false; bool input_permute = true;
bool output_permute = false; bool output_permute = true;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment