Commit 1aed5cb0 authored by guangzlu's avatar guangzlu
Browse files

added group fwd mha dropout verify

parent cdc6f6ba
...@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -161,6 +162,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -161,6 +162,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, ADataType, ADataType>;
#include "run_grouped_multihead_attention_forward.inc" #include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -5,12 +5,12 @@ int run(int argc, char* argv[]) ...@@ -5,12 +5,12 @@ int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
float p_drop = 0.2; float p_drop = 0.1;
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
...@@ -45,9 +45,9 @@ int run(int argc, char* argv[]) ...@@ -45,9 +45,9 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float alpha = 1; // scaling after 1st gemm float alpha = 0.25; // scaling after 1st gemm
std::size_t group_count = 7; std::size_t group_count = 8;
// Problem descs // Problem descs
std::vector<DeviceGemmInstance::ProblemDesc> problem_descs; std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
...@@ -79,13 +79,24 @@ int run(int argc, char* argv[]) ...@@ -79,13 +79,24 @@ int run(int argc, char* argv[])
std::cout << "group count " << group_count << ". printing first 4 groups\n"; std::cout << "group count " << group_count << ". printing first 4 groups\n";
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int M = 128 * (rand() % 8 + 1);
int N = 128 * (rand() % 8 + 1); int M = 512;
int N = 512;
int K = 40; int K = 40;
int O = 40 * (rand() % 2 + 1); int O = 40;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G1 = rand() % 5 + 1;
// int M = 128 * (rand() % 8 + 1);
// int N = 128 * (rand() % 8 + 1);
// int K = 40;
// int O = 40 * (rand() % 2 + 1);
// int G0 = rand() % 3 + 1;
// int G1 = rand() % 5 + 1;
std::cout << "group id" << i << " M, N, K, O, G0, G1 is " << M << "," << N << "," << K
<< "," << O << "," << G0 << "," << G1 << std::endl;
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O}); g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O});
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
...@@ -229,25 +240,26 @@ int run(int argc, char* argv[]) ...@@ -229,25 +240,26 @@ int run(int argc, char* argv[])
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(p_a, auto argument =
p_b0, gemm.MakeArgument(p_a,
p_b1, p_b0,
p_c, p_b1,
p_z, p_c,
p_lse, p_z,
{}, // p_acc0_biases p_lse,
{}, // p_acc1_biases {}, // p_acc0_biases
problem_descs, {}, // p_acc1_biases
a_element_op, problem_descs,
b0_element_op, a_element_op,
acc0_element_op, b0_element_op,
b1_element_op, acc0_element_op,
c_element_op, b1_element_op,
p_drop, // dropout ratio c_element_op,
{0, 448}); // dropout random seed and offset, offset should be p_drop, // dropout ratio
// at least the number of elements on a thread {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc // specify workspace for problem_desc
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -291,11 +303,14 @@ int run(int argc, char* argv[]) ...@@ -291,11 +303,14 @@ int run(int argc, char* argv[])
const auto& b0_gs_ns_ks = b0_tensors[i]; const auto& b0_gs_ns_ks = b0_tensors[i];
const auto& b1_gs_os_ns = b1_tensors[i]; const auto& b1_gs_os_ns = b1_tensors[i];
auto& c_gs_ms_os_device_result = c_tensors[i]; auto& c_gs_ms_os_device_result = c_tensors[i];
auto& z_gs_ms_ns_device_result = z_tensors[i];
auto& lse_gs_ms_device_result = lse_tensors[i]; auto& lse_gs_ms_device_result = lse_tensors[i];
auto& c_gs_ms_os_device_buf = *c_tensors_device[i]; auto& c_gs_ms_os_device_buf = *c_tensors_device[i];
auto& z_gs_ms_ns_device_buf = *z_tensors_device[i];
auto& lse_gs_ms_device_buf = *lse_tensors_device[i]; auto& lse_gs_ms_device_buf = *lse_tensors_device[i];
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
z_gs_ms_ns_device_buf.FromDevice(z_gs_ms_ns_device_result.mData.data());
lse_gs_ms_device_buf.FromDevice(lse_gs_ms_device_result.mData.data()); lse_gs_ms_device_buf.FromDevice(lse_gs_ms_device_result.mData.data());
Tensor<ADataType> a_g_m_k({G0 * G1, M, K}); Tensor<ADataType> a_g_m_k({G0 * G1, M, K});
...@@ -303,8 +318,11 @@ int run(int argc, char* argv[]) ...@@ -303,8 +318,11 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O}); Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
// Tensor<CDataType> z_gs_ms_ns_host_result(z_gs_ms_os_lengths, z_gs_ms_os_strides);
Tensor<LSEDataType> lse_g_m_host_result({G0 * G1, M}); // scratch object after gemm1 Tensor<LSEDataType> lse_g_m_host_result({G0 * G1, M}); // scratch object after gemm1
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
...@@ -319,6 +337,10 @@ int run(int argc, char* argv[]) ...@@ -319,6 +337,10 @@ int run(int argc, char* argv[])
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
}); });
z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
// gemm 0 // gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
...@@ -342,10 +364,20 @@ int run(int argc, char* argv[]) ...@@ -342,10 +364,20 @@ int run(int argc, char* argv[])
ref_softmax_invoker.Run(ref_softmax_argument); ref_softmax_invoker.Run(ref_softmax_argument);
// printf("print z_g_m_n \n");
// z_g_m_n.ForEach([&](auto& self, auto idx) {printf("%u ", self(idx));});
// dropout after softmax
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// gemm 1 // gemm 1
auto ref_gemm1 = ReferenceGemm1Instance{}; auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n_drop,
b1_g_n_o, b1_g_n_o,
c_g_m_o_host_result, c_g_m_o_host_result,
PassThrough{}, PassThrough{},
...@@ -384,9 +416,11 @@ int run(int argc, char* argv[]) ...@@ -384,9 +416,11 @@ int run(int argc, char* argv[])
atol = 1e-2; atol = 1e-2;
} }
printf("group id is %lu \n", i);
// bool pass_ = // bool pass_ =
// ck::utils::check_err(c_gs_ms_os_device_result.mData, // ck::utils::check_err(c_gs_ms_os_device_result.mData,
// c_gs_ms_os_host_result.mData); // c_gs_ms_os_host_result.mData);
bool pass_ = ck::utils::check_err(c_gs_ms_os_device_result.mData, bool pass_ = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData, c_gs_ms_os_host_result.mData,
"Error: Incorrect results c!", "Error: Incorrect results c!",
......
...@@ -97,17 +97,18 @@ __global__ void ...@@ -97,17 +97,18 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
// unsigned short* p_z_grid_in = // unsigned short* p_z_grid_in = //
// (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr p_z_grid_in,
: arg_ptr[group_id].p_z_grid_ + z_batch_offset, // arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -417,6 +418,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -417,6 +418,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
}; };
...@@ -621,7 +623,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -621,7 +623,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 // typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; // z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
z_grid_desc_m_n); z_grid_desc_m_n);
......
...@@ -139,23 +139,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -139,23 +139,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
} }
__host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M,
const index_t N) ////=> for z use
{
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(M, N)),
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
{ {
...@@ -852,7 +835,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -852,7 +835,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
///////////////////=>z for dropout ///////////////////=>z for dropout
// //
// z vgpr copy to global // z vgpr copy to global
// //
...@@ -876,11 +858,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -876,11 +858,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer; z_tenor_buffer;
z_tenor_buffer.Clear(); z_tenor_buffer.Clear();
// z matrix global desc // z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
...@@ -1025,7 +1002,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1025,7 +1002,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
true>( false>(
acc_thread_buf, ph, z_tenor_buffer); acc_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
...@@ -1034,20 +1011,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1034,20 +1011,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>(
acc_thread_buf, ph); acc_thread_buf, ph);
} }
} }
// if constexpr(IsDropout) // dropout
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
// TODO: may convert to log domain // TODO: may convert to log domain
running_max_new = mathext::max(max, running_max); running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment