Commit 9ee2997d authored by guangzlu's avatar guangzlu
Browse files

added bf16 fwd attn dropout verify

parent 067e71a8
......@@ -27,12 +27,14 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......@@ -42,6 +44,7 @@ using B1DataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = BF16;
using ZDataType = U16;
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......@@ -78,6 +81,7 @@ using DeviceGemmInstance =
B0DataType,
B1DataType,
CDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
......@@ -98,8 +102,8 @@ using DeviceGemmInstance =
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
......@@ -107,7 +111,7 @@ using DeviceGemmInstance =
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -157,6 +161,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp,
CElementOp>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, ADataType, ADataType>;
#include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -31,14 +31,14 @@ struct ReferenceDropout : public device::BaseOperator
in_(in),
out_(out),
p_dropout_in_16bits_(p_dropout_in_16bits),
rp_dropout_(ck::type_convert<OutDataType>(rp_dropout))
rp_dropout_(rp_dropout)
{
}
const Tensor<RefDataType>& ref_;
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
RefDataType p_dropout_in_16bits_;
OutDataType rp_dropout_;
float rp_dropout_;
};
// Invoker
......@@ -48,7 +48,10 @@ struct ReferenceDropout : public device::BaseOperator
{
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) =
arg.ref_(idx) < arg.p_dropout_in_16bits_ ? arg.in_(idx) * arg.rp_dropout_ : 0;
arg.ref_(idx) < arg.p_dropout_in_16bits_
? ck::type_convert<OutDataType>(ck::type_convert<float>(arg.in_(idx)) *
ck::type_convert<float>(arg.rp_dropout_))
: 0;
});
return 0;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment