"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "a71270c0946c3d44a31beed5e5da441bd94f895f"
Commit 0b997ce4 authored by Chao Liu's avatar Chao Liu
Browse files

adding conv multiple D

parent 69d323de
...@@ -16,7 +16,7 @@ using S = ck::Sequence<Is...>; ...@@ -16,7 +16,7 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::UnaryConvert;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
...@@ -48,18 +48,18 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc ...@@ -48,18 +48,18 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1 8, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1 8, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN true, // BBlockLdsExtraN
7, // CThreadTransferSrcDstVectorDim 7, // CThreadTransferSrcDstVectorDim
1>; // CThreadTransferDstScalarPerVector 1>; // CThreadTransferDstScalarPerVector
#else #else
using CShuffleDataType = float; using CShuffleDataType = ck::half_t;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance = using DeviceConvNDFwdInstance =
...@@ -69,37 +69,40 @@ using DeviceConvNDFwdInstance = ...@@ -69,37 +69,40 @@ using DeviceConvNDFwdInstance =
WeiDataType, // WeiDataType, //
AccDataType, // AccDataType, //
CShuffleDataType, // CShuffleDataType, //
ck::Tuple<>, ck::Tuple<>, //
OutDataType, // OutDataType, //
InElementOp, // Input Elementwise Operation InElementOp, // Input Elementwise Operation
WeiElementOp, // Weights Elementwise Operation WeiElementOp, // Weights Elementwise Operation
OutElementOp, // Output Elementwise Operation OutElementOp, // Output Elementwise Operation
ConvFwdDefault, // ConvForwardSpecialization ConvFwdDefault, // ConvForwardSpecialization
256, // BlockSize 1, //
128, // MPerBlock 256, // BlockSize
256, // NPerBlock 128, // MPerBlock
4, // K0PerBlock 256, // NPerBlock
8, // K1 32, // KPerBlock
32, // MPerXdl 8, // K1
32, // NPerXdl 32, // MPerXdl
2, // MXdlPerWave 32, // NPerXdl
4, // NXdlPerWave 2, // MXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 4, // NXdlPerWave
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
2, // ABlockTransferSrcVectorDim S<1, 0, 2>, // ABlockTransferSrcAccessOrder
8, // ABlockTransferSrcScalarPerVector 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferDstScalarPerVector_K1 8, // ABlockTransferSrcScalarPerVector
true, // ABlockLdsAddExtraM 8, // ABlockTransferDstScalarPerVector_K1
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 1, // ABlockLdsExtraM
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
2, // BBlockTransferSrcVectorDim S<1, 0, 2>, // BBlockTransferSrcAccessOrder
8, // BBlockTransferSrcScalarPerVector 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferDstScalarPerVector_K1 8, // BBlockTransferSrcScalarPerVector
true, // BBlockLdsAddExtraN 8, // BBlockTransferDstScalarPerVector_K1
7, // CThreadTransferSrcDstVectorDim 1, // BBlockLdsExtraN
1>; // CThreadTransferDstScalarPerVector 1,
1,
S<1, 32, 1, 8>,
8>;
#endif #endif
int main(int argc, char* argv[]) int main(int argc, char* argv[])
......
...@@ -618,18 +618,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -618,18 +618,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.block_2_etile_map_); arg.block_2_etile_map_);
}; };
float ave_time = 0; float avg_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); avg_time = launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); avg_time = launch_kernel(integral_constant<bool, false>{});
} }
return ave_time; return avg_time;
} }
// polymorphic // polymorphic
......
...@@ -12,16 +12,47 @@ namespace element_wise { ...@@ -12,16 +12,47 @@ namespace element_wise {
struct PassThrough struct PassThrough
{ {
template <typename T> template <typename Y, typename X>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || y = x;
is_same<T, half_t>::value || is_same<T, bhalf_t>::value || }
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = x; y = x;
}; }
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
};
struct UnaryConvert
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
y = type_convert<Y>(x);
}
}; };
struct Scale struct Scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment