Commit 77a60235 authored by Astha Rai's avatar Astha Rai
Browse files

implemented client ex with device_elementwise.hpp and device_elementwise_3d_impl.hpp

parent a2ddbd2b
...@@ -15,187 +15,37 @@ namespace instance { ...@@ -15,187 +15,37 @@ namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using device_transpose_f16_instances = std::tuple< using device_transpose_f16_instances = std::tuple<
// clang-format off FOR 16, 32, 16, 32, 16 // FOR 16, 32, 16, 32, 16
DeviceElementwise3dImpl<ck::Tuple<F16>, // clang-format off
ck::Tuple<F16>, DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<8>, ck::Sequence<8>>,
2, DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<8>, ck::Sequence<1>>,
2, DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<1>, ck::Sequence<8>>,
1, DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<1>, ck::Sequence<1>>,
8, DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 1, 1, ck::Sequence<1>, ck::Sequence<1>>,
8, DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 1, 1, ck::Sequence<8>, ck::Sequence<1>>,
8, DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 4, 4, ck::Sequence<1>, ck::Sequence<1>>,
ck::Sequence<8>, DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 4, 4, ck::Sequence<8>, ck::Sequence<8>>
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
8,
8,
ck::Sequence<8>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
8,
8,
ck::Sequence<1>,
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
8,
8,
ck::Sequence<1>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
1,
1,
ck::Sequence<1>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
1,
1,
ck::Sequence<8>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
4,
4,
ck::Sequence<1>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>,
ck::Tuple<F16>,
2,
2,
1,
8,
4,
4,
ck::Sequence<8>,
ck::Sequence<8>>
// clang-format on // clang-format on
>; >;
using device_transpose_f32_instances = std::tuple< using device_transpose_f32_instances = std::tuple<
// clang-format off // for 16, 8, 16, 32, 8 -> test with instances for fp16 // for 16, 8, 16, 32, 8 -> test with instances for fp16
DeviceElementwise3dImpl<ck::Tuple<F32>, // clang-format off
ck::Tuple<F32>, DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<1>, ck::Sequence<1>>,
2, DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<8>, ck::Sequence<1>>,
2, DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<8>, ck::Sequence<8>>,
1, DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 4, ck::Sequence<8>, ck::Sequence<8>>,
4, DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 8, ck::Sequence<8>, ck::Sequence<8>>,
4, DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 8, ck::Sequence<4>, ck::Sequence<8>>,
4, DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 8, ck::Sequence<4>, ck::Sequence<4>>,
ck::Sequence<1>, DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 8, ck::Sequence<8>, ck::Sequence<4>>,
ck::Sequence<1>>, DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 8, ck::Sequence<8>, ck::Sequence<8>>
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
4,
4,
ck::Sequence<8>,
ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
4,
4,
ck::Sequence<8>,
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
8,
4,
ck::Sequence<8>,
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
8,
8,
ck::Sequence<8>,
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
8,
8,
ck::Sequence<4>,
ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
8,
8,
ck::Sequence<4>,
ck::Sequence<4>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
8,
8,
ck::Sequence<8>,
ck::Sequence<4>>,
DeviceElementwise3dImpl<ck::Tuple<F32>,
ck::Tuple<F32>,
2,
2,
1,
4,
4,
8,
ck::Sequence<8>,
ck::Sequence<8>>,
// clang-format on // clang-format on
>; >;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -15,11 +15,18 @@ namespace tensor_operation { ...@@ -15,11 +15,18 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_device_transpose_f16_instances( void add_device_transpose_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise3dImpl<F16, F16, NCDHW, 3>>>& instances); std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 5>>>&
instances);
void add_device_transpose_f32_instances( void add_device_transpose_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise3dImpl<F32, F32, NCDHW, 3>>>& instances); std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 5>>>&
instances);
template <typename InDataTypeTuple, template <typename InDataTypeTuple,
typename OutDataTypeTuple, typename OutDataTypeTuple,
...@@ -27,19 +34,10 @@ template <typename InDataTypeTuple, ...@@ -27,19 +34,10 @@ template <typename InDataTypeTuple,
index_t NumDim> index_t NumDim>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device:: ck::tensor_operation::device::
DeviceElementwise3dImpl<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>> DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>>
{ {
using DeviceOp = DeviceElementwise3dImpl<InDataTypeTuple, using DeviceOp =
OutDataTypeTuple, DeviceElementwise3dImpl<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>;
ElementwiseOperation,
NumDim_m, // choose how to set dims
NumDim_n,
NumDim_k,
MPerThread,
NPerThread,
KPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
static auto GetInstances() static auto GetInstances()
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/transpose/device_transpose_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/transpose/device_transpose_instance.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -9,21 +11,27 @@ namespace tensor_operation { ...@@ -9,21 +11,27 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_device_transpose_f16_instances( void add_device_transpose_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise3dImpl<F16, F16, NCDHW, 3>>>& instances) std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 5>>>&
instances)
{ {
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
add_device_operation_instances(instances, device_transpose_f16_instances<F16, F16, NCDHW, 3>{}); add_device_operation_instances(instances, device_transpose_f16_instances{});
#else #else
ignore = instances; ignore = instances;
#endif #endif
} }
void add_device_transpose_f32_instances( void add_device_transpose_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise3dImpl<F32, F32, NCDHW, 3>>>& instances) std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 5>>>&
instances)
{ {
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
add_device_operation_instances(instances, device_transpose_f32_instances<F32, F32, NCDHW, 3>{}); add_device_operation_instances(instances, device_transpose_f32_instances{});
#else #else
ignore = instances; ignore = instances;
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment