Commit 4fec5ad3 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into wmma_op

parents 24faa1fc 87fd1152
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 3, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<4, 3, UnarySquare, UnarySqrt>>&);
template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 4, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<4, 4, UnarySquare, UnarySqrt>>&);
template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 1, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<4, 1, UnarySquare, UnarySqrt>>&);
template void add_device_reduce_instance_threadwise<F64, F64, F64, 2, 1, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<2, 1, UnarySquare, UnarySqrt>>&);
// clang-format on
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
namespace ck { namespace ck {
...@@ -9,15 +10,11 @@ namespace device { ...@@ -9,15 +10,11 @@ namespace device {
namespace instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 3, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4); template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 4, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1); template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 1, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); template void add_device_reduce_instance_threadwise<I8, I32, I8, 2, 1, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
// clang-format on // clang-format on
// clang-format on // clang-format on
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 3, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<4, 3, PassThrough, UnaryDivide>>&);
template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 4, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<4, 4, PassThrough, UnaryDivide>>&);
template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 1, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<4, 1, PassThrough, UnaryDivide>>&);
template void add_device_reduce_instance_threadwise<I8, I32, I8, 2, 1, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<2, 1, PassThrough, UnaryDivide>>&);
// clang-format on
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 3, UnaryAbs, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 4, UnaryAbs, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 1, UnaryAbs, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<2, 1, UnaryAbs, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 3, UnaryAbs, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 4, UnaryAbs, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 1, UnaryAbs, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<2, 1, UnaryAbs, PassThrough>>&);
// clang-format on
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
// clang-format on
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
// clang-format on
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -29,7 +29,8 @@ template <typename ADataType, ...@@ -29,7 +29,8 @@ template <typename ADataType,
typename ALayout, typename ALayout,
typename B0Layout, typename B0Layout,
typename B1Layout, typename B1Layout,
typename CLayout> typename CLayout,
bool MaskOutUpperTriangle>
bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
...@@ -46,16 +47,18 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -46,16 +47,18 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
int BatchStrideA = -1, int BatchStrideA = -1,
int BatchStrideB0 = -1, int BatchStrideB0 = -1,
int BatchStrideB1 = -1, int BatchStrideB1 = -1,
int BatchStrideC = -1) int BatchStrideC = -1,
float alpha = 1.f)
{ {
using Row = tensor_layout::gemm::RowMajor; using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor; using Col = tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
using Scale = tensor_operation::element_wise::Scale;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough; using Acc0ElementOp = Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
using AccDataType = float; using AccDataType = float;
...@@ -67,7 +70,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -67,7 +70,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
AccDataType, AccDataType,
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
CElementOp>; Acc0ElementOp>;
// Ref Softmax: fp32 in, various type out // Ref Softmax: fp32 in, various type out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
...@@ -185,7 +188,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -185,7 +188,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{}; auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{}; auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
...@@ -201,7 +204,8 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -201,7 +204,8 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
B0ElementOp, B0ElementOp,
Acc0ElementOp, Acc0ElementOp,
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp,
MaskOutUpperTriangle>;
// get device op instances // get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -214,10 +218,16 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -214,10 +218,16 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument( auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{}); a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, Scale{alpha});
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(MaskOutUpperTriangle && idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
});
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker(); auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
......
...@@ -7,10 +7,10 @@ ...@@ -7,10 +7,10 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -22,36 +22,32 @@ ...@@ -22,36 +22,32 @@
namespace ck { namespace ck {
namespace profiler { namespace profiler {
template <typename ADataType, template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename ALayout, typename Acc0BiasesDataType,
typename B0Layout, typename Acc1BiasesDataType,
typename B1Layout, tensor_operation::device::MaskingSpecialization MaskingSpec>
typename CPermuteNumDims_G_M_O> bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verification, int init_method,
int init_method, bool do_log,
bool do_log, bool time_kernel,
bool time_kernel, int M,
int M, int N,
int N, int K,
int K, int O,
int O, int G0,
int G0, int G1,
int G1, float alpha = 1.f)
int StrideA = -1,
int StrideB0 = -1,
int StrideB1 = -1,
int BatchStrideA = -1,
int BatchStrideB0 = -1,
int BatchStrideB1 = -1,
float alpha = 1.f)
{ {
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
using Scale = tensor_operation::element_wise::Scale; using Scale = tensor_operation::element_wise::Scale;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -60,6 +56,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -60,6 +56,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
using AccDataType = float; using AccDataType = float;
using tensor_operation::device::MaskingSpecialization;
// Ref Gemm0: various type in, fp32 out // Ref Gemm0: various type in, fp32 out
using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<ADataType,
...@@ -85,67 +82,33 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -85,67 +82,33 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
bool pass = true; bool pass = true;
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; // A layout [G0, M, G1, K]
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; // B0 layout [G0, N, G1, K]
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA; // B1 layout [G0, N, G1, O]
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1; std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; // C layout [G0, M, G1, O]
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
const int BatchCount = G0 * G1; const int BatchCount = G0 * G1;
auto f_host_tensor_descriptor = [](std::size_t batch_count, Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
std::size_t row, Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
std::size_t col, Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
std::size_t stride, Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::size_t batch_stride, Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
auto layout) {
if(std::is_same<decltype(layout), Row>::value) std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
{ std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_gs_ms_os_host_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
Tensor<CDataType> c_gs_ms_os_device_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
// Host verification: Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{BatchCount, M, O},
std::vector<int>{M * O, O, 1});
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
std::srand(1); // work around test flakiness std::srand(1); // work around test flakiness
...@@ -157,38 +120,38 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -157,38 +120,38 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
// or not. May want to try exact same approach as the GPU kernel in the host reference // or not. May want to try exact same approach as the GPU kernel in the host reference
// GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then, // GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then,
// shrink the input value range as it is less likely to produce errors of around ~1e-3. // shrink the input value range as it is less likely to produce errors of around ~1e-3.
// a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); // a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
// b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5}); // b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
// b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5}); // b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break; break;
case 2: case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break; break;
case 3: case 3:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break; break;
default: default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) * DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
...@@ -196,20 +159,23 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -196,20 +159,23 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
auto b1_element_op = B1ElementOp{}; auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
using DeviceOp = using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout, 1,
B0Layout, 1,
B1Layout, 1,
CPermuteNumDims_G_M_O, 1,
ADataType, ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
AElementOp, ck::Tuple<>,
B0ElementOp, ck::Tuple<>,
Acc0ElementOp, AElementOp,
B1ElementOp, B0ElementOp,
CElementOp>; Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskingSpec>;
// get device op instances // get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -219,6 +185,26 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -219,6 +185,26 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument( auto ref_gemm0_argument = ref_gemm0.MakeArgument(
...@@ -228,7 +214,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -228,7 +214,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(idx[1] < idx[2]) if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
...@@ -265,23 +251,24 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -265,23 +251,24 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = op_ptr->MakeArgumentPointer( auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_gs_ms_os_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M, {}, // std::array<void*, 1> p_acc0_biases;
N, {}, // std::array<void*, 1> p_acc1_biases;
K, a_gs_ms_ks_lengths,
O, a_gs_ms_ks_strides,
BatchCount, b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
StrideA, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
StrideB0, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
StrideB1, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
BatchStrideA, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
BatchStrideB0,
BatchStrideB1,
a_element_op, a_element_op,
b0_element_op, b0_element_op,
acc0_element_op, acc0_element_op,
...@@ -319,18 +306,18 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -319,18 +306,18 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
if(do_verification) if(do_verification)
{ {
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result.mData, pass = pass & ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData); c_gs_ms_os_host_result.mData);
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a_g_m_k: ", a_g_m_k.mData, ",") LogRangeAsType<float>(std::cout << "a_gs_ms_ks: ", a_gs_ms_ks.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "b0_g_k_n : ", b0_g_k_n.mData, ",") LogRangeAsType<float>(std::cout << "b0_gs_ns_ks : ", b0_gs_ns_ks.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",") LogRangeAsType<float>(std::cout << "b1_gs_os_ns : ", b1_gs_os_ns.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",") std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",")
......
...@@ -18,57 +18,61 @@ namespace tensor_operation { ...@@ -18,57 +18,61 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
template <int Rank, int NumReduceDim, int ReduceOpId, bool PropagateNan, bool UseIndex> template <index_t Rank,
index_t NumReduceDim,
ReduceTensorOp ReduceOpId,
bool PropagateNan,
bool UseIndex>
struct ReduceDescription struct ReduceDescription
{ {
static constexpr int Rank_ = Rank; static constexpr index_t Rank_ = Rank;
static constexpr int NumReduceDim_ = NumReduceDim; static constexpr index_t NumReduceDim_ = NumReduceDim;
static constexpr int ReduceOpId_ = ReduceOpId; static constexpr ReduceTensorOp ReduceOpId_ = ReduceOpId;
static constexpr int PropagateNan_ = PropagateNan; static constexpr bool PropagateNan_ = PropagateNan;
static constexpr int UseIndex_ = UseIndex; static constexpr bool UseIndex_ = UseIndex;
}; };
using reduce_description_instances = using reduce_description_instances =
std::tuple<ReduceDescription<4, 3, 0, false, false>, // for ADD std::tuple<ReduceDescription<4, 3, ReduceTensorOp::ADD, false, false>, // for ADD
ReduceDescription<4, 4, 0, false, false>, ReduceDescription<4, 4, ReduceTensorOp::ADD, false, false>,
ReduceDescription<4, 1, 0, false, false>, ReduceDescription<4, 1, ReduceTensorOp::ADD, false, false>,
ReduceDescription<2, 1, 0, false, false>, ReduceDescription<2, 1, ReduceTensorOp::ADD, false, false>,
ReduceDescription<4, 3, 5, false, false>, // for AVG ReduceDescription<4, 3, ReduceTensorOp::AVG, false, false>, // for AVG
ReduceDescription<4, 4, 5, false, false>, ReduceDescription<4, 4, ReduceTensorOp::AVG, false, false>,
ReduceDescription<4, 1, 5, false, false>, ReduceDescription<4, 1, ReduceTensorOp::AVG, false, false>,
ReduceDescription<2, 1, 5, false, false>, ReduceDescription<2, 1, ReduceTensorOp::AVG, false, false>,
ReduceDescription<4, 3, 7, false, false>, // for NORM2 ReduceDescription<4, 3, ReduceTensorOp::NORM2, false, false>, // for NORM2
ReduceDescription<4, 4, 7, false, false>, ReduceDescription<4, 4, ReduceTensorOp::NORM2, false, false>,
ReduceDescription<4, 1, 7, false, false>, ReduceDescription<4, 1, ReduceTensorOp::NORM2, false, false>,
ReduceDescription<2, 1, 7, false, false>, ReduceDescription<2, 1, ReduceTensorOp::NORM2, false, false>,
ReduceDescription<4, 3, 2, false, false>, // for MIN ReduceDescription<4, 3, ReduceTensorOp::MIN, false, false>, // for MIN
ReduceDescription<4, 4, 2, false, false>, ReduceDescription<4, 4, ReduceTensorOp::MIN, false, false>,
ReduceDescription<4, 1, 2, false, false>, ReduceDescription<4, 1, ReduceTensorOp::MIN, false, false>,
ReduceDescription<2, 1, 2, false, false>, ReduceDescription<2, 1, ReduceTensorOp::MIN, false, false>,
ReduceDescription<4, 3, 3, false, false>, // for MAX ReduceDescription<4, 3, ReduceTensorOp::MAX, false, false>, // for MAX
ReduceDescription<4, 4, 3, false, false>, ReduceDescription<4, 4, ReduceTensorOp::MAX, false, false>,
ReduceDescription<4, 1, 3, false, false>, ReduceDescription<4, 1, ReduceTensorOp::MAX, false, false>,
ReduceDescription<2, 1, 3, false, false>, ReduceDescription<2, 1, ReduceTensorOp::MAX, false, false>,
ReduceDescription<4, 3, 4, false, false>, // for AMAX ReduceDescription<4, 3, ReduceTensorOp::AMAX, false, false>, // for AMAX
ReduceDescription<4, 4, 4, false, false>, ReduceDescription<4, 4, ReduceTensorOp::AMAX, false, false>,
ReduceDescription<4, 1, 4, false, false>, ReduceDescription<4, 1, ReduceTensorOp::AMAX, false, false>,
ReduceDescription<2, 1, 4, false, false>, ReduceDescription<2, 1, ReduceTensorOp::AMAX, false, false>,
ReduceDescription<4, 3, 2, false, true>, // for MIN ReduceDescription<4, 3, ReduceTensorOp::MIN, false, true>, // for MIN
ReduceDescription<4, 4, 2, false, true>, ReduceDescription<4, 4, ReduceTensorOp::MIN, false, true>,
ReduceDescription<4, 1, 2, false, true>, ReduceDescription<4, 1, ReduceTensorOp::MIN, false, true>,
ReduceDescription<2, 1, 2, false, true>, ReduceDescription<2, 1, ReduceTensorOp::MIN, false, true>,
ReduceDescription<4, 3, 3, false, true>, // for MAX ReduceDescription<4, 3, ReduceTensorOp::MAX, false, true>, // for MAX
ReduceDescription<4, 4, 3, false, true>, ReduceDescription<4, 4, ReduceTensorOp::MAX, false, true>,
ReduceDescription<4, 1, 3, false, true>, ReduceDescription<4, 1, ReduceTensorOp::MAX, false, true>,
ReduceDescription<2, 1, 3, false, true>, ReduceDescription<2, 1, ReduceTensorOp::MAX, false, true>,
ReduceDescription<4, 3, 4, false, true>, // for AMAX ReduceDescription<4, 3, ReduceTensorOp::AMAX, false, true>, // for AMAX
ReduceDescription<4, 4, 4, false, true>, ReduceDescription<4, 4, ReduceTensorOp::AMAX, false, true>,
ReduceDescription<4, 1, 4, false, true>, ReduceDescription<4, 1, ReduceTensorOp::AMAX, false, true>,
ReduceDescription<2, 1, 4, false, true>>; ReduceDescription<2, 1, ReduceTensorOp::AMAX, false, true>>;
template <typename DescriptionType> template <typename DescriptionType>
bool description_match(const DescriptionType& description, bool description_match(const DescriptionType& description,
...@@ -78,9 +82,8 @@ bool description_match(const DescriptionType& description, ...@@ -78,9 +82,8 @@ bool description_match(const DescriptionType& description,
bool PropagateNan, bool PropagateNan,
bool UseIndex) bool UseIndex)
{ {
if(description.Rank_ != Rank || description.ReduceOpId_ != static_cast<int>(ReduceOpId) || if(description.Rank_ != Rank || description.ReduceOpId_ != ReduceOpId ||
description.PropagateNan_ != static_cast<int>(PropagateNan) || description.PropagateNan_ != PropagateNan || description.UseIndex_ != UseIndex)
description.UseIndex_ != static_cast<int>(UseIndex))
return (false); return (false);
if(DescriptionType::NumReduceDim_ != reduceDims.size()) if(DescriptionType::NumReduceDim_ != reduceDims.size())
...@@ -99,11 +102,10 @@ bool description_match(const DescriptionType& description, ...@@ -99,11 +102,10 @@ bool description_match(const DescriptionType& description,
namespace ck { namespace ck {
namespace profiler { namespace profiler {
template <index_t Rank, index_t NumReduceDim> template <int Rank, int NumReduceDim>
static inline std::vector<int> get_invariant_dims(const std::vector<int>& reduceDims) static inline std::array<int, Rank - NumReduceDim>
get_invariant_dims(const std::array<int, NumReduceDim>& reduceDims)
{ {
assert(NumReduceDim == reduceDims.size());
int reduceFlag = 0; int reduceFlag = 0;
// flag the bits for the reduceDims // flag the bits for the reduceDims
...@@ -112,13 +114,15 @@ static inline std::vector<int> get_invariant_dims(const std::vector<int>& reduce ...@@ -112,13 +114,15 @@ static inline std::vector<int> get_invariant_dims(const std::vector<int>& reduce
reduceFlag |= 1 << reduceDims[i]; reduceFlag |= 1 << reduceDims[i];
}; };
std::vector<int> invariantDims; std::array<int, Rank - NumReduceDim> invariantDims;
// collect invariant dimensions // collect invariant dimensions
int dim = 0;
for(int i = 0; i < Rank; i++) for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) == 0) if((reduceFlag & (1 << i)) == 0)
{ {
invariantDims.push_back(i); invariantDims[dim] = i;
dim++;
}; };
return invariantDims; return invariantDims;
...@@ -137,7 +141,7 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -137,7 +141,7 @@ bool profile_reduce_impl_impl(bool do_verification,
bool do_dumpout, bool do_dumpout,
bool time_kernel, bool time_kernel,
const std::vector<size_t>& inLengths, const std::vector<size_t>& inLengths,
const std::vector<int>& reduceDims, const std::array<int, NumReduceDim>& reduceDims,
float alpha, float alpha,
float beta) float beta)
{ {
...@@ -145,6 +149,8 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -145,6 +149,8 @@ bool profile_reduce_impl_impl(bool do_verification,
using namespace ck::tensor_operation::device::instance; using namespace ck::tensor_operation::device::instance;
using ck::host_common::dumpBufferToFile; using ck::host_common::dumpBufferToFile;
constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
constexpr bool op_support_indices = constexpr bool op_support_indices =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp::AMAX); ReduceOpId == ReduceTensorOp::AMAX);
...@@ -279,28 +285,32 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -279,28 +285,32 @@ bool profile_reduce_impl_impl(bool do_verification,
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator( reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
static_cast<int32_t>(reduce_total_length)); static_cast<int32_t>(reduce_total_length));
using DeviceReduceInstPtr0 = using DeviceReduceInstPtr =
DeviceReducePtr<InElementwiseOperation, AccElementwiseOperation>; DeviceReducePtr<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>;
std::vector<DeviceReduceInstPtr0> reduce0_ptrs; std::vector<DeviceReduceInstPtr> reduce_ptrs;
add_device_reduce_instance_threadwise<InDataType, add_device_reduce_instance_threadwise<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan, PropagateNan,
UseIndex>(reduce0_ptrs); UseIndex>(reduce_ptrs);
add_device_reduce_instance_blockwise<InDataType, add_device_reduce_instance_blockwise<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan, PropagateNan,
UseIndex>(reduce0_ptrs); UseIndex>(reduce_ptrs);
if constexpr(use_atomic_add) if constexpr(use_atomic_add)
{ {
...@@ -309,12 +319,14 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -309,12 +319,14 @@ bool profile_reduce_impl_impl(bool do_verification,
OutDataType, OutDataType,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan, PropagateNan,
UseIndex>(reduce0_ptrs); UseIndex>(reduce_ptrs);
} }
if(reduce0_ptrs.empty()) if(reduce_ptrs.empty())
{ {
throw std::runtime_error("Wrong! No device REDUCE instance found"); throw std::runtime_error("Wrong! No device REDUCE instance found");
}; };
...@@ -342,22 +354,22 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -342,22 +354,22 @@ bool profile_reduce_impl_impl(bool do_verification,
acc_elementwise_op); acc_elementwise_op);
}; };
std::vector<ck::index_t> i_inLengths; std::array<index_t, Rank> arrInLengths;
std::vector<ck::index_t> i_inStrides; std::array<index_t, Rank> arrInStrides;
std::vector<ck::index_t> i_outLengths; std::array<index_t, NumOutDim> arrOutLengths;
std::vector<ck::index_t> i_outStrides; std::array<index_t, NumOutDim> arrOutStrides;
i_inLengths.assign(inLengths.begin(), inLengths.end()); std::copy(inLengths.begin(), inLengths.end(), arrInLengths.begin());
i_inStrides.assign(inStrides.begin(), inStrides.end()); std::copy(inStrides.begin(), inStrides.end(), arrInStrides.begin());
i_outLengths.assign(outLengths.begin(), outLengths.end()); std::copy(outLengths.begin(), outLengths.end(), arrOutLengths.begin());
i_outStrides.assign(outStrides.begin(), outStrides.end()); std::copy(outStrides.begin(), outStrides.end(), arrOutStrides.begin());
for(auto& reduce_ptr : reduce0_ptrs) for(auto& reduce_ptr : reduce_ptrs)
{ {
auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths, auto argument_ptr = reduce_ptr->MakeArgumentPointer(arrInLengths,
i_inStrides, arrInStrides,
i_outLengths, arrOutLengths,
i_outStrides, arrOutStrides,
reduceDims, reduceDims,
alpha, alpha,
beta, beta,
...@@ -478,22 +490,25 @@ bool profile_reduce_impl(bool do_verification, ...@@ -478,22 +490,25 @@ bool profile_reduce_impl(bool do_verification,
descType{}, inLengths.size(), reduceDims, ReduceOpId, PropagateNan, UseIndex)) descType{}, inLengths.size(), reduceDims, ReduceOpId, PropagateNan, UseIndex))
return; return;
pass = pass && std::array<ck::index_t, descType::NumReduceDim_> arrReduceDims;
profile_reduce_impl_impl<InDataType,
AccDataType, std::copy(reduceDims.begin(), reduceDims.end(), arrReduceDims.begin());
OutDataType,
descType::Rank_, pass = pass && profile_reduce_impl_impl<InDataType,
descType::NumReduceDim_, AccDataType,
static_cast<ReduceTensorOp>(descType::ReduceOpId_), OutDataType,
static_cast<bool>(descType::PropagateNan_), descType::Rank_,
static_cast<bool>(descType::UseIndex_)>(do_verification, descType::NumReduceDim_,
init_method, static_cast<ReduceTensorOp>(descType::ReduceOpId_),
do_dumpout, descType::PropagateNan_,
time_kernel, descType::UseIndex_>(do_verification,
inLengths, init_method,
reduceDims, do_dumpout,
alpha, time_kernel,
beta); inLengths,
arrReduceDims,
alpha,
beta);
matched = true; matched = true;
}); });
......
...@@ -81,7 +81,7 @@ def parse_logfile(logfile): ...@@ -81,7 +81,7 @@ def parse_logfile(logfile):
StrideA=[] StrideA=[]
StrideB=[] StrideB=[]
StrideC=[] StrideC=[]
if 'perf_gemm' in logfile: if 'perf_gemm.log' in logfile:
for line in open(logfile): for line in open(logfile):
if 'Best Perf' in line: if 'Best Perf' in line:
lst=line.split() lst=line.split()
...@@ -120,14 +120,14 @@ def parse_logfile(logfile): ...@@ -120,14 +120,14 @@ def parse_logfile(logfile):
res = [x for _,x in sorted(zip(tests,tflops))] res = [x for _,x in sorted(zip(tests,tflops))]
#sorted_kernels = [x for _,x in sorted(zip(tests,kernels))] #sorted_kernels = [x for _,x in sorted(zip(tests,kernels))]
test_list=list(range(1,len(tests)+1)) test_list=list(range(1,len(tests)+1))
#parse conv_fwd performance tests: #parse conv_fwd and conv_bwd performance tests:
elif 'conv_fwd' in logfile: elif 'conv_fwd' in logfile or 'conv_bwd_data' in logfile:
for line in open(logfile): for line in open(logfile):
if 'tflops:' in line: if 'tflops:' in line:
lst=line.split() lst=line.split()
res.append(lst[1]) res.append(lst[1])
#parse all other performance tests: #parse all other performance tests:
elif 'resnet50' in logfile or 'batched_gemm' in logfile or 'grouped_gemm' in logfile or 'conv_bwd_data' in logfile or 'gemm_bilinear' in logfile or 'reduction' in logfile: elif 'resnet50' in logfile or 'batched_gemm' in logfile or 'grouped_gemm' in logfile or 'gemm_bilinear' in logfile or 'reduction' in logfile:
for line in open(logfile): for line in open(logfile):
if 'Best Perf' in line: if 'Best Perf' in line:
lst=line.split() lst=line.split()
...@@ -149,7 +149,7 @@ def store_new_test_result(table_name, test_results, testlist, branch_name, node_ ...@@ -149,7 +149,7 @@ def store_new_test_result(table_name, test_results, testlist, branch_name, node_
df=pd.DataFrame(data=[params],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Environment','Datetime']) df=pd.DataFrame(data=[params],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Environment','Datetime'])
df_add=pd.DataFrame(data=[test_results],columns=testlist) df_add=pd.DataFrame(data=[test_results],columns=testlist)
df=pd.concat([df,df_add],axis=1) df=pd.concat([df,df_add],axis=1)
print("new test results dataframe:",df) #print("new test results dataframe:",df)
df.to_sql(table_name,connection,if_exists='append',index=False) df.to_sql(table_name,connection,if_exists='append',index=False)
return 0 return 0
...@@ -165,7 +165,7 @@ def compare_test_to_baseline(baseline,test,testlist): ...@@ -165,7 +165,7 @@ def compare_test_to_baseline(baseline,test,testlist):
print("test # ",i,"shows regression by {:.3f}%".format( print("test # ",i,"shows regression by {:.3f}%".format(
(float(test[i])-base_list[i])/base_list[i]*100)) (float(test[i])-base_list[i])/base_list[i]*100))
regression=1 regression=1
ave_perf=ave_perf+float(test[i])/base_list[i] if base_list[i]>0: ave_perf=ave_perf+float(test[i])/base_list[i]
if regression==0: if regression==0:
print("no regressions found") print("no regressions found")
ave_perf=ave_perf/len(base_list) ave_perf=ave_perf/len(base_list)
...@@ -248,7 +248,7 @@ def main(): ...@@ -248,7 +248,7 @@ def main():
conn = sqlEngine.connect() conn = sqlEngine.connect()
#save gemm performance tests: #save gemm performance tests:
if 'perf_gemm' in filename: if 'perf_gemm.log' in filename:
#write the ck_gemm_test_params table only needed once the test set changes #write the ck_gemm_test_params table only needed once the test set changes
#post_test_params(test_list,conn) #post_test_params(test_list,conn)
for i in range(1,len(results)+1): for i in range(1,len(results)+1):
......
...@@ -41,7 +41,7 @@ add_subdirectory(batched_gemm) ...@@ -41,7 +41,7 @@ add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm) add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm) add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_masking_scale_softmax_gemm_permute) add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm)
add_subdirectory(reduce) add_subdirectory(reduce)
add_subdirectory(convnd_fwd) add_subdirectory(convnd_fwd)
......
...@@ -2,3 +2,14 @@ add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) ...@@ -2,3 +2,14 @@ add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) target_link_libraries(test_batched_gemm_fp16 PRIVATE utility)
target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility)
target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility)
target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
target_link_libraries(test_batched_gemm_int8 PRIVATE utility)
target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/include/profile_batched_gemm_impl.hpp"
namespace {
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM bf16: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/include/profile_batched_gemm_impl.hpp"
namespace {
using ADataType = float;
using BDataType = float;
using CDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM fp32: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/include/profile_batched_gemm_impl.hpp"
namespace {
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM int8: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
add_custom_target(test_batched_gemm_masking_scale_softmax_gemm_permute)
add_gtest_executable(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp)
target_link_libraries(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_masking_scale_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_masking_scale_softmax_gemm_permute test_batched_gemm_masking_scale_softmax_gemm_permute_fp16)
\ No newline at end of file
...@@ -9,9 +9,13 @@ class TestBatchedGemmSoftmaxGemmFP16 : public TestBatchedGemmSoftmaxGemm<Tuple> ...@@ -9,9 +9,13 @@ class TestBatchedGemmSoftmaxGemmFP16 : public TestBatchedGemmSoftmaxGemm<Tuple>
{ {
}; };
using Masked = std::true_type;
using NoMask = std::false_type;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row> std::tuple<F16, F16, F16, F16, Row, Col, Row, Row, NoMask>,
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row, Masked>
>; >;
// clang-format on // clang-format on
...@@ -120,7 +124,6 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK) ...@@ -120,7 +124,6 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK)
using ck::tensor_operation::device::GemmSpecialization; using ck::tensor_operation::device::GemmSpecialization;
// TODO: enable KPadding tests when it is implemented
TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch) TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch)
{ {
int P = 120; // requires padding int P = 120; // requires padding
...@@ -152,12 +155,12 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch) ...@@ -152,12 +155,12 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch)
// IsSupported(M, N, K, O) // IsSupported(M, N, K, O)
// clang-format off // clang-format off
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128)); EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120)); EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0 // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128)); EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128)); EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0 // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129)); EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
// clang-format on // clang-format on
} }
...@@ -169,6 +172,5 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, AdhocTest) ...@@ -169,6 +172,5 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, AdhocTest)
{1020, 1020, 64, 128, 24}, {1020, 1020, 64, 128, 24},
{576, 576, 64, 64, 24}, {576, 576, 64, 64, 24},
}; };
this->bench_ = true;
this->Run(); this->Run();
} }
...@@ -20,14 +20,15 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -20,14 +20,15 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple> template <typename Tuple>
struct TestBatchedGemmSoftmaxGemm : public ::testing::Test struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
{ {
using ADataType = std::tuple_element_t<0, Tuple>; using ADataType = std::tuple_element_t<0, Tuple>;
using B0DataType = std::tuple_element_t<1, Tuple>; using B0DataType = std::tuple_element_t<1, Tuple>;
using B1DataType = std::tuple_element_t<2, Tuple>; using B1DataType = std::tuple_element_t<2, Tuple>;
using CDataType = std::tuple_element_t<3, Tuple>; using CDataType = std::tuple_element_t<3, Tuple>;
using ALayout = std::tuple_element_t<4, Tuple>; using ALayout = std::tuple_element_t<4, Tuple>;
using B0Layout = std::tuple_element_t<5, Tuple>; using B0Layout = std::tuple_element_t<5, Tuple>;
using B1Layout = std::tuple_element_t<6, Tuple>; using B1Layout = std::tuple_element_t<6, Tuple>;
using CLayout = std::tuple_element_t<7, Tuple>; using CLayout = std::tuple_element_t<7, Tuple>;
using MaskingType = std::tuple_element_t<8, Tuple>;
std::vector<std::vector<int>> lengths_ = {{256, 256, 64, 64, 4}, std::vector<std::vector<int>> lengths_ = {{256, 256, 64, 64, 4},
{256, 256, 128, 128, 4}, {256, 256, 128, 128, 4},
...@@ -54,7 +55,8 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test ...@@ -54,7 +55,8 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
ALayout, ALayout,
B0Layout, B0Layout,
B1Layout, B1Layout,
CLayout>( CLayout,
MaskingType::value>(
verify_, 1, false, bench_, M, N, K, O, BatchCount); verify_, 1, false, bench_, M, N, K, O, BatchCount);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
......
add_custom_target(test_batched_gemm_softmax_gemm_permute)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment