"vscode:/vscode.git/clone" did not exist on "c5877261905a36238d8162c817478cff9bc77bc4"
Unverified Commit 370efa6c authored by ltqin's avatar ltqin Committed by GitHub
Browse files

batched_gemm + multiple_d + gemm + multiple_d (#394)



* refactor

* start

* add device gemm file

* add BatchStrideD0

* add stridd0

* add gridwise file

* add d0 parameters to gridwise gemm

* add c layout transformer

* add d0 threadwise copy

* init kernel

* init kernel

* regular code

* nm desc put to out

* kernel parameter can not use reference

* host add bias+gelu

* run right for bias+gelu

* change AddFastGelu into another file

* interface add d1 bias parameters

* add d1 parameter to argument

* add d1 parameter to gridwise

* first all code,not verify

* gelu change to relu and GetElementSpaceSize bug

* add instance

* start add to ckprofiler

* ckprofiler finish code

* change input parameter for ckProfiler

* fix host bias+gelu bug

* show help for ckProfiler

* fix bug for lunch kernel ignore parametes

* add pad and fix about bug

* mutiple d0

* add dynamic d0_element_op

* change profiler and  instance to mutiple d0

* example have 2 d0

* remove some comments not using

* change 2 d0 have self  parameters

* change d element_op name

* change class name(multiple_d)

* fix bug

* fix bug that don't find file

* update profiler

* refactor

* update profiler

* clean

* revert example change

* add gon layout

* optimize parameter for gno

* add gon to gemm+gemm

* change helping input parameters

* change to GemmPadder_v2

* using ForEach

* fix gb_per_sec
Co-authored-by: default avatarChao Liu <lc.roy86@gmail.com>
Co-authored-by: default avatarltqin <letaoqin@amd.com>
parent b22ebd44
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
......@@ -52,4 +52,5 @@ add_subdirectory(33_multiple_reduce)
add_subdirectory(34_batchnorm)
add_subdirectory(35_splitK_gemm)
add_subdirectory(36_sparse_embedding)
add_subdirectory(37_batched_gemm_add_add_relu_gemm_add)
add_subdirectory(41_grouped_conv_conv_fwd)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename A0Layout,
typename B0Layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename E1Layout,
typename A0DataType,
typename B0DataType,
typename D0sDataType,
typename B1DataType,
typename D1sDataType,
typename E1DataType,
typename A0ElementwiseOperation,
typename B0ElementwiseOperation,
typename CDE0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation>
struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
{
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a0,
const void* p_b0,
std::array<const void*, NumD0Tensor> p_d0s,
const void* p_b1,
std::array<const void*, NumD1Tensor> p_d1s,
void* p_e1,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
ck::index_t StrideA0,
ck::index_t StrideB0,
std::array<ck::index_t, NumD0Tensor> StrideD0s,
ck::index_t StrideB1,
std::array<ck::index_t, NumD1Tensor> StrideD1s,
ck::index_t StrideE1,
ck::index_t BatchStrideA0,
ck::index_t BatchStrideB0,
std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
ck::index_t BatchStrideB1,
std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
ck::index_t BatchStrideE1,
A0ElementwiseOperation a0_element_op,
B0ElementwiseOperation b0_element_op,
CDE0ElementwiseOperation cde0_element_op,
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -218,6 +218,165 @@ struct GemmPadder_v2
KPerTileType KPerTile_;
};
// M/N/KPerTileType could be index_t or Number<>
template <bool PadM,
bool PadN,
bool PadK,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType>
struct MatrixPadder_v2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
template <typename ADesc_MRaw_KRaw>
__host__ __device__ constexpr auto
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
{
const auto MRaw = a_desc_mraw_kraw.GetLength(I0);
const auto KRaw = a_desc_mraw_kraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(PadM && PadK)
{
// pad both M and K
return transform_tensor_descriptor(a_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(PadM && (!PadK))
{
// pad M, but not K
return transform_tensor_descriptor(
a_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr((!PadM) && PadK)
{
// pad K, but not M
return transform_tensor_descriptor(
a_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or K
return a_desc_mraw_kraw;
}
}
template <typename BDesc_NRaw_KRaw>
__host__ __device__ constexpr auto
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
{
const auto NRaw = b_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_desc_nraw_kraw.GetLength(I1);
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(PadN && PadK)
{
// pad both N and K
return transform_tensor_descriptor(b_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(PadN && (!PadK))
{
// pad N, but not K
return transform_tensor_descriptor(
b_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad), make_pass_through_transform(KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr((!PadN) && PadK)
{
// pad K, but not N
return transform_tensor_descriptor(
b_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad N or K
return b_desc_nraw_kraw;
}
}
template <typename CDesc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
{
const auto MRaw = c_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_desc_mraw_nraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(PadM && PadN)
{
// pad M and N
return transform_tensor_descriptor(c_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(PadM && (!PadN))
{
// pad M, but not N
return transform_tensor_descriptor(
c_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr((!PadM) && PadN)
{
// pad N, but not M
return transform_tensor_descriptor(
c_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_desc_mraw_nraw;
}
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -28,6 +28,13 @@ struct Add
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const half_t& x1) const
{
y = x0 + type_convert<half_t>(x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
......@@ -172,6 +179,14 @@ struct AddRelu
const float a = x0 + x1;
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
};
template <>
__host__ __device__ constexpr void
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > 0.0f ? a : 0.0f;
};
};
struct AddHardswish
......@@ -210,6 +225,46 @@ struct AddHardswish
};
};
// C = A * B
// E = FastGelu(C + D)
struct AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D>);
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
e = type_convert<E>(y);
}
template <typename D>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
{
static_assert(is_valid_param_type_v<D>);
e = GetFastGeLU(c + type_convert<float>(d));
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -211,6 +211,27 @@ struct FastGelu
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
}
template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t>(ck::half_t& y,
const ck::half_t& x) const
{
y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x))));
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Row,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>&
instances);
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Col,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>&
instances);
template <typename A0Layout,
typename B0Layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename E1Layout,
typename A0DataType,
typename B0DataType,
typename D0sDataType,
typename B1DataType,
typename D1sDataType,
typename E1DataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
A0DataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
E1DataType,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>
{
using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
A0DataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
E1DataType,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<A0DataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<E1DataType, half_t>)
{
if constexpr(is_same_v<A0Layout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<E1Layout, Row>)
{
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
}
else if constexpr(is_same_v<A0Layout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Col> && is_same_v<E1Layout, Row>)
{
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -16,6 +16,7 @@ add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_add_relu_gemm_add)
add_subdirectory(grouped_gemm)
add_subdirectory(contraction_scale)
add_subdirectory(contraction_bilinear)
......@@ -42,6 +43,7 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_batched_gemm_add_relu_gemm_add_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_contraction_scale_instance>
......
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
std::tuple<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Row,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances =
std::tuple<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Col,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -12,6 +12,8 @@ set(PROFILER_SOURCE
src/profile_gemm_add_add_fastgelu.cpp
src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp
src/profile_batched_gemm_gemm.cpp
src/profile_batched_gemm_add_relu_gemm_add.cpp
src/profile_batched_gemm_reduce.cpp
src/profile_grouped_gemm.cpp
src/profile_conv_fwd.cpp
......@@ -35,6 +37,8 @@ target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
namespace ck {
namespace profiler {
template <typename A0Layout,
typename B0Layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename E1Layout,
typename A0DataType,
typename B0DataType,
typename D0sDataType,
typename B1DataType,
typename D1sDataType,
typename E1DataType>
bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int O,
int BatchCount = 1,
int StrideA0 = -1,
int StrideB0 = -1,
int StrideD0 = -1,
int StrideB1 = -1,
int StrideD1 = -1,
int StrideE1 = -1,
int BatchStrideA0 = -1,
int BatchStrideB0 = -1,
int BatchStrideD0 = -1,
int BatchStrideB1 = -1,
int BatchStrideD1 = -1,
int BatchStrideE1 = -1)
{
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
using A0ElementOp = PassThrough;
using B0ElementOp = PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using B1ElementOp = PassThrough;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
using D0DataType = remove_cvref_t<tuple_element_t<0, D0sDataType>>;
using D0Layout = remove_cvref_t<tuple_element_t<0, D0sLayout>>;
using D1DataType = remove_cvref_t<tuple_element_t<0, D1sDataType>>;
using D1Layout = remove_cvref_t<tuple_element_t<0, D1sLayout>>;
// for reference
using RefAcc0DataType = float;
using RefAcc1DataType = float;
bool pass = true;
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideD0 = (StrideD0 < 0) ? DefaultStrideD0 : StrideD0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideD0 = (ck::is_same_v<D0Layout, Col> ? N : M) * StrideD0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideD0 = BatchStrideD0 < 0 ? DefaultBatchStrideD0 : BatchStrideD0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
// E_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<A0DataType> a0_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<D0DataType> d0_g_m_n(
f_host_tensor_descriptor(BatchCount, M, N, StrideD0, BatchStrideD0, D0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<D1DataType> d1_g_m_o(
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
Tensor<E1DataType> e1_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
Tensor<E1DataType> e1_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
// Host verification: Output of Gemm0 is input A of Gemm1
Tensor<RefAcc0DataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<RefAcc0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<RefAcc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "d0_g_m_n: " << d0_g_m_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "d1_g_m_o: " << d1_g_m_o.mDesc << std::endl;
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
d0_g_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 3});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
break;
default:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_g_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
DeviceMem d0_g_m_n_device_buf(sizeof(D0DataType) * d0_g_m_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
e1_g_m_o_device_result.mDesc.GetElementSize());
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
d0_g_m_n_device_buf.ToDevice(d0_g_m_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
auto a0_element_op = A0ElementOp{};
auto b0_element_op = B0ElementOp{};
auto cde0_element_op = CDE0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto cde1_element_op = CDE1ElementOp{};
using DeviceOp =
tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
A0DataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
E1DataType,
A0ElementOp,
B0ElementOp,
CDE0ElementOp,
B1ElementOp,
CDE1ElementOp>;
// get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
if(do_verification)
{
// Ref Gemm0
using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<A0DataType,
B0DataType,
RefAcc0DataType,
RefAcc0DataType,
A0ElementOp,
B0ElementOp,
PassThrough>;
// Ref Gemm1
using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm<RefAcc0DataType,
B1DataType,
RefAcc1DataType,
RefAcc1DataType,
PassThrough,
B1ElementOp,
PassThrough>;
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// cde0_elementwise
e0_g_m_n.ForEach(
[&](auto&, auto idx) { cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d0_g_m_n(idx)); });
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
// cde1_elementwise
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
});
}
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d0_g_m_n_device_buf.GetDeviceBuffer()},
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
std::array<ck::index_t, 1>{StrideD0},
StrideB1,
std::array<ck::index_t, 1>{StrideD1},
StrideE1,
BatchStrideA0,
BatchStrideB0,
std::array<ck::index_t, 1>{BatchStrideD0},
BatchStrideB1,
std::array<ck::index_t, 1>{BatchStrideD1},
BatchStrideE1,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype =
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D0DataType) * N +
sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O + sizeof(D1DataType) * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
pass = pass & ck::utils::check_err(e1_g_m_o_device_result.mData,
e1_g_m_o_host_result.mData);
if(do_log)
{
LogRangeAsType<float>(
std::cout << "e1_g_m_o_host_result : ", e1_g_m_o_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "e1_g_m_o_device_result : ", e1_g_m_o_device_result.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/include/profile_batched_gemm_add_relu_gemm_add_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
int profile_batched_gemm_add_relu_gemm_add(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
MK_NK_MN_NO_MO_MO, // 0
MK_NK_MN_ON_MO_MO, // 1
};
enum struct GemmDataType
{
F32_F32_F32_F32_F32_F32, // 0
F16_F16_F16_F16_F16_F16, // 1
};
GemmDataType data_type = GemmDataType::F16_F16_F16_F16_F16_F16;
GemmMatrixLayout layout = GemmMatrixLayout::MK_NK_MN_NO_MO_MO;
bool do_verification = true;
int init_method = 1;
bool do_log = 0;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA0 = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideD0 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideD1 = -1;
ck::index_t StrideE1 = -1;
ck::index_t BatchStrideA0 = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideD0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideD1 = -1;
ck::index_t BatchStrideE1 = -1;
if(argc == 8)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
}
else if(argc == 13)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
M = std::stoi(argv[8]);
N = std::stoi(argv[9]);
K = std::stoi(argv[10]);
O = std::stoi(argv[11]);
BatchCount = std::stoi(argv[12]);
}
else if(argc == 25)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
M = std::stoi(argv[8]);
N = std::stoi(argv[9]);
K = std::stoi(argv[10]);
O = std::stoi(argv[11]);
BatchCount = std::stoi(argv[12]);
StrideA0 = std::stoi(argv[13]);
StrideB0 = std::stoi(argv[14]);
StrideD0 = std::stoi(argv[15]);
StrideB1 = std::stoi(argv[16]);
StrideD1 = std::stoi(argv[17]);
StrideE1 = std::stoi(argv[18]);
BatchStrideA0 = std::stoi(argv[19]);
BatchStrideB0 = std::stoi(argv[20]);
BatchStrideD0 = std::stoi(argv[21]);
BatchStrideB1 = std::stoi(argv[22]);
BatchStrideD1 = std::stoi(argv[23]);
BatchStrideE1 = std::stoi(argv[24]);
}
else
{
printf("arg1: tensor operation (batched_gemm_add_relu_gemm_add: "
"Batched_GEMM+Add+Relu+Gemm+Add)\n");
printf("arg2: data type (1: fp16)\n");
printf("arg3: matrix layout (0: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[n, o] + D1[m, o] "
"= E1[m, o]; 1: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[o, n] + D1[m, o] = "
"E1[m, o];)\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 12: M, N, K, O, Batch\n");
printf("arg13 to 18: StrideA0, StrideB0, StrideD0, StrideB1, StrideD1, StrideE1\n");
printf("arg19 to 24: BatchStrideA0, BatchStrideB0, BatchStrideD0, BatchStrideB1, "
"BatchStrideD1, BatchStrideE1 \n");
exit(1);
}
if(data_type == GemmDataType::F16_F16_F16_F16_F16_F16 &&
layout == GemmMatrixLayout::MK_NK_MN_NO_MO_MO)
{
ck::profiler::profile_batched_gemm_add_relu_gemm_add_impl<Row, // A0Layout,
Col, // B0Layout,
ck::Tuple<Row>, // D0sLayout,
Row, // B1Layout,
ck::Tuple<Row>, // D1sLayout,
Row, // E1Layout,
F16, // A0DataType,
F16, // B0DataType,
ck::Tuple<F16>, // D0DataType,
F16, // B1DataType,
ck::Tuple<F16>, // D1sDataType
F16> // E1DataType,
(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
StrideD0,
StrideB1,
StrideD1,
StrideE1,
BatchStrideA0,
BatchStrideB0,
BatchStrideD0,
BatchStrideB1,
BatchStrideD1,
BatchStrideE1);
}
else if(data_type == GemmDataType::F16_F16_F16_F16_F16_F16 &&
layout == GemmMatrixLayout::MK_NK_MN_ON_MO_MO)
{
ck::profiler::profile_batched_gemm_add_relu_gemm_add_impl<Row, // A0Layout,
Col, // B0Layout,
ck::Tuple<Row>, // D0sLayout,
Col, // B1Layout,
ck::Tuple<Row>, // D1sLayout,
Row, // E1Layout,
F16, // A0DataType,
F16, // B0DataType,
ck::Tuple<F16>, // D0DataType,
F16, // B1DataType,
ck::Tuple<F16>, // D1sDataType
F16> // E1DataType,
(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
StrideD0,
StrideB1,
StrideD1,
StrideE1,
BatchStrideA0,
BatchStrideB0,
BatchStrideD0,
BatchStrideB1,
BatchStrideD1,
BatchStrideE1);
}
else
{
throw std::runtime_error("wrong! this data_type & layout is not implemented");
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
int profile_batched_gemm_gemm(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
MK_NK_NO_MO, // 0
MK_NK_ON_MO, // 0
};
enum struct GemmDataType
{
F32_F32_F32_F32, // 0
F16_F16_F16_F16, // 1
};
GemmDataType data_type = GemmDataType::F16_F16_F16_F16;
GemmMatrixLayout layout = GemmMatrixLayout::MK_NK_NO_MO;
bool do_verification = true;
int init_method = 1;
bool do_log = 0;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA0 = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideE1 = -1;
ck::index_t BatchStrideA0 = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideE1 = -1;
if(argc == 8)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
}
else if(argc == 13)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
M = std::stoi(argv[8]);
N = std::stoi(argv[9]);
K = std::stoi(argv[10]);
O = std::stoi(argv[11]);
BatchCount = std::stoi(argv[12]);
}
else if(argc == 21)
{
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
do_verification = std::stoi(argv[4]);
init_method = std::stoi(argv[5]);
do_log = std::stoi(argv[6]);
time_kernel = std::stoi(argv[7]);
M = std::stoi(argv[8]);
N = std::stoi(argv[9]);
K = std::stoi(argv[10]);
O = std::stoi(argv[11]);
BatchCount = std::stoi(argv[12]);
StrideA0 = std::stoi(argv[13]);
StrideB0 = std::stoi(argv[14]);
StrideB1 = std::stoi(argv[15]);
StrideE1 = std::stoi(argv[16]);
BatchStrideA0 = std::stoi(argv[17]);
BatchStrideB0 = std::stoi(argv[18]);
BatchStrideB1 = std::stoi(argv[19]);
BatchStrideE1 = std::stoi(argv[20]);
}
else
{
printf("arg1: tensor operation (batched_gemm_gemm: Batched_GEMM+Gemm)\n");
printf("arg2: data type (1: fp16)\n");
printf("arg3: matrix layout (0: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[n, o] + D1[m, o] "
"= E1[m, o]; 1: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[o, n] + D1[m, o] = E1[m, "
"o];)\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 12: M, N, K, O, Batch\n");
printf("arg13 to 16: StrideA0, StrideB0, StrideB1, StrideE1\n");
printf("arg17 to 20: BatchStrideA0, BatchStrideB0, BatchStrideB1, BatchStrideE1 \n");
exit(1);
}
if(data_type == GemmDataType::F16_F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_NO_MO)
{
ck::profiler::profile_batched_gemm_gemm_impl<F16, // A0DataType,
F16, // B0DataType,
F16, // B1DataType,
F16, // E1DataType,
Row, // A0Layout,
Col, // B0Layout,
Row, // B1Layout,
Row> // E1Layout,
(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
StrideB1,
StrideE1,
BatchStrideA0,
BatchStrideB0,
BatchStrideB1,
BatchStrideE1);
}
else if(data_type == GemmDataType::F16_F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_ON_MO)
{
ck::profiler::profile_batched_gemm_gemm_impl<F16, // A0DataType,
F16, // B0DataType,
F16, // B1DataType,
F16, // E1DataType,
Row, // A0Layout,
Col, // B0Layout,
Col, // B1Layout,
Row> // E1Layout,
(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
StrideB1,
StrideE1,
BatchStrideA0,
BatchStrideB0,
BatchStrideB1,
BatchStrideE1);
}
else
{
throw std::runtime_error("wrong! this data_type & layout is not implemented");
}
return 0;
}
......@@ -10,6 +10,8 @@ int profile_gemm_add_add_fastgelu(int, char*[]);
int profile_gemm_reduce(int, char*[]);
int profile_gemm_bias_add_reduce(int, char*[]);
int profile_batched_gemm(int, char*[]);
int profile_batched_gemm_gemm(int, char*[]);
int profile_batched_gemm_add_relu_gemm_add(int, char*[]);
int profile_batched_gemm_reduce(int, char*[]);
int profile_grouped_gemm(int, char*[]);
int profile_conv_fwd(int, char*[]);
......@@ -32,6 +34,8 @@ static void print_helper_message()
" gemm_reduce: GEMM+Reduce\n"
" gemm_bias_add_reduce: GEMM+Bias+Add+Reduce\n"
" batched_gemm: Batched GEMM\n"
" batched_gemm_gemm: Batched+GEMM+GEMM\n"
" batched_gemm_add_relu_gemm_add: Batched+GEMM+bias+gelu+GEMM+bias\n"
" batched_gemm_reduce: Batched GEMM+Reduce\n"
" grouped_gemm: Grouped GEMM\n"
" conv_fwd: Convolution Forward\n"
......@@ -80,6 +84,14 @@ int main(int argc, char* argv[])
{
return profile_batched_gemm(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm_gemm") == 0)
{
return profile_batched_gemm_gemm(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm_add_relu_gemm_add") == 0)
{
return profile_batched_gemm_add_relu_gemm_add(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
{
return profile_batched_gemm_reduce(argc, argv);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment