Commit 77a60235 authored by Astha Rai's avatar Astha Rai
Browse files

implemented client ex with device_elementwise.hpp and device_elementwise_3d_impl.hpp

parent a2ddbd2b
......@@ -15,187 +15,37 @@ namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using device_transpose_f16_instances = std::tuple<
// clang-format off FOR 16, 32, 16, 32, 16
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
8,
8,
ck::Sequence<8>,
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
8,
8,
ck::Sequence<8>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
8,
8,
ck::Sequence<1>,
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
8,
8,
ck::Sequence<1>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
1,
1,
ck::Sequence<1>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
1,
1,
ck::Sequence<8>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
4,
4,
ck::Sequence<1>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
4,
4,
ck::Sequence<8>,
ck::Sequence<8>>
// FOR 16, 32, 16, 32, 16
// clang-format off
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<8>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<1>, ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 1, 1, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 1, 1, ck::Sequence<8>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 4, 4, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 4, 4, ck::Sequence<8>, ck::Sequence<8>>
// clang-format on
>;
using device_transpose_f32_instances = std::tuple<
// clang-format off // for 16, 8, 16, 32, 8 -> test with instances for fp16
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
4,
4,
ck::Sequence<1>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
4,
4,
ck::Sequence<8>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
4,
4,
ck::Sequence<8>,
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
8,
4,
ck::Sequence<8>,
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
8,
8,
ck::Sequence<8>,
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
8,
8,
ck::Sequence<4>,
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
8,
8,
ck::Sequence<4>,
ck::Sequence<4>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
8,
8,
ck::Sequence<8>,
ck::Sequence<4>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
4,
8,
ck::Sequence<8>,
ck::Sequence<8>>,
// for 16, 8, 16, 32, 8 -> test with instances for fp16
// clang-format off
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<8>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 4, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 8, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 8, ck::Sequence<4>, ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 8, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 8, ck::Sequence<8>, ck::Sequence<4>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 8, ck::Sequence<8>, ck::Sequence<8>>
// clang-format on
>;
......
......@@ -6,7 +6,7 @@
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
......@@ -15,11 +15,18 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_device_transpose_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise3dImpl<F16, F16, NCDHW, 3>>>& instances);
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 5>>>&
instances);
void add_device_transpose_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise3dImpl<F32, F32, NCDHW, 3>>>& instances);
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 5>>>&
instances);
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
......@@ -27,19 +34,10 @@ template <typename InDataTypeTuple,
index_t NumDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceElementwise3dImpl<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>>
DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>>
{
using DeviceOp = DeviceElementwise3dImpl<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m, // choose how to set dims
NumDim_n,
NumDim_k,
MPerThread,
NPerThread,
KPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
using DeviceOp =
DeviceElementwise3dImpl<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>;
static auto GetInstances()
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/transpose/device_transpose_instance.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
......@@ -9,21 +11,27 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_device_transpose_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise3dImpl<F16, F16, NCDHW, 3>>>& instances)
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 5>>>&
instances)
{
#ifdef CK_ENABLE_FP16
add_device_operation_instances(instances, device_transpose_f16_instances<F16, F16, NCDHW, 3>{});
add_device_operation_instances(instances, device_transpose_f16_instances{});
#else
ignore = instances;
#endif
}
void add_device_transpose_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise3dImpl<F32, F32, NCDHW, 3>>>& instances)
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 5>>>&
instances)
{
#ifdef CK_ENABLE_FP32
add_device_operation_instances(instances, device_transpose_f32_instances<F32, F32, NCDHW, 3>{});
add_device_operation_instances(instances, device_transpose_f32_instances{});
#else
ignore = instances;
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment