"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "23220fe5c16e38575d186ec62bdca2c3f16951b4"
Commit dd29eb09 authored by Chao Liu's avatar Chao Liu
Browse files

gemm/conv activation fusion example

parent ac0d8066
......@@ -38,9 +38,43 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
struct Relu
{
__host__ __device__ constexpr void operator()(float& y, const float& x) const
{
const float a = x;
y = a > 0 ? a : 0;
}
__host__ __device__ constexpr void operator()(ck::half_t& y, const ck::half_t& x) const
{
const ck::half_t a = x;
y = a > 0 ? a : 0;
}
};
struct Hardswish
{
__host__ __device__ constexpr void operator()(float& y, const float& x) const
{
float a = x;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
__host__ __device__ constexpr void operator()(ck::half_t& y, const ck::half_t& x) const
{
float a = x;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
};
using AElementOp = Relu;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = Hardswish;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
......
......@@ -25,9 +25,43 @@ using AccDataType = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
struct Relu
{
__host__ __device__ constexpr void operator()(float& y, const float& x) const
{
const float a = x;
y = a > 0 ? a : 0;
}
__host__ __device__ constexpr void operator()(ck::half_t& y, const ck::half_t& x) const
{
const ck::half_t a = x;
y = a > 0 ? a : 0;
}
};
struct Hardswish
{
__host__ __device__ constexpr void operator()(float& y, const float& x) const
{
float a = x;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
__host__ __device__ constexpr void operator()(ck::half_t& y, const ck::half_t& x) const
{
float a = x;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
};
using InElementOp = Relu;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = Hardswish;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
......
......@@ -29,9 +29,43 @@ using InLayout = ck::tensor_layout::convolution::NHWC;
using WeiLayout = ck::tensor_layout::convolution::KYXC;
using OutLayout = ck::tensor_layout::convolution::NHWK;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
struct Relu
{
__host__ __device__ constexpr void operator()(float& y, const float& x) const
{
const float a = x;
y = a > 0 ? a : 0;
}
__host__ __device__ constexpr void operator()(ck::half_t& y, const ck::half_t& x) const
{
const ck::half_t a = x;
y = a > 0 ? a : 0;
}
};
struct Hardswish
{
__host__ __device__ constexpr void operator()(float& y, const float& x) const
{
float a = x;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
__host__ __device__ constexpr void operator()(ck::half_t& y, const ck::half_t& x) const
{
float a = x;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
};
using InElementOp = Relu;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = Hardswish;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
......
......@@ -72,8 +72,13 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
using ReferenceConvBwdWeightInstance = ck::tensor_operation::host::
ReferenceConvBwdWeight<InDataType, WeiDataType, OutDataType, InElementOp, WeiElementOp, OutElementOp>;
using ReferenceConvBwdWeightInstance =
ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
int main(int argc, char* argv[])
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment