"vscode:/vscode.git/clone" did not exist on "55028a24d3c851ab71028ed91e7fb235f51569b5"
Commit aea62819 authored by Chaitanya Inumella's avatar Chaitanya Inumella
Browse files

Rebase branch 'develop' of...

Rebase branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into contraction_hipTENSOR
parents 75af5450 75ab874e
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/device_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp"
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "ck/device_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// M/N/KPerTileType could be index_t or Number<>
template <GemmSpecialization GemmSpec,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType>
struct MatrixPadder
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
template <typename ADesc_MRaw_KRaw>
__host__ __device__ constexpr auto
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
{
const auto MRaw = a_desc_mraw_kraw.GetLength(I0);
const auto KRaw = a_desc_mraw_kraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
return transform_tensor_descriptor(a_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
return transform_tensor_descriptor(
a_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
return transform_tensor_descriptor(
a_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or K
return a_desc_mraw_kraw;
}
}
template <typename BDesc_NRaw_KRaw>
__host__ __device__ constexpr auto
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
{
const auto NRaw = b_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_desc_nraw_kraw.GetLength(I1);
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
return transform_tensor_descriptor(b_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
return transform_tensor_descriptor(
b_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad), make_pass_through_transform(KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
return transform_tensor_descriptor(
b_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad N or K
return b_desc_nraw_kraw;
}
}
template <typename CDesc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
{
const auto MRaw = c_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_desc_mraw_nraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_desc_mraw_nraw;
}
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -25,41 +25,146 @@ struct ColumnMajor : public BaseTensorLayout ...@@ -25,41 +25,146 @@ struct ColumnMajor : public BaseTensorLayout
namespace convolution { namespace convolution {
// 1D Conv // input tensor
// packed NCW/NCHW/NCDHW
struct NCW : public BaseTensorLayout
{
static constexpr const char* name = "NCW";
};
struct NCHW : public BaseTensorLayout
{
static constexpr const char* name = "NCHW";
};
struct NCDHW : public BaseTensorLayout
{
static constexpr const char* name = "NCDHW";
};
// packed GNCW/GNCHW/GNCDHW
struct GNCW : public BaseTensorLayout
{
static constexpr const char* name = "GNCW";
};
struct GNCHW : public BaseTensorLayout
{
static constexpr const char* name = "GNCHW";
};
struct GNCDHW : public BaseTensorLayout
{
static constexpr const char* name = "GNCDHW";
};
// input tensor
// packed NWC/NHWC/NDHWC
struct NWC : public BaseTensorLayout struct NWC : public BaseTensorLayout
{ {
static constexpr const char* name = "NWC"; static constexpr const char* name = "NWC";
}; };
struct KXC : public BaseTensorLayout struct NHWC : public BaseTensorLayout
{ {
static constexpr const char* name = "KXC"; static constexpr const char* name = "NHWC";
}; };
struct NWK : public BaseTensorLayout struct NDHWC : public BaseTensorLayout
{ {
static constexpr const char* name = "NWK"; static constexpr const char* name = "NDHWC";
}; };
struct NCW : public BaseTensorLayout // input tensor
// packed GNWC/GNHWC/GNDHWC
struct GNWC : public BaseTensorLayout
{ {
static constexpr const char* name = "NCW"; static constexpr const char* name = "GNWC";
};
struct GNHWC : public BaseTensorLayout
{
static constexpr const char* name = "GNHWC";
}; };
struct GNDHWC : public BaseTensorLayout
{
static constexpr const char* name = "GNDHWC";
};
// input tensor
// packed GNWC/GNHWC/GNDHWC
struct NWGC : public BaseTensorLayout
{
static constexpr const char* name = "NWGC";
};
struct NHWGC : public BaseTensorLayout
{
static constexpr const char* name = "NHWGC";
};
struct NDHWGC : public BaseTensorLayout
{
static constexpr const char* name = "NDHWGC";
};
// input tensor
// strided layout
struct G_NW_C : public BaseTensorLayout
{
static constexpr const char* name = "G_NW_C";
};
struct G_NHW_C : public BaseTensorLayout
{
static constexpr const char* name = "G_NHW_C";
};
struct G_NDHW_C : public BaseTensorLayout
{
static constexpr const char* name = "G_NDHW_C";
};
// weight tensor
// packed KCX/KCYX/KCZYX
struct KCX : public BaseTensorLayout struct KCX : public BaseTensorLayout
{ {
static constexpr const char* name = "KCX"; static constexpr const char* name = "KCX";
}; };
struct NKW : public BaseTensorLayout struct KCYX : public BaseTensorLayout
{ {
static constexpr const char* name = "NKW"; static constexpr const char* name = "KCYX";
}; };
// 2D Conv struct KCZYX : public BaseTensorLayout
struct NHWC : public BaseTensorLayout
{ {
static constexpr const char* name = "NHWC"; static constexpr const char* name = "KCZYX";
};
// weight tensor
// packed KCX/KCYX/KCZYX
struct GKCX : public BaseTensorLayout
{
static constexpr const char* name = "GKCX";
};
struct GKCYX : public BaseTensorLayout
{
static constexpr const char* name = "GKCYX";
};
struct GKCZYX : public BaseTensorLayout
{
static constexpr const char* name = "GKCZYX";
};
// weight tensor
// packed KXC/KYXC/KZYXC
struct KXC : public BaseTensorLayout
{
static constexpr const char* name = "KXC";
}; };
struct KYXC : public BaseTensorLayout struct KYXC : public BaseTensorLayout
...@@ -67,19 +172,67 @@ struct KYXC : public BaseTensorLayout ...@@ -67,19 +172,67 @@ struct KYXC : public BaseTensorLayout
static constexpr const char* name = "KYXC"; static constexpr const char* name = "KYXC";
}; };
struct NHWK : public BaseTensorLayout struct KZYXC : public BaseTensorLayout
{ {
static constexpr const char* name = "NHWK"; static constexpr const char* name = "KZYXC";
}; };
struct NCHW : public BaseTensorLayout // weight tensor
// packed GKXC/GKYXC/GKZYXC
struct GKXC : public BaseTensorLayout
{ {
static constexpr const char* name = "NCHW"; static constexpr const char* name = "GKXC";
}; };
struct KCYX : public BaseTensorLayout struct GKYXC : public BaseTensorLayout
{ {
static constexpr const char* name = "KCYX"; static constexpr const char* name = "GKYXC";
};
struct GKZYXC : public BaseTensorLayout
{
static constexpr const char* name = "GKZYXC";
};
// weight tensor
// packed KXGC/KYXGC/KZYXGC
struct KXGC : public BaseTensorLayout
{
static constexpr const char* name = "KXGC";
};
struct KYXGC : public BaseTensorLayout
{
static constexpr const char* name = "KYXGC";
};
struct KZYXGC : public BaseTensorLayout
{
static constexpr const char* name = "KZYXGC";
};
// weight tensor
// strided
struct G_K_X_C : public BaseTensorLayout
{
static constexpr const char* name = "G_K_X_C";
};
struct G_K_YX_C : public BaseTensorLayout
{
static constexpr const char* name = "G_K_YX_C";
};
struct G_K_ZYX_C : public BaseTensorLayout
{
static constexpr const char* name = "G_K_ZYX_C";
};
// output tensor
// packed NKW/NKHW/NKDHW
struct NKW : public BaseTensorLayout
{
static constexpr const char* name = "NKW";
}; };
struct NKHW : public BaseTensorLayout struct NKHW : public BaseTensorLayout
...@@ -87,34 +240,94 @@ struct NKHW : public BaseTensorLayout ...@@ -87,34 +240,94 @@ struct NKHW : public BaseTensorLayout
static constexpr const char* name = "NKHW"; static constexpr const char* name = "NKHW";
}; };
// 3D Conv struct NKDHW : public BaseTensorLayout
struct NDHWC : public BaseTensorLayout
{ {
static constexpr const char* name = "NDHWC"; static constexpr const char* name = "NKDHW";
}; };
struct KZYXC : public BaseTensorLayout // output tensor
// packed GNKW/GNKHW/GNKDHW
struct GNKW : public BaseTensorLayout
{ {
static constexpr const char* name = "KZYXC"; static constexpr const char* name = "GNKW";
};
struct GNKHW : public BaseTensorLayout
{
static constexpr const char* name = "GNKHW";
};
struct GNKDHW : public BaseTensorLayout
{
static constexpr const char* name = "GNKDHW";
};
// output tensor
// packed NWK/NHWK/NDHWK
struct NWK : public BaseTensorLayout
{
static constexpr const char* name = "NWK";
};
struct NHWK : public BaseTensorLayout
{
static constexpr const char* name = "NHWK";
}; };
struct NDHWK : public BaseTensorLayout struct NDHWK : public BaseTensorLayout
{ {
static constexpr const char* name = "NDHWK"; static constexpr const char* name = "NDHWK";
}; };
struct NCDHW : public BaseTensorLayout
// output tensor
// packed GNWK/GNHWK/GNDHWK
struct GNWK : public BaseTensorLayout
{ {
static constexpr const char* name = "NCDHW"; static constexpr const char* name = "GNWK";
}; };
struct KCZYX : public BaseTensorLayout struct GNHWK : public BaseTensorLayout
{ {
static constexpr const char* name = "KCZYX"; static constexpr const char* name = "GNHWK";
}; };
struct NKDHW : public BaseTensorLayout struct GNDHWK : public BaseTensorLayout
{ {
static constexpr const char* name = "NKDHW"; static constexpr const char* name = "GNDHWK";
};
// output tensor
// packed NWGK/NHWGK/NDHWGK
struct NWGK : public BaseTensorLayout
{
static constexpr const char* name = "NWGK";
};
struct NHWGK : public BaseTensorLayout
{
static constexpr const char* name = "NHWGK";
};
struct NDHWGK : public BaseTensorLayout
{
static constexpr const char* name = "NDHWGK";
};
// output tensor
// strided layout
struct G_NW_K : public BaseTensorLayout
{
static constexpr const char* name = "G_NW_K";
};
struct G_NHW_K : public BaseTensorLayout
{
static constexpr const char* name = "G_NHW_K";
};
struct G_NDHW_K : public BaseTensorLayout
{
static constexpr const char* name = "G_NDHW_K";
}; };
} // namespace convolution } // namespace convolution
......
...@@ -51,6 +51,13 @@ struct Add ...@@ -51,6 +51,13 @@ struct Add
const float y_tmp = x1_tmp + x2_tmp; const float y_tmp = x1_tmp + x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp); y = ck::type_convert<bhalf_t>(y_tmp);
} }
template <>
__host__ __device__ constexpr void
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
y = x0 + x1;
};
}; };
struct Subtract struct Subtract
...@@ -88,6 +95,13 @@ struct Subtract ...@@ -88,6 +95,13 @@ struct Subtract
const float y_tmp = x1_tmp - x2_tmp; const float y_tmp = x1_tmp - x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp); y = ck::type_convert<bhalf_t>(y_tmp);
} }
template <>
__host__ __device__ constexpr void
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
y = x0 - x1;
};
}; };
struct Bilinear struct Bilinear
...@@ -104,6 +118,13 @@ struct Bilinear ...@@ -104,6 +118,13 @@ struct Bilinear
y = alpha_ * x0 + beta_ * x1; y = alpha_ * x0 + beta_ * x1;
}; };
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = type_convert<half_t>(alpha_) * x0 + type_convert<half_t>(beta_) * x1;
};
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
...@@ -117,12 +138,12 @@ struct Bilinear ...@@ -117,12 +138,12 @@ struct Bilinear
struct AddRelu struct AddRelu
{ {
template <typename T> template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const operator()<float, float, float>(float& y, const float& x0, const float& x1) const
{ {
const float a = x0 + x1; const float a = x0 + x1;
y = a > 0.0f ? a : 0.0f; y = a > 0.0f ? a : 0.0f;
...@@ -130,7 +151,7 @@ struct AddRelu ...@@ -130,7 +151,7 @@ struct AddRelu
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const operator()<double, double, double>(double& y, const double& x0, const double& x1) const
{ {
const double a = x0 + x1; const double a = x0 + x1;
y = a > 0.0 ? a : 0.0; y = a > 0.0 ? a : 0.0;
...@@ -138,11 +159,19 @@ struct AddRelu ...@@ -138,11 +159,19 @@ struct AddRelu
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{ {
const half_t a = x0 + x1; const half_t a = x0 + x1;
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f); y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
}; };
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
{
const float a = x0 + x1;
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
};
}; };
struct AddHardswish struct AddHardswish
......
...@@ -12,16 +12,65 @@ namespace element_wise { ...@@ -12,16 +12,65 @@ namespace element_wise {
struct PassThrough struct PassThrough
{ {
template <typename T> template <typename Y, typename X>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || y = x;
is_same<T, half_t>::value || is_same<T, bhalf_t>::value || }
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = x; y = x;
}; }
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
y = type_convert<int8_t>(x);
}
};
struct UnaryConvert
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
y = type_convert<Y>(x);
}
}; };
struct Scale struct Scale
......
...@@ -18,25 +18,26 @@ ...@@ -18,25 +18,26 @@
namespace ck { namespace ck {
// GEMM: // GEMM:
// input : A[AK0, M, AK1] // input : A[M, K]
// input : B[AK0, N, AK1] // input : B[N, K]
// input : D0[M, N], D1[M, N], ... // input : D0[M, N], D1[M, N], ...
// output : E[M, N] // output : E[M, N]
// C = a_op(A) * b_op(B) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
// Assume: // Assume:
// D0, D1, ... and E have the same layout // D0, D1, ... and E have the same layout
template <typename FloatAB, template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename FloatGemmAcc, typename AccDataType,
typename FloatCShuffle, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
typename FloatE, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation, InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_M_K,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N, typename EGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -70,7 +71,7 @@ template <typename FloatAB, ...@@ -70,7 +71,7 @@ template <typename FloatAB,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched>
struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle struct GridwiseGemmMultipleD_xdl_cshuffle
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -84,10 +85,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -84,10 +85,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -97,7 +98,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -97,7 +98,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1), make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1)); make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
} }
...@@ -105,7 +106,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -105,7 +106,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1), make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
} }
...@@ -160,31 +161,123 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -160,31 +161,123 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB), sizeof(ABDataType),
c_block_size * sizeof(FloatCShuffle)); c_block_size * sizeof(CShuffleDataType));
}
// A desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// E desc for destination in blockwise copy
template <typename EGridDescriptor_M_N>
__host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const EGridDescriptor_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap> template <typename Block2ETileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const BGridDesc_N_K& b_grid_desc_n_k,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
{
return false; return false;
}
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
return false; return false;
}
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock; const auto num_k_loop = K / KPerBlock;
...@@ -194,12 +287,23 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -194,12 +287,23 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return false; return false;
} }
// check block-to-E-tile
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{ {
return false; return false;
} }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true; return true;
} }
...@@ -210,60 +314,39 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -210,60 +314,39 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
__host__ __device__ static constexpr auto using DefaultAGridDesc_AK0_M_AK1 =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
{ using DefaultBGridDesc_BK0_N_BK1 =
const auto M = e_grid_desc_m_n.GetLength(I0); remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
const auto N = e_grid_desc_m_n.GetLength(I1); using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap = using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using DsGridPointer = decltype(MakeDsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop, typename Block2ETileMap> template <bool HasMainKBlockLoop,
__device__ static void typename AGridDesc_AK0_M_AK1,
Run(const FloatAB* __restrict__ p_a_grid, typename BGridDesc_BK0_N_BK1,
const FloatAB* __restrict__ p_b_grid, typename Block2ETileMap>
DsGridPointer p_ds_grid, __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
FloatE* __restrict__ p_e_grid, const ABDataType* __restrict__ p_b_grid,
void* __restrict__ p_shared, DsGridPointer p_ds_grid,
const AElementwiseOperation& a_element_op, EDataType* __restrict__ p_e_grid,
const BElementwiseOperation& b_element_op, void* __restrict__ p_shared,
const CDEElementwiseOperation& cde_element_op, const AElementwiseOperation& a_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const BElementwiseOperation& b_element_op,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const CDEElementwiseOperation& cde_element_op,
const StaticallyIndexedArray<EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
NumDTensor>& const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: Ds desc may be of different const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
// type from E ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -316,11 +399,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -316,11 +399,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, ABDataType,
FloatAB, ABDataType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -347,11 +430,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -347,11 +430,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, ABDataType,
FloatAB, ABDataType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -379,13 +462,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -379,13 +462,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t KPack = math::max( constexpr index_t KPack =
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, ABDataType,
FloatGemmAcc, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
MPerXdl, MPerXdl,
...@@ -402,10 +486,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -402,10 +486,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
...@@ -466,7 +550,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -466,7 +550,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared), static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
...@@ -518,8 +602,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -518,8 +602,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// shuffle: threadwise copy C from VGPR to LDS // shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
FloatCShuffle, CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -576,8 +660,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -576,8 +660,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// blockwise copy C/D/E between LDS and global // blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, ThisThreadBlock,
decltype(container_concat(make_tuple(FloatCShuffle{}), DsDataType{})), decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<FloatE>, Tuple<EDataType>,
decltype(c_ds_desc_refs), decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, CDEElementwiseOperation,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_K>
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_K gamma_grid_desc_k,
const GridDesc_K beta_grid_desc_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op)
{
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_k,
beta_grid_desc_k,
y_grid_desc_m_k,
num_k_block_tile_iteration,
epsilon,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global,
acc_elementwise_op);
};
// Y = LayerNorm(X, Beta, Gamma)
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
bool SweepOnce>
struct GridwiseLayernorm_mk_to_mk
{
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseSumReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
true>;
using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Add,
true>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_K& gamma_grid_desc_k,
const GridDesc_K& beta_grid_desc_k,
const GridDesc_M_K& y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op)
{
if constexpr(SweepOnce)
{
num_k_block_tile_iteration = 1;
}
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true> gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true>& beta_thread_buf =
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>& x_square_thread_buf = y_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& var_value_buf =
mean_square_thread_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_K = Sequence<KThreadSliceSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_k =
make_naive_tensor_descriptor_packed(make_tuple(Number<KThreadSliceSize>{}));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
AccDataType,
GridDesc_K,
decltype(thread_buffer_desc_k),
ThreadBufferLengths_K,
Sequence<0>,
0,
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType,
GridDesc_K,
decltype(thread_buffer_desc_k),
ThreadBufferLengths_K,
Sequence<0>,
0,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType,
decltype(thread_buffer_desc_m_k),
GridDesc_M_K,
AccElementwiseOperation,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
YDstVectorDim,
YDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
y_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
acc_elementwise_op);
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr auto thread_copy_fwd_step_k = make_multi_index(SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_bwd_step_k = make_multi_index(SweepOnce ? 0 : -K_BlockTileSize);
constexpr auto thread_copy_fwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_k.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_k.GetElementSpaceSize());
// E(x), E[x^2], var(x)
int reduce_length = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1];
index_t reducedTiles = 0;
do
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
x_square_thread_buf(Number<offset_m_k>{}) =
x_thread_buf(Number<offset_m_k>{}) * x_thread_buf(Number<offset_m_k>{});
});
});
ThreadwiseSumReduce::Reduce(x_thread_buf, mean_thread_buf);
ThreadwiseSumReduce::Reduce(x_square_thread_buf, mean_square_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
++reducedTiles;
} while(reducedTiles < num_k_block_tile_iteration);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
// var(x) = E[x^2] - E[x]^2
var_value_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
auto thread_copy_tail_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_k;
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_tail_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_tail_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
reducedTiles = 0;
do
{
if constexpr(!SweepOnce)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
}
threadwise_gamma_load.Run(gamma_grid_desc_k,
gamma_global_val_buf,
thread_buffer_desc_k,
make_tuple(I0),
gamma_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
// normalize
y_thread_buf(Number<offset_m_k>{}) =
(x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
sqrt(var_value_buf(iM) + epsilon);
// gamma
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_k>{});
});
});
threadwise_beta_load.Run(beta_grid_desc_k,
beta_global_val_buf,
thread_buffer_desc_k,
make_tuple(I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
// beta
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_k>{});
});
});
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf,
y_grid_desc_m_k,
y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_bwd_step_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_bwd_step_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
++reducedTiles;
} while(reducedTiles < num_k_block_tile_iteration);
}
};
} // namespace ck
...@@ -250,8 +250,10 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -250,8 +250,10 @@ struct GridwiseSoftmax_mk_to_mk
reducedTiles++; reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
static_for<0, MThreadSliceSize, 1>{}( static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
[&](auto I) { BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); }); BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
block_sync_lds();
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
...@@ -303,9 +305,10 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -303,9 +305,10 @@ struct GridwiseSoftmax_mk_to_mk
reducedTiles++; reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
block_sync_lds(); // wait for reading being complete before writing to LDS
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I)); BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I));
// block_sync_lds(); block_sync_lds();
}); });
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "enable_if.hpp" #include "enable_if.hpp"
#include "c_style_pointer_cast.hpp" #include "c_style_pointer_cast.hpp"
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
......
...@@ -21,6 +21,8 @@ struct TupleElementKey ...@@ -21,6 +21,8 @@ struct TupleElementKey
template <typename Key, typename Data> template <typename Key, typename Data>
struct TupleElementKeyData struct TupleElementKeyData
{ {
using DataType = Data;
#if 0 // workaround compiler complaint about implicitly-deleted default constructor #if 0 // workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default; __host__ __device__ constexpr TupleElementKeyData() = default;
#else #else
...@@ -34,29 +36,40 @@ struct TupleElementKeyData ...@@ -34,29 +36,40 @@ struct TupleElementKeyData
{ {
} }
Data mData; DataType mData;
}; };
// for read access of tuple element
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr const Data& __host__ __device__ constexpr const Data&
get_tuple_element_data(const TupleElementKeyData<Key, Data>& x) get_tuple_element_data_reference(const TupleElementKeyData<Key, Data>& x)
{ {
return static_cast<const Data&>(x.mData); return static_cast<const Data&>(x.mData);
} }
// for write access of tuple element
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element_data(TupleElementKeyData<Key, Data>& x) __host__ __device__ constexpr Data&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>& x)
{ {
return x.mData; return x.mData;
} }
// TODO: not sure the use of reference is correct // TODO: not sure the use of reference is correct
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element_data(TupleElementKeyData<Key, Data>&& x) __host__ __device__ constexpr Data&&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
{ {
return static_cast<Data&&>(x.mData); return static_cast<Data&&>(x.mData);
} }
// for infering type of tuple element
template <typename Key, typename Data>
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{
return std::forward(x.mData);
}
template <typename Indices, typename... Xs> template <typename Indices, typename... Xs>
struct TupleImpl; struct TupleImpl;
...@@ -87,13 +100,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I ...@@ -87,13 +100,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
template <index_t I> template <index_t I>
__host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const
{ {
return get_tuple_element_data<TupleElementKey<I>>(*this); return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
} }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>) __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>)
{ {
return get_tuple_element_data<TupleElementKey<I>>(*this); return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
} }
}; };
...@@ -185,7 +198,8 @@ struct Tuple<> ...@@ -185,7 +198,8 @@ struct Tuple<>
template <index_t I, typename TTuple> template <index_t I, typename TTuple>
struct tuple_element struct tuple_element
{ {
using type = decltype(TTuple{}.At(Number<I>{})); // type should keep the cv/ref qualifier of original tuple element
using type = decltype(detail::get_tuple_element_data<detail::TupleElementKey<I>>(TTuple{}));
}; };
template <index_t I, typename TTuple> template <index_t I, typename TTuple>
......
add_subdirectory(src/tensor_operation_instance/gpu) add_subdirectory(src/tensor_operation_instance/gpu)
add_subdirectory(src/host_tensor)
add_subdirectory(src/utility) add_subdirectory(src/utility)
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -6,8 +6,9 @@ ...@@ -6,8 +6,9 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -91,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator ...@@ -91,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator
v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag; v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag;
} }
arg.c_m_n_real_(m, n) = v_c_real; arg.c_m_n_real_(m, n) = ck::type_convert<CDataType>(v_c_real);
}; };
auto f_mk_kn_mn_imag = [&](auto m, auto n) { auto f_mk_kn_mn_imag = [&](auto m, auto n) {
...@@ -107,7 +108,7 @@ struct ReferenceCGemm : public device::BaseOperator ...@@ -107,7 +108,7 @@ struct ReferenceCGemm : public device::BaseOperator
v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real; v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real;
} }
arg.c_m_n_imag_(m, n) = v_c_imag; arg.c_m_n_imag_(m, n) = ck::type_convert<CDataType>(v_c_imag);
}; };
make_ParallelTensorFunctor(f_mk_kn_mn_real, make_ParallelTensorFunctor(f_mk_kn_mn_real,
......
...@@ -8,22 +8,24 @@ ...@@ -8,22 +8,24 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] // input descriptor in [G, N, C, Do, Ho, Wo] order
template <typename InDataType, // weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AccDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2, typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdData : public device::BaseOperator struct ReferenceConvBwdData : public device::BaseOperator
{ {
// Argument // Argument
...@@ -73,36 +75,45 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -73,36 +75,45 @@ struct ReferenceConvBwdData : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if constexpr(NumDimSpatial == 1) if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{ {
auto f_ncw = [&](auto n, auto c, auto wi) { throw std::runtime_error("wrong! inconsistent dimension");
std::size_t K = arg.weight_.mDesc.GetLengths()[0]; }
std::size_t X = arg.weight_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[2]; if constexpr(NDimSpatial == 1)
{
auto f_ncw = [&](auto g, auto n, auto c, auto wi) {
std::size_t K = arg.weight_.GetLengths()[1];
std::size_t X = arg.weight_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[3];
AccDataType v_acc = 0; float v_acc = 0;
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
auto w_tmp = ck::type_convert<ck::long_index_t>(wi) + auto w_tmp = static_cast<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0) if(w_tmp % arg.conv_strides_[0] == 0)
{ {
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) / auto wo = static_cast<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo) if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; float v_out = 0;
AccDataType v_wei = 0; float v_wei = 0;
arg.out_element_op_( arg.out_element_op_(
v_out, v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
ck::type_convert<AccDataType>(arg.output_(n, k, wo)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, ck::type_convert<AccDataType>(arg.weight_(k, c, x))); v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -110,66 +121,72 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -110,66 +121,72 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
arg.in_element_op_(v_acc, v_acc); float v_in;
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
arg.input_.mDesc.GetLengths()[0], arg.input_.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1], arg.input_.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2])( arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { auto f_nchw = [&](auto g, auto n, auto c, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0]; std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Y = arg.weight_.mDesc.GetLengths()[2]; std::size_t Y = arg.weight_.GetLengths()[3];
std::size_t X = arg.weight_.mDesc.GetLengths()[3]; std::size_t X = arg.weight_.GetLengths()[4];
std::size_t Ho = arg.output_.mDesc.GetLengths()[2]; std::size_t Ho = arg.output_.GetLengths()[3];
std::size_t Wo = arg.output_.mDesc.GetLengths()[3]; std::size_t Wo = arg.output_.GetLengths()[4];
AccDataType v_acc = 0; float v_acc = 0;
for(std::size_t y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
auto h_tmp = ck::type_convert<ck::long_index_t>(hi) + auto h_tmp = static_cast<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0) if(h_tmp % arg.conv_strides_[0] == 0)
{ {
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) / auto ho = static_cast<ck::long_index_t>(h_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho) if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
auto w_tmp = auto w_tmp =
ck::type_convert<ck::long_index_t>(wi) + static_cast<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(x * static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]);
arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0) if(w_tmp % arg.conv_strides_[1] == 0)
{ {
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) / auto wo =
ck::type_convert<ck::long_index_t>( static_cast<ck::long_index_t>(w_tmp) /
arg.conv_strides_[1]); static_cast<ck::long_index_t>(arg.conv_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo) if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; float v_out = 0;
AccDataType v_wei = 0; float v_wei = 0;
arg.out_element_op_(
v_out,
ck::type_convert<float>(
arg.output_(g, n, k, ho, wo)));
arg.out_element_op_(v_out, arg.wei_element_op_(
ck::type_convert<AccDataType>( v_wei,
arg.output_(n, k, ho, wo))); ck::type_convert<float>(
arg.wei_element_op_(v_wei, arg.weight_(g, k, c, y, x)));
ck::type_convert<AccDataType>(
arg.weight_(k, c, y, x)));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -180,90 +197,91 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -180,90 +197,91 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
AccDataType v_in; float v_in;
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
arg.input_.mDesc.GetLengths()[0], arg.input_.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1], arg.input_.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2], arg.input_.GetLengths()[2],
arg.input_.mDesc.GetLengths()[3])( arg.input_.GetLengths()[3],
arg.input_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) { auto f_ncdhw = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0]; std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Z = arg.weight_.mDesc.GetLengths()[2]; std::size_t Z = arg.weight_.GetLengths()[3];
std::size_t Y = arg.weight_.mDesc.GetLengths()[3]; std::size_t Y = arg.weight_.GetLengths()[4];
std::size_t X = arg.weight_.mDesc.GetLengths()[4]; std::size_t X = arg.weight_.GetLengths()[5];
std::size_t Do = arg.output_.mDesc.GetLengths()[2]; std::size_t Do = arg.output_.GetLengths()[3];
std::size_t Ho = arg.output_.mDesc.GetLengths()[3]; std::size_t Ho = arg.output_.GetLengths()[4];
std::size_t Wo = arg.output_.mDesc.GetLengths()[4]; std::size_t Wo = arg.output_.GetLengths()[5];
AccDataType v_acc = 0; float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z) for(std::size_t z = 0; z < Z; ++z)
{ {
auto d_tmp = ck::type_convert<ck::long_index_t>(di) + auto d_tmp = static_cast<ck::long_index_t>(di) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0) if(d_tmp % arg.conv_strides_[0] == 0)
{ {
auto do_ = ck::type_convert<ck::long_index_t>(d_tmp) / auto do_ = static_cast<ck::long_index_t>(d_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do) if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{ {
for(std::size_t y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
auto h_tmp = auto h_tmp =
ck::type_convert<ck::long_index_t>(hi) + static_cast<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(y * static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]);
arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0) if(h_tmp % arg.conv_strides_[1] == 0)
{ {
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) / auto ho =
ck::type_convert<ck::long_index_t>( static_cast<ck::long_index_t>(h_tmp) /
arg.conv_strides_[1]); static_cast<ck::long_index_t>(arg.conv_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho) if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
auto w_tmp = auto w_tmp = static_cast<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(wi) + static_cast<ck::long_index_t>(
ck::type_convert<ck::long_index_t>( arg.in_left_pads_[2]) -
arg.in_left_pads_[2]) - static_cast<ck::long_index_t>(
ck::type_convert<ck::long_index_t>( x * arg.conv_dilations_[2]);
x * arg.conv_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0) if(w_tmp % arg.conv_strides_[2] == 0)
{ {
auto wo = auto wo = static_cast<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(w_tmp) / static_cast<ck::long_index_t>(
ck::type_convert<ck::long_index_t>( arg.conv_strides_[2]);
arg.conv_strides_[2]);
if(wo >= 0 && if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo) ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; float v_out = 0;
AccDataType v_wei = 0; float v_wei = 0;
arg.out_element_op_( arg.out_element_op_(
v_out, v_out,
ck::type_convert<AccDataType>( ck::type_convert<float>(arg.output_(
arg.output_( g, n, k, do_, ho, wo)));
n, k, do_, ho, wo)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei,
ck::type_convert<AccDataType>( ck::type_convert<float>(
arg.weight_(k, c, z, y, x))); arg.weight_(g, k, c, z, y, x)));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -277,17 +295,20 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -277,17 +295,20 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
AccDataType v_in; float v_in;
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
arg.input_.mDesc.GetLengths()[0], arg.input_.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1], arg.input_.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2], arg.input_.GetLengths()[2],
arg.input_.mDesc.GetLengths()[3], arg.input_.GetLengths()[3],
arg.input_.mDesc.GetLengths()[4])( arg.input_.GetLengths()[4],
arg.input_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -7,21 +7,25 @@ ...@@ -7,21 +7,25 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] // input descriptor in [G, N, C, Do, Ho, Wo] order
template <typename InDataType, // weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2, typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
// Argument // Argument
...@@ -71,156 +75,162 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -71,156 +75,162 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if constexpr(NumDimSpatial == 1) if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{ {
constexpr auto I0 = Number<0>{}; throw std::runtime_error("wrong! inconsistent dimension");
auto f_kcx = [&](auto k, auto c, auto x) { }
if constexpr(NDimSpatial == 1)
{
auto f_kcx = [&](auto g, auto k, auto c, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo) for(std::size_t wo = 0; wo < arg.output_.GetLengths()[3]; ++wo)
{ {
auto wi = auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I0]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_(v_out, arg.out_element_op_(
ck::type_convert<float>(arg.output_(n, k, wo))); v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
arg.in_element_op_(v_in,
ck::type_convert<float>(arg.input_(n, c, wi))); arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
} }
} }
float v_wei; float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, x) = ck::type_convert<WeiDataType>(v_wei); arg.weight_(g, k, c, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcx, make_ParallelTensorFunctor(f_kcx,
arg.weight_.mDesc.GetLengths()[0], arg.weight_.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1], arg.weight_.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2])( arg.weight_.GetLengths()[2],
arg.weight_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
constexpr auto I0 = Number<0>{}; auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) {
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho) for(std::size_t ho = 0; ho < arg.output_.GetLengths()[3]; ++ho)
{ {
auto hi = auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) + static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo) for(std::size_t wo = 0; wo < arg.output_.GetLengths()[4]; ++wo)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
arg.conv_dilations_[I1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
arg.input_.mDesc.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_( arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(n, k, ho, wo))); v_out,
ck::type_convert<float>(arg.output_(g, n, k, ho, wo)));
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
} }
} }
} }
float v_wei; float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, y, x) = ck::type_convert<WeiDataType>(v_wei); arg.weight_(g, k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcyx, make_ParallelTensorFunctor(f_kcyx,
arg.weight_.mDesc.GetLengths()[0], arg.weight_.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1], arg.weight_.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2], arg.weight_.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3])( arg.weight_.GetLengths()[3],
arg.weight_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
constexpr auto I0 = Number<0>{}; auto f_kczyx = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_) for(std::size_t do_ = 0; do_ < arg.output_.GetLengths()[3]; ++do_)
{ {
auto di = auto di = static_cast<ck::long_index_t>(do_ * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(do_ * arg.conv_strides_[I0]) + static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[I0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]); for(std::size_t ho = 0; ho < arg.output_.GetLengths()[4]; ++ho)
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho)
{ {
auto hi = auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I1]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(y * static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
arg.conv_dilations_[I1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]); for(std::size_t wo = 0; wo < arg.output_.GetLengths()[5]; ++wo)
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4];
++wo)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
arg.conv_strides_[I2]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>( static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
x * arg.conv_dilations_[I2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I2]);
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] && arg.input_.GetLengths()[3] &&
hi >= 0 && hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] && arg.input_.GetLengths()[4] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4]) arg.input_.GetLengths()[5])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_(v_out, arg.out_element_op_(v_out,
ck::type_convert<float>( ck::type_convert<float>(
arg.output_(n, k, do_, ho, wo))); arg.output_(g, n, k, do_, ho, wo)));
arg.in_element_op_(
v_in, arg.in_element_op_(v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi))); ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
...@@ -228,19 +238,21 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -228,19 +238,21 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
} }
} }
} }
float v_wei; float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei); arg.weight_(g, k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kczyx, make_ParallelTensorFunctor(f_kczyx,
arg.weight_.mDesc.GetLengths()[0], arg.weight_.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1], arg.weight_.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2], arg.weight_.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3], arg.weight_.GetLengths()[3],
arg.weight_.mDesc.GetLengths()[4])( arg.weight_.GetLengths()[4],
arg.weight_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -17,9 +17,10 @@ namespace host { ...@@ -17,9 +17,10 @@ namespace host {
// //
// @brief Reference implementation for forward convolution. // @brief Reference implementation for forward convolution.
// //
// @paragraph Supports both NCHW as well as NHWC formats (and their respective // @paragraph
// counterparts for weight and output) as long as tensor descriptor // Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
// lengths is in NCHW. // Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
// //
// @tparam InDataType Input tensor data type. // @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type. // @tparam WeiDataType Weights tensor data type.
...@@ -28,16 +29,20 @@ namespace host { ...@@ -28,16 +29,20 @@ namespace host {
// operation. // operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise // @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation. // operation.
// @tparam NumDimSpatial Number of spatial dimensions. // @tparam NDimSpatial Number of spatial dimensions.
// //
template <typename InDataType, // input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2, typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator struct ReferenceConvFwd : public device::BaseOperator
{ {
// Argument // Argument
...@@ -86,29 +91,37 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -86,29 +91,37 @@ struct ReferenceConvFwd : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if constexpr(NumDimSpatial == 1) if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{ {
auto f_ncw = [&](auto n, auto k, auto wo) { throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto func = [&](auto g, auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[3]; ++x)
{ {
auto wi = auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
arg.in_element_op_(v_in, arg.in_element_op_(
ck::type_convert<float>(arg.input_(n, c, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
arg.wei_element_op_(v_wei,
ck::type_convert<float>(arg.weight_(k, c, x))); arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -118,50 +131,53 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -118,50 +131,53 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(func,
arg.output_.mDesc.GetLengths()[0], arg.output_.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], arg.output_.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2])( arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.weight_.GetLengths()[3]; ++y)
{ {
auto hi = auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[4]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
arg.input_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(k, c, y, x))); v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x)));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
} }
...@@ -171,64 +187,65 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -171,64 +187,65 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(func,
arg.output_.mDesc.GetLengths()[0], arg.output_.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], arg.output_.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2], arg.output_.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3])( arg.output_.GetLengths()[3],
arg.output_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) for(std::size_t z = 0; z < arg.weight_.GetLengths()[3]; ++z)
{ {
auto di = auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); for(std::size_t y = 0; y < arg.weight_.GetLengths()[4]; ++y)
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
{ {
auto hi = auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[5]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
arg.conv_strides_[2]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(x * static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] && arg.input_.GetLengths()[3] &&
hi >= 0 && hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] && arg.input_.GetLengths()[4] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4]) arg.input_.GetLengths()[5])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
arg.in_element_op_( arg.in_element_op_(v_in,
v_in, ck::type_convert<float>(
ck::type_convert<float>(arg.input_(n, c, di, hi, wi))); arg.input_(g, n, c, di, hi, wi)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei,
ck::type_convert<float>(arg.weight_(k, c, z, y, x))); ck::type_convert<float>(arg.weight_(g, k, c, z, y, x)));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
} }
...@@ -239,15 +256,17 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -239,15 +256,17 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(func,
arg.output_.mDesc.GetLengths()[0], arg.output_.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], arg.output_.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2], arg.output_.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3], arg.output_.GetLengths()[3],
arg.output_.mDesc.GetLengths()[4])( arg.output_.GetLengths()[4],
arg.output_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -267,7 +286,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -267,7 +286,10 @@ struct ReferenceConvFwd : public device::BaseOperator
return true; return true;
} }
bool IsSupportedArgument(const device::BaseArgument*) override { return true; } bool IsSupportedArgument(const device::BaseArgument*) override
{
return NDimSpatial >= 1 && NDimSpatial <= 3;
}
static auto MakeArgument(const Tensor<InDataType>& input, static auto MakeArgument(const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment