Commit b57c3879 authored by Anthony Chang's avatar Anthony Chang
Browse files

harmonize interface between ref_gemm and ref_batched_gemm

parent 237371ad
...@@ -66,8 +66,14 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc ...@@ -66,8 +66,14 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance =
ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType,
CDataType,
ReduceAccDataType,
AElementOp,
BElementOp,
CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
...@@ -51,8 +51,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermu ...@@ -51,8 +51,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermu
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
ReferenceBatchedGemm<ADataType, BDataType, EDataType, AElementOp, BElementOp, CDEElementOp>; BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
...@@ -16,6 +16,7 @@ namespace host { ...@@ -16,6 +16,7 @@ namespace host {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
...@@ -58,7 +59,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -58,7 +59,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) { auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2]; const int K = arg.a_g_m_k_.mDesc.GetLengths()[2];
float v_acc = 0; AccDataType v_acc = 0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
...@@ -68,10 +69,10 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -68,10 +69,10 @@ struct ReferenceBatchedGemm : public device::BaseOperator
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc += ck::type_convert<float>(v_a) * ck::type_convert<float>(v_b); v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
float v_c; AccDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
...@@ -81,8 +82,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -81,8 +82,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
make_ParallelTensorFunctor(f_gmk_gkn_gmn, make_ParallelTensorFunctor(f_gmk_gkn_gmn,
arg.c_g_m_n_.mDesc.GetLengths()[0], arg.c_g_m_n_.mDesc.GetLengths()[0],
arg.c_g_m_n_.mDesc.GetLengths()[1], arg.c_g_m_n_.mDesc.GetLengths()[1],
arg.c_g_m_n_.mDesc.GetLengths()[2])( arg.c_g_m_n_.mDesc.GetLengths()[2])();
std::thread::hardware_concurrency());
return 0; return 0;
} }
......
...@@ -101,6 +101,7 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -101,6 +101,7 @@ bool profile_batched_gemm_impl(int do_verification,
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
float,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; CElementOp>;
......
...@@ -155,6 +155,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -155,6 +155,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
float,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; CElementOp>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment