Commit 92c615b6 authored by Jing Zhang's avatar Jing Zhang
Browse files

add reference_gemm_transpose

parent f1cdecfb
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_grouped_gemm_transpose_xdl.hpp" #include "device_grouped_gemm_transpose_xdl.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm_transpose.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -55,8 +55,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmTransp ...@@ -55,8 +55,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmTransp
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmTransposeInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemmTranspose<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -89,13 +89,30 @@ int main(int argc, char* argv[]) ...@@ -89,13 +89,30 @@ int main(int argc, char* argv[])
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
int B = 16; int B = 16;
int S = 64; int S = 64;
int nH = 16; int NumHead = 16;
int hD = 64; int HeadDim = 64;
int M0 = B;
int M1 = S;
int N0 = NumHead;
int N1 = HeadDim;
int M = M0 * N1;
int N = N0 * N1;
int K = NumHead * HeadDim;
int StrideA = K;
int StrideB = K;
int StrideM0 = S * NumHead * HeadDim;
int StrideM1 = HeadDim;
int StrideN0 = S * HeadDim;
int StrideN1 = 1;
gemm_descs.push_back( gemm_descs.push_back(
{B * S, nH * hD, nH * hD, nH * hD, nH * hD, B, S, nH, hD, S * nH * hD, S * hD, hD, 1}); {M, N, K, StrideA, StrideB, M0, M1, N0, N1, StrideM0, StrideM1, StrideN0, StrideN1});
} }
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
...@@ -250,7 +267,7 @@ int main(int argc, char* argv[]) ...@@ -250,7 +267,7 @@ int main(int argc, char* argv[])
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmTransposeInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
......
#pragma once
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmTranspose : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m0_m1_n0_n1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c_m0_m1_n0_n1_{c_m0_m1_n0_n1},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
Tensor<CDataType>& c_m0_m1_n0_n1_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmTranspose::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_m0m1n0n1 = [&](auto m0, auto m1, auto n0, auto n1) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
const int m = m0 * arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[1] + m1;
const int n = n0 * arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[3] + n1;
float v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b;
}
float v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_m0_m1_n0_n1_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(
f_mk_kn_m0m1n0n1, arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[0], arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[1], arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[2],arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m0_m1_n0_n1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k, b_k_n, c_m0_m1_n0_n1, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmTranspose"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment