Commit c80db13f authored by danyao12's avatar danyao12
Browse files

unpate fwd/train fp16/bf16

parent da80a2e3
......@@ -9,7 +9,7 @@ add_example_executable(example_grouped_multihead_attention_forward grouped_multi
add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1 batched_multihead_attention_backward_pt1.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt2 batched_multihead_attention_backward_pt2.cpp)
add_example_executable(example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
......@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 1
#define USING_HD32 0
#include <iostream>
......@@ -344,8 +344,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; // 512
ck::index_t N = 512; // 512
ck::index_t M = 1536; // 512
ck::index_t N = 1536; // 512
#if USING_HD32
ck::index_t K = 32; // K/O<=32
ck::index_t O = 32;
......@@ -353,8 +353,8 @@ int run(int argc, char* argv[])
ck::index_t K = 64; // 32<K/O<=64
ck::index_t O = 64;
#endif
ck::index_t G0 = 4; // 54
ck::index_t G1 = 6; // 16
ck::index_t G0 = 1; // 54
ck::index_t G1 = 1; // 16
float alpha = 1.f / std::sqrt(K);
......@@ -395,6 +395,8 @@ int run(int argc, char* argv[])
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
p_drop = std::stoi(argv[13]);
}
else
{
......@@ -407,6 +409,20 @@ int run(int argc, char* argv[])
exit(0);
}
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
std::cout << "M: " << M << std::endl;
std::cout << "N: " << N << std::endl;
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
std::cout << "p_drop: " << p_drop << std::endl;
const ck::index_t BatchCount = G0 * G1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
......
......@@ -388,6 +388,8 @@ int run(int argc, char* argv[])
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
p_drop = std::stoi(argv[13]);
}
else
{
......@@ -400,6 +402,20 @@ int run(int argc, char* argv[])
exit(0);
}
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
std::cout << "M: " << M << std::endl;
std::cout << "N: " << N << std::endl;
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
std::cout << "p_drop: " << p_drop << std::endl;
const ck::index_t BatchCount = G0 * G1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
......
......@@ -60,6 +60,7 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
......@@ -69,7 +70,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using DataType = F16;
using DataType = BF16;
using GemmDataType = BF16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
......@@ -108,6 +110,7 @@ using DeviceGemmInstanceFWD =
DataType,
DataType,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......@@ -180,6 +183,7 @@ using DeviceGemmInstanceBWD =
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......@@ -248,6 +252,7 @@ using DeviceGemmInstanceBWD =
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......@@ -419,8 +424,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 200; // 512
ck::index_t N = 200; // 512
ck::index_t M = 129; // 512
ck::index_t N = 129; // 512
ck::index_t K = 64;
ck::index_t O = 64;
ck::index_t G0 = 4; // 54
......@@ -428,8 +433,8 @@ int run(int argc, char* argv[])
float alpha = 1.f / std::sqrt(K);
bool input_permute = false;
bool output_permute = false;
bool input_permute = true;
bool output_permute = true;
float p_drop = 0.0;
float p_dropout = 1 - p_drop;
......@@ -465,6 +470,8 @@ int run(int argc, char* argv[])
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
p_drop = std::stoi(argv[13]);
}
else
{
......@@ -477,6 +484,20 @@ int run(int argc, char* argv[])
exit(0);
}
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
std::cout << "M: " << M << std::endl;
std::cout << "N: " << N << std::endl;
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
std::cout << "p_drop: " << p_drop << std::endl;
const ck::index_t BatchCount = G0 * G1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
......@@ -959,9 +980,12 @@ int run(int argc, char* argv[])
{
auto idx_gmo = idx_gmn;
idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo);
ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
}
self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
self(idx_gmn) = ck::type_convert<DataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
});
#if PRINT_HOST
{
......@@ -1058,7 +1082,7 @@ int run(int argc, char* argv[])
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<DataType, ck::bhalf_t>)
if(std::is_same_v<DataType, ck::bhalf_t> || std::is_same_v<GemmDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
......
......@@ -9,8 +9,8 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 512; // 120
ck::index_t N = 512; // 1000
ck::index_t M = 1000; // 120
ck::index_t N = 1000; // 1000
ck::index_t K = 64;
ck::index_t O = 64;
......@@ -360,8 +360,7 @@ int run(int argc, char* argv[])
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
if(std::is_same_v<DataType, ck::bhalf_t> || std::is_same_v<GemmDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
......
......@@ -426,8 +426,7 @@ int run(int argc, char* argv[])
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
if(std::is_same_v<DataType, ck::bhalf_t> || std::is_same_v<GemmDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
......
......@@ -852,7 +852,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
I1, // n0, // NRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment