Commit 180290ba authored by rocking's avatar rocking
Browse files

Add gemm layernorm host code

parent c13776be
...@@ -66,14 +66,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern ...@@ -66,14 +66,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8, 1>; < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}), return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride})); std::vector<std::size_t>({stride}));
...@@ -93,6 +85,78 @@ auto f_host_tensor_descriptor2d = ...@@ -93,6 +85,78 @@ auto f_host_tensor_descriptor2d =
} }
}; };
void host_gemm_layernorm(Tensor<HDataType>& e_m_n,
Tensor<HDataType>& h_m_n,
const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<D0DataType>& bias_n,
const Tensor<D1DataType>& d1_m_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n,
AElementOp a_element_op,
BElementOp b_element_op,
CDEElementOp cde_element_op,
int M,
int N,
float epsilon = 1e-5)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n(m, n), c_m_n(m, n), bias_n(n), d1_m_n(m, n));
}
// LayerNorm
Tensor<AccDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<AccDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto layerNormInst = NormalizeFunctor{epsilon};
for(int m = 0; m < M; ++m)
{
AccDataType mean = 0;
AccDataType meanSquare = 0;
for(int n = 0; n < N; ++n)
{
auto e_val = ck::type_convert<AccDataType>(e_m_n(m, n));
mean += e_val;
meanSquare += e_val * e_val;
}
mean /= N;
meanSquare /= N;
for(int n = 0; n < N; ++n)
{
AccDataType h_val = 0;
AccDataType e_val = ck::type_convert<AccDataType>(e_m_n(m, n));
AccDataType gamma_val = ck::type_convert<AccDataType>(gamma_n(n));
AccDataType beta_val = ck::type_convert<AccDataType>(beta_n(n));
layerNormInst(h_val, e_val, mean, meanSquare, gamma_val, beta_val);
h_m_n(m, n) = ck::type_convert<HDataType>(h_val);
}
}
}
int main() int main()
{ {
bool do_verification = true; bool do_verification = true;
...@@ -181,28 +245,34 @@ int main() ...@@ -181,28 +245,34 @@ int main()
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
if(do_verification) if(do_verification)
{ {
Tensor<AccDataType> c_m_n_host(HostTensorDescriptor{M, N});
Tensor<HDataType> e_m_n_host(HostTensorDescriptor{M, N}); Tensor<HDataType> e_m_n_host(HostTensorDescriptor{M, N});
Tensor<HDataType> h_m_n_host(HostTensorDescriptor{M, N});
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); host_gemm_layernorm(e_m_n_host,
h_m_n_host,
auto ref_argument = ref_gemm.MakeArgument( a_m_k,
a_m_k, b_k_n, c_m_n_host, a_element_op, b_element_op, PassThrough{}); b_k_n,
d0_n,
ref_invoker.Run(ref_argument); d1_m_n,
gamma_n,
for(int m = 0; m < M; ++m) beta_n,
{ a_element_op,
for(int n = 0; n < N; ++n) b_element_op,
{ cde_element_op,
cde_element_op(e_m_n_host(m, n), c_m_n_host(m, n), d0_n(n), d1_m_n(m, n)); M,
} N,
} epsilon);
e_device_buf.FromDevice(e_m_n.mData.data()); e_device_buf.FromDevice(e_m_n.mData.data());
return ck::utils::check_err(e_m_n, e_m_n_host) ? 0 : 1; h_device_buf.FromDevice(h_m_n.mData.data());
pass &= ck::utils::check_err(e_m_n, e_m_n_host);
pass &= ck::utils::check_err(h_m_n, h_m_n_host);
} }
return pass ? 0 : 1;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment