Commit 9f6dbb55 authored by Anthony Chang's avatar Anthony Chang
Browse files

tighten up example code

parent ebdb48ae
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "device_gemm_xdl_layernorm_cshuffle.hpp" #include "device_gemm_xdl_layernorm_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reference_gemm.hpp" #include "reference_gemm_layernorm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -50,65 +50,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl ...@@ -50,65 +50,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl
< Row, Col, Row, F16, F16, F16, AccDataType, AccDataType, AccDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, AccDataType, AccDataType, AccDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta) using ReferenceInstance = ck::tensor_operation::host::
template <typename InDataType, typename OutDataType, typename ComputeDataType> ReferenceGemmLayernorm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
void Layernorm(Tensor<OutDataType>& result,
const Tensor<ComputeDataType>& acc, // MxN
const Tensor<InDataType>& bias, // 1xN
const Tensor<InDataType>& gamma, // 1xN
const Tensor<InDataType>& beta, // 1xN
const InDataType epsilon = 1e-5)
{
assert(acc.mDesc.GetLengths()[1] == bias.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]);
size_t M = acc.mDesc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1];
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<ComputeDataType> acc_layernorm(acc.mDesc);
// add bias
acc_layernorm.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) = acc(idx[0], idx[1]) + bias(idx[1]);
});
// reduce N dim
for(size_t i = 0; i < M; i++)
{
ComputeDataType sum_acc_sq = 0;
ComputeDataType sum_acc = 0;
for(size_t j = 0; j < N; j++)
{
sum_acc_sq += acc_layernorm(i, j) * acc_layernorm(i, j);
sum_acc += acc_layernorm(i, j);
}
avg_acc_sq(i) = sum_acc_sq / N;
avg_acc(i) = sum_acc / N;
// std::cout << "avg_acc_(" << i << ") =" << avg_acc(i) << std::endl;
// std::cout << "avg_acc_sq_(" << i << ") =" << avg_acc_sq(i) << std::endl;
}
// normalize
acc_layernorm.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) =
(self(idx[0], idx[1]) - avg_acc(idx[0])) /
sqrt(avg_acc_sq(idx[0]) - avg_acc(idx[0]) * avg_acc(idx[0]) + epsilon);
});
// affine
acc_layernorm.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) = self(idx[0], idx[1]) * gamma(idx[1]) + beta(idx[1]);
});
// cast
result = acc_layernorm.template CopyAsType<OutDataType>();
}
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -272,16 +215,14 @@ int main(int argc, char* argv[]) ...@@ -272,16 +215,14 @@ int main(int argc, char* argv[])
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, acc_m_n_host_result, a_element_op, b_element_op, c_element_op); a_m_k, b_k_n, c0_n_bias, c0_n_gamma, c0_n_beta, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
Layernorm(c_m_n_host_result, acc_m_n_host_result, c0_n_bias, c0_n_gamma, c0_n_beta);
pass &= ck::utils::check_err( pass &= ck::utils::check_err(
c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c"); c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c");
......
#pragma once
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "reference_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmLayernorm : public device::BaseOperator
{
using ReferenceGemmInstance = ReferenceGemm<ADataType, BDataType, AccDataType, AccDataType,
AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>;
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template <typename InDataType, typename OutDataType, typename ComputeDataType>
static void RunLayernorm(Tensor<OutDataType>& result,
const Tensor<ComputeDataType>& acc, // MxN
const Tensor<InDataType>& bias, // 1xN
const Tensor<InDataType>& gamma, // 1xN
const Tensor<InDataType>& beta, // 1xN
const InDataType epsilon = 1e-5)
{
assert(acc.mDesc.GetLengths()[1] == bias.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]);
size_t M = acc.mDesc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1];
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<ComputeDataType> acc_layernorm(acc.mDesc);
// add bias
acc_layernorm.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) = acc(idx[0], idx[1]) + bias(idx[1]);
});
// reduce N dim
for(size_t i = 0; i < M; i++)
{
ComputeDataType sum_acc_sq = 0;
ComputeDataType sum_acc = 0;
for(size_t j = 0; j < N; j++)
{
sum_acc_sq += acc_layernorm(i, j) * acc_layernorm(i, j);
sum_acc += acc_layernorm(i, j);
}
avg_acc_sq(i) = sum_acc_sq / N;
avg_acc(i) = sum_acc / N;
// std::cout << "avg_acc_(" << i << ") =" << avg_acc(i) << std::endl;
// std::cout << "avg_acc_sq_(" << i << ") =" << avg_acc_sq(i) << std::endl;
}
// normalize
acc_layernorm.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) =
(self(idx[0], idx[1]) - avg_acc(idx[0])) /
sqrt(avg_acc_sq(idx[0]) - avg_acc(idx[0]) * avg_acc(idx[0]) + epsilon);
});
// affine
acc_layernorm.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) = self(idx[0], idx[1]) * gamma(idx[1]) + beta(idx[1]);
});
// cast
result = acc_layernorm.template CopyAsType<OutDataType>();
}
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<CDataType>& c0_n_bias, // 1xN
const Tensor<CDataType>& c0_n_gamma, // 1xN
const Tensor<CDataType>& c0_n_beta, // 1xN
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
const CDataType epsilon = 1e-5)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c0_n_bias_{c0_n_bias},
c0_n_gamma_{c0_n_gamma},
c0_n_beta_{c0_n_beta},
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
epsilon_{epsilon}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
const Tensor<CDataType>& c0_n_bias_;
const Tensor<CDataType>& c0_n_gamma_;
const Tensor<CDataType>& c0_n_beta_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
const CDataType epsilon_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
// using Argument = ReferenceGemm::Argument;
float Run(const Argument& arg)
{
Tensor<AccDataType> acc_m_n(arg.c_m_n_.mDesc);
acc_m_n.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
arg.a_m_k_, arg.b_k_n_, acc_m_n, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_);
ref_invoker.Run(ref_argument);
RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_bias_, arg.c0_n_gamma_, arg.c0_n_beta_);
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<CDataType>& c0_n_bias, // 1xN
const Tensor<CDataType>& c0_n_gamma, // 1xN
const Tensor<CDataType>& c0_n_beta, // 1xN
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
const CDataType epsilon = 1e-5)
{
return Argument{a_m_k,
b_k_n,
c0_n_bias,
c0_n_gamma,
c0_n_beta,
c_m_n,
a_element_op,
b_element_op,
c_element_op,
epsilon};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmLayernorm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment