Commit 9ee2997d authored by guangzlu's avatar guangzlu
Browse files

added bf16 fwd attn dropout verify

parent 067e71a8
...@@ -27,12 +27,14 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -27,12 +27,14 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...@@ -42,6 +44,7 @@ using B1DataType = BF16; ...@@ -42,6 +44,7 @@ using B1DataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = BF16; using CDataType = BF16;
using ZDataType = U16;
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -78,6 +81,7 @@ using DeviceGemmInstance = ...@@ -78,6 +81,7 @@ using DeviceGemmInstance =
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
...@@ -98,8 +102,8 @@ using DeviceGemmInstance = ...@@ -98,8 +102,8 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
64, // Gemm1NPerBlock 128, // Gemm1NPerBlock
32, // Gemm1KPerBlock 64, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -107,7 +111,7 @@ using DeviceGemmInstance = ...@@ -107,7 +111,7 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -157,6 +161,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -157,6 +161,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, ADataType, ADataType>;
#include "run_grouped_multihead_attention_forward.inc" #include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -31,14 +31,14 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -31,14 +31,14 @@ struct ReferenceDropout : public device::BaseOperator
in_(in), in_(in),
out_(out), out_(out),
p_dropout_in_16bits_(p_dropout_in_16bits), p_dropout_in_16bits_(p_dropout_in_16bits),
rp_dropout_(ck::type_convert<OutDataType>(rp_dropout)) rp_dropout_(rp_dropout)
{ {
} }
const Tensor<RefDataType>& ref_; const Tensor<RefDataType>& ref_;
const Tensor<InDataType>& in_; const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_; Tensor<OutDataType>& out_;
RefDataType p_dropout_in_16bits_; RefDataType p_dropout_in_16bits_;
OutDataType rp_dropout_; float rp_dropout_;
}; };
// Invoker // Invoker
...@@ -48,7 +48,10 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -48,7 +48,10 @@ struct ReferenceDropout : public device::BaseOperator
{ {
arg.out_.ForEach([&](auto& self, auto idx) { arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = self(idx) =
arg.ref_(idx) < arg.p_dropout_in_16bits_ ? arg.in_(idx) * arg.rp_dropout_ : 0; arg.ref_(idx) < arg.p_dropout_in_16bits_
? ck::type_convert<OutDataType>(ck::type_convert<float>(arg.in_(idx)) *
ck::type_convert<float>(arg.rp_dropout_))
: 0;
}); });
return 0; return 0;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment