Unverified Commit 19a08d65 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-471

parents 2056491f afdfef74
......@@ -15,14 +15,14 @@ namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 3, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 4, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 1, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 2, 1, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 3, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 4, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 1, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 2, 1, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 3, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<F64, F64, F64, 4, 3, ReduceMin, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 4, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<F64, F64, F64, 4, 4, ReduceMin, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 1, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<F64, F64, F64, 4, 1, ReduceMin, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 2, 1, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<F64, F64, F64, 2, 1, ReduceMin, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 3, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<F64, F64, F64, 4, 3, ReduceMin, PassThrough, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 4, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<F64, F64, F64, 4, 4, ReduceMin, PassThrough, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 1, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<F64, F64, F64, 4, 1, ReduceMin, PassThrough, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 2, 1, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<F64, F64, F64, 2, 1, ReduceMin, PassThrough, PassThrough, false, true>>&);
// clang-format on
} // namespace instance
......
......@@ -15,10 +15,10 @@ namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 3, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<4, 3, UnarySquare, UnarySqrt>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 4, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<4, 4, UnarySquare, UnarySqrt>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 1, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<4, 1, UnarySquare, UnarySqrt>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 2, 1, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<2, 1, UnarySquare, UnarySqrt>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 3, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<F64, F64, F64, 4, 3, ReduceAdd, UnarySquare, UnarySqrt, false, false>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 4, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<F64, F64, F64, 4, 4, ReduceAdd, UnarySquare, UnarySqrt, false, false>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 4, 1, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<F64, F64, F64, 4, 1, ReduceAdd, UnarySquare, UnarySqrt, false, false>>&);
extern template void add_device_reduce_instance_threadwise<F64, F64, F64, 2, 1, ReduceAdd, UnarySquare, UnarySqrt, false, false>(std::vector<DeviceReducePtr<F64, F64, F64, 2, 1, ReduceAdd, UnarySquare, UnarySqrt, false, false>>&);
// clang-format on
} // namespace instance
......
......@@ -15,10 +15,10 @@ namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 3, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 4, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 1, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 2, 1, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 3, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I32, I8, 4, 3, ReduceAdd, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 4, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I32, I8, 4, 4, ReduceAdd, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 1, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I32, I8, 4, 1, ReduceAdd, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 2, 1, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I32, I8, 2, 1, ReduceAdd, PassThrough, PassThrough, false, false>>&);
// clang-format on
} // namespace instance
......
......@@ -15,10 +15,10 @@ namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 3, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<4, 3, PassThrough, UnaryDivide>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 4, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<4, 4, PassThrough, UnaryDivide>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 1, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<4, 1, PassThrough, UnaryDivide>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 2, 1, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<2, 1, PassThrough, UnaryDivide>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 3, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<I8, I32, I8, 4, 3, ReduceAdd, PassThrough, UnaryDivide, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 4, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<I8, I32, I8, 4, 4, ReduceAdd, PassThrough, UnaryDivide, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 4, 1, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<I8, I32, I8, 4, 1, ReduceAdd, PassThrough, UnaryDivide, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I32, I8, 2, 1, ReduceAdd, PassThrough, UnaryDivide, false, false>(std::vector<DeviceReducePtr<I8, I32, I8, 2, 1, ReduceAdd, PassThrough, UnaryDivide, false, false>>&);
// clang-format on
} // namespace instance
......
......@@ -15,14 +15,14 @@ namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 3, UnaryAbs, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 4, UnaryAbs, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 1, UnaryAbs, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<2, 1, UnaryAbs, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 3, UnaryAbs, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 4, UnaryAbs, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 1, UnaryAbs, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<2, 1, UnaryAbs, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
// clang-format on
} // namespace instance
......
......@@ -15,14 +15,14 @@ namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 3, ReduceMax, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 4, ReduceMax, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 1, ReduceMax, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMax, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 2, 1, ReduceMax, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 3, ReduceMax, PassThrough, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 4, ReduceMax, PassThrough, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 1, ReduceMax, PassThrough, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMax, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 2, 1, ReduceMax, PassThrough, PassThrough, false, true>>&);
// clang-format on
} // namespace instance
......
......@@ -15,14 +15,14 @@ namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 3, ReduceMin, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 4, ReduceMin, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 1, ReduceMin, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMin, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<I8, I8, I8, 2, 1, ReduceMin, PassThrough, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 3, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 3, ReduceMin, PassThrough, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 4, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 4, ReduceMin, PassThrough, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 4, 1, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 4, 1, ReduceMin, PassThrough, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_threadwise<I8, I8, I8, 2, 1, ReduceMin, PassThrough, PassThrough, false, true>(std::vector<DeviceReducePtr<I8, I8, I8, 2, 1, ReduceMin, PassThrough, PassThrough, false, true>>&);
// clang-format on
} // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp"
#include "ck/utility/reduction_operator.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOp,
typename AccElementwiseOp,
bool PropagateNan,
bool OutputIndex>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOp,
AccElementwiseOp,
PropagateNan,
OutputIndex>>
{
using DeviceOp = DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOp,
AccElementwiseOp,
PropagateNan,
OutputIndex>;
using DeviceOpPtr = DeviceReducePtr<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOp,
AccElementwiseOp,
PropagateNan,
OutputIndex>;
static auto GetInstances()
{
std::vector<DeviceOpPtr> op_ptrs;
constexpr bool out_support_atomic_add =
ck::reduce::InMemoryDataOperationSupportedOnDataType<
InMemoryDataOperationEnum::AtomicAdd,
OutDataType>::value;
constexpr bool op_support_atomic_add =
std::is_same<ReduceOperation, ReduceAdd>::value &&
(std::is_same<AccElementwiseOp, PassThrough>::value ||
std::is_same<AccElementwiseOp, UnaryDivide>::value);
constexpr bool use_atomic_add = (out_support_atomic_add && op_support_atomic_add);
add_device_reduce_instance_threadwise<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOp,
AccElementwiseOp,
PropagateNan,
OutputIndex>(op_ptrs);
add_device_reduce_instance_blockwise<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOp,
AccElementwiseOp,
PropagateNan,
OutputIndex>(op_ptrs);
if constexpr(use_atomic_add)
{
add_device_reduce_instance_multiblock_atomic_add<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOp,
AccElementwiseOp,
PropagateNan,
OutputIndex>(op_ptrs);
};
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <array>
#include <functional>
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
template <int NDim>
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,
std::vector<std::array<size_t, NDim>>& indexes)
{
static_assert(NDim >= 1, "NDim >= 1 is required to use this function!");
if constexpr(NDim == 1)
{
for(size_t i = 0; i < dimLengths[0]; i++)
{
std::array<size_t, 1> index{i};
indexes.push_back(index);
};
}
else
{
std::array<size_t, NDim - 1> partial_dim_lengths;
for(int i = 0; i < NDim - 1; i++)
partial_dim_lengths[i] = dimLengths[i + 1];
std::vector<std::array<size_t, NDim - 1>> partial_indexes;
get_all_indexes<NDim - 1>(partial_dim_lengths, partial_indexes);
for(size_t i = 0; i < dimLengths[0]; i++)
for(const auto& index : partial_indexes)
{
std::array<size_t, NDim> extIndex;
extIndex[0] = i;
for(int k = 0; k < NDim - 1; k++)
extIndex[k + 1] = index[k];
indexes.push_back(extIndex);
};
};
};
template <int NDim>
static size_t get_offset_from_index(const std::array<size_t, NDim>& strides,
const std::array<size_t, NDim>& index)
{
size_t offset = 0;
for(int i = 0; i < NDim; i++)
offset += strides[i] * index[i];
return (offset);
};
template <int NDim>
static size_t get_offset_from_index(const std::vector<size_t>& strides,
const std::array<size_t, NDim>& index)
{
size_t offset = 0;
for(int i = 0; i < NDim; i++)
offset += strides[i] * index[i];
return (offset);
};
template <typename InDataType,
typename AccDataType,
typename OutDataType,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
int Rank,
int NumReduceDim,
bool PropagateNan,
bool OutputIndex>
struct ReductionHost
{
using IndexDataType = int32_t;
static constexpr int NumInvariantDim = Rank - NumReduceDim;
std::vector<size_t> outStrides;
IndexDataType divider;
std::array<size_t, NumReduceDim> reduceLengths;
std::array<size_t, NumReduceDim> reduceStrides;
std::array<size_t, NumInvariantDim> invariantLengths;
std::array<size_t, NumInvariantDim> invariantStrides;
std::vector<std::array<size_t, NumReduceDim>> reduce_dim_indexes;
std::vector<std::array<size_t, NumInvariantDim>> invariant_dim_indexes;
ReductionHost(HostTensorDescriptor& inDesc,
HostTensorDescriptor& outDesc,
const std::array<int, NumInvariantDim> invariantDims,
const std::array<int, NumReduceDim> reduceDims)
{
// this->outLengths = to_int_vector(outDesc.GetLengths());
this->outStrides = outDesc.GetStrides();
int product = 1;
for(int i = 0; i < NumReduceDim; i++)
{
reduceLengths[i] = inDesc.GetLengths()[reduceDims[i]];
reduceStrides[i] = inDesc.GetStrides()[reduceDims[i]];
product *= inDesc.GetLengths()[reduceDims[i]];
};
divider = product;
for(int i = 0; i < NumInvariantDim; i++)
{
invariantLengths[i] = inDesc.GetLengths()[invariantDims[i]];
invariantStrides[i] = inDesc.GetStrides()[invariantDims[i]];
};
reduce_dim_indexes.clear();
get_all_indexes<NumReduceDim>(reduceLengths, reduce_dim_indexes);
if constexpr(NumInvariantDim > 0)
{
invariant_dim_indexes.clear();
get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes);
};
};
void Run(float alpha,
const InDataType* in_data,
float beta,
OutDataType* out_data,
IndexDataType* out_indices,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{
if constexpr(OutputIndex)
{
RunImpl_with_index(
alpha, in_data, beta, out_data, out_indices, in_elementwise_op, acc_elementwise_op);
}
else
{
RunImpl_no_index(alpha, in_data, beta, out_data, in_elementwise_op, acc_elementwise_op);
};
};
void RunImpl_with_index(float alpha,
const InDataType* in_data,
float beta,
OutDataType* out_data,
IndexDataType* out_indices,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{
using ck::float_equal_one;
using ck::float_equal_zero;
using ck::type_convert;
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
{
auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
};
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
if(!float_equal_zero{}(beta))
accuVal += type_convert<AccDataType>(out_data[0]) * type_convert<AccDataType>(beta);
out_data[0] = type_convert<OutDataType>(accuVal);
out_indices[0] = accuIndex;
}
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
{
auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
};
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
auto dst_offset =
get_offset_from_index<NumInvariantDim>(outStrides, invariant_index);
if(!float_equal_zero{}(beta))
accuVal += type_convert<AccDataType>(out_data[dst_offset]) *
type_convert<AccDataType>(beta);
out_data[dst_offset] = type_convert<OutDataType>(accuVal);
out_indices[dst_offset] = accuIndex;
};
std::size_t num_thread = 1;
std::size_t work_per_thread =
(invariant_dim_indexes.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end =
std::min((it + 1) * work_per_thread, invariant_dim_indexes.size());
auto f = [=] {
for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
{
thread_reduce_func(invariant_dim_indexes[iw]);
}
};
threads[it] = joinable_thread(f);
}
};
};
void RunImpl_no_index(float alpha,
const InDataType* in_data,
float beta,
OutDataType* out_data,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{
using ck::float_equal_one;
using ck::float_equal_zero;
using ck::type_convert;
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(const auto& reduce_index : reduce_dim_indexes)
{
auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_index);
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
};
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
if(!float_equal_zero{}(beta))
accuVal += type_convert<AccDataType>(out_data[0]) * type_convert<AccDataType>(beta);
out_data[0] = type_convert<OutDataType>(accuVal);
}
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
for(const auto& reduce_index : reduce_dim_indexes)
{
auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_index);
auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
};
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
auto dst_offset =
get_offset_from_index<NumInvariantDim>(outStrides, invariant_index);
if(!float_equal_zero{}(beta))
accuVal += type_convert<AccDataType>(out_data[dst_offset]) *
type_convert<AccDataType>(beta);
out_data[dst_offset] = type_convert<OutDataType>(accuVal);
};
std::size_t num_thread = 1;
std::size_t work_per_thread =
(invariant_dim_indexes.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end =
std::min((it + 1) * work_per_thread, invariant_dim_indexes.size());
auto f = [=] {
for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
{
thread_reduce_func(invariant_dim_indexes[iw]);
}
};
threads[it] = joinable_thread(f);
}
};
};
};
......@@ -396,7 +396,7 @@ struct Tensor
}
case 6: {
auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) {
(*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4, i5);
(*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5);
};
make_ParallelTensorFunctor(f,
mDesc.GetLengths()[0],
......
add_instance_library(device_batched_gemm_bias_permute_instance
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using F16_Tuple = ck::Tuple<F16>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
// A[g0, m0, m1, k0] * B[g0, n0, n1, n2, k0] + D[g0, m0, m1, n0, n1, n2] = E[g0, n0, m0, n0, n1, m1]
// m/n/n/n are the fast changing dimension for A/B/D/E
using device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_mnnm_instance =
std::tuple<
// clang-format off
//############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
//M1 faster dim
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceBatchedContractionMultipleD_Xdl_CShuffle< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, ABSpec, ABSpec, DESpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
void add_device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_mnnm_instance(
std::vector<std::unique_ptr<DeviceBatchedContractionMultipleD<1,
2,
3,
1,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_mnnm_instance{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,4 +7,8 @@ add_instance_library(device_batchnorm_instance
device_batchnorm_backward_f32_instance.cpp
device_batchnorm_backward_bf16_instance.cpp
device_batchnorm_backward_f64_instance.cpp
device_batchnorm_infer_f16_instance.cpp
device_batchnorm_infer_f32_instance.cpp
device_batchnorm_infer_bf16_instance.cpp
device_batchnorm_infer_f64_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using Normalize = ck::tensor_operation::element_wise::NormalizeInInfer;
// clang-format off
template <index_t Rank>
using device_batchnorm_infer_bf16_instances =
std::tuple <
// Tuple<XDataType, MeanDataType, VarDataType, ScaleDataType, BiasDataType>, Tuple<YDataType>, NormalizeOp, Rank, MPerThread, Sequence<XVectorSize, MeanDataType, VarDataType, ScaleVectorSize, BiasVectorSize>, Sequence<YVectorSize>
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 1, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 2, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 2, Sequence<2, 1, 1, 1, 1>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 2, Sequence<1, 2, 2, 2, 2>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 2, Sequence<2, 2, 2, 2, 2>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 4, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 4, Sequence<2, 1, 1, 1, 1>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 4, Sequence<1, 2, 2, 2, 2>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 4, Sequence<2, 2, 2, 2, 2>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 4, Sequence<4, 1, 1, 1, 1>, Sequence<4> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 4, Sequence<1, 4, 4, 4, 4>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 4, Sequence<4, 2, 2, 2, 2>, Sequence<4> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 4, Sequence<2, 4, 4, 4, 4>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, Rank, 4, Sequence<4, 4, 4, 4, 4>, Sequence<4> >
>;
// clang-format on
void add_device_batchnorm_infer_rank_4_bf16_instances(
std::vector<std::unique_ptr<
DeviceElementwise<Tuple<BF16, F32, F32, BF16, BF16>, Tuple<BF16>, Normalize, 4>>>&
instances)
{
add_device_operation_instances(instances, device_batchnorm_infer_bf16_instances<4>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Normalize = ck::tensor_operation::element_wise::NormalizeInInfer;
// clang-format off
template <index_t Rank>
using device_batchnorm_infer_f16_instances =
std::tuple <
// Tuple<XDataType, MeanDataType, VarDataType, ScaleDataType, BiasDataType>, Tuple<YDataType>, NormalizeOp, Rank, MPerThread, Sequence<XVectorSize, MeanDataType, VarDataType, ScaleVectorSize, BiasVectorSize>, Sequence<YVectorSize>
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 1, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 2, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 2, Sequence<2, 1, 1, 1, 1>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 2, Sequence<1, 2, 2, 2, 2>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 2, Sequence<2, 2, 2, 2, 2>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 4, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 4, Sequence<2, 1, 1, 1, 1>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 4, Sequence<1, 2, 2, 2, 2>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 4, Sequence<2, 2, 2, 2, 2>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 4, Sequence<4, 1, 1, 1, 1>, Sequence<4> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 4, Sequence<1, 4, 4, 4, 4>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 4, Sequence<4, 2, 2, 2, 2>, Sequence<4> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 4, Sequence<2, 4, 4, 4, 4>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, Rank, 4, Sequence<4, 4, 4, 4, 4>, Sequence<4> >
>;
// clang-format on
void add_device_batchnorm_infer_rank_4_f16_instances(
std::vector<std::unique_ptr<
DeviceElementwise<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 4>>>& instances)
{
add_device_operation_instances(instances, device_batchnorm_infer_f16_instances<4>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F32 = float;
using Normalize = ck::tensor_operation::element_wise::NormalizeInInfer;
// clang-format off
template <index_t Rank>
using device_batchnorm_infer_f32_instances =
std::tuple <
// Tuple<XDataType, MeanDataType, VarDataType, ScaleDataType, BiasDataType>, Tuple<YDataType>, NormalizeOp, Rank, MPerThread, Sequence<XVectorSize, MeanDataType, VarDataType, ScaleVectorSize, BiasVectorSize>, Sequence<YVectorSize>
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 1, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 2, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 2, Sequence<2, 1, 1, 1, 1>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 2, Sequence<1, 2, 2, 2, 2>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 2, Sequence<2, 2, 2, 2, 2>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 4, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 4, Sequence<2, 1, 1, 1, 1>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 4, Sequence<1, 2, 2, 2, 2>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 4, Sequence<2, 2, 2, 2, 2>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 4, Sequence<4, 1, 1, 1, 1>, Sequence<4> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 4, Sequence<1, 4, 4, 4, 4>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 4, Sequence<4, 2, 2, 2, 2>, Sequence<4> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 4, Sequence<2, 4, 4, 4, 4>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, Rank, 4, Sequence<4, 4, 4, 4, 4>, Sequence<4> >
>;
// clang-format on
void add_device_batchnorm_infer_rank_4_f32_instances(
std::vector<std::unique_ptr<
DeviceElementwise<Tuple<F32, F32, F32, F32, F32>, Tuple<F32>, Normalize, 4>>>& instances)
{
add_device_operation_instances(instances, device_batchnorm_infer_f32_instances<4>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F64 = double;
using Normalize = ck::tensor_operation::element_wise::NormalizeInInfer;
// clang-format off
template <index_t Rank>
using device_batchnorm_infer_f64_instances =
std::tuple <
// Tuple<XDataType, MeanDataType, VarDataType, ScaleDataType, BiasDataType>, Tuple<YDataType>, NormalizeOp, Rank, MPerThread, Sequence<XVectorSize, MeanDataType, VarDataType, ScaleVectorSize, BiasVectorSize>, Sequence<YVectorSize>
DeviceElementwiseImpl<Tuple<F64, F64, F64, F64, F64>, Tuple<F64>, Normalize, Rank, 1, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F64, F64, F64, F64, F64>, Tuple<F64>, Normalize, Rank, 2, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F64, F64, F64, F64, F64>, Tuple<F64>, Normalize, Rank, 2, Sequence<2, 1, 1, 1, 1>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F64, F64, F64, F64, F64>, Tuple<F64>, Normalize, Rank, 2, Sequence<1, 2, 2, 2, 2>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F64, F64, F64, F64, F64>, Tuple<F64>, Normalize, Rank, 2, Sequence<2, 2, 2, 2, 2>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F64, F64, F64, F64, F64>, Tuple<F64>, Normalize, Rank, 4, Sequence<1, 1, 1, 1, 1>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F64, F64, F64, F64, F64>, Tuple<F64>, Normalize, Rank, 4, Sequence<2, 1, 1, 1, 1>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F64, F64, F64, F64, F64>, Tuple<F64>, Normalize, Rank, 4, Sequence<1, 2, 2, 2, 2>, Sequence<1> >,
DeviceElementwiseImpl<Tuple<F64, F64, F64, F64, F64>, Tuple<F64>, Normalize, Rank, 4, Sequence<2, 2, 2, 2, 2>, Sequence<2> >
>;
// clang-format on
void add_device_batchnorm_infer_rank_4_f64_instances(
std::vector<std::unique_ptr<
DeviceElementwise<Tuple<F64, F64, F64, F64, F64>, Tuple<F64>, Normalize, 4>>>& instances)
{
add_device_operation_instances(instances, device_batchnorm_infer_f64_instances<4>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
......@@ -28,15 +28,15 @@ using Normalize = ck::tensor_operation::element_wise::Normalize;
using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances = std::tuple<
// clang-format off
//###################|<in, mean, square_mean, gamma, beta>| <out>| functor| NDim| MPerThread| <in, mean, square_mean, gamma, beta ScalarPerVector>| <out ScalarPerVector>|
DeviceElementwise<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 8, Sequence<8, 1, 1, 8, 8>, Sequence<8> >,
DeviceElementwise<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 4, Sequence<4, 1, 1, 4, 4>, Sequence<4> >,
DeviceElementwise<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 2, Sequence<2, 1, 1, 2, 2>, Sequence<2> >,
DeviceElementwise<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 1, Sequence<1, 1, 1, 1, 1>, Sequence<1> >
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 8, Sequence<8, 1, 1, 8, 8>, Sequence<8> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 4, Sequence<4, 1, 1, 4, 4>, Sequence<4> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 2, Sequence<2, 1, 1, 2, 2>, Sequence<2> >,
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 1, Sequence<1, 1, 1, 1, 1>, Sequence<1> >
// clang-format on
>;
void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(
std::vector<DeviceElementwiseBasePtr<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2>>&
std::vector<DeviceElementwisePtr<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2>>&
instances)
{
add_device_operation_instances(
......
add_instance_library(device_gemm_add_multiply_instance
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using F16_Tuple = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row_Tuple = ck::Tuple<Row, Row>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances =
std::tuple<
// clang-format off
// no padding
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// M/N/K Padding
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment