"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9d2fc6b53528bf0aa19a49f5b98848abac98275c"
Unverified Commit fd3d907a authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

fix ReLU formula (#61)

* fix relu

* clean up

* clean up
parent 41cdd380
...@@ -25,115 +25,76 @@ struct PassThrough ...@@ -25,115 +25,76 @@ struct PassThrough
struct Relu struct Relu
{ {
float alpha = 0.1;
// ReLU
template <typename T> template <typename T>
__host__ __device__ constexpr T operator()(T v) const __host__ __device__ constexpr T operator()(T v) const
{ {
T tmp = alpha * v; return v > 0 ? v : 0;
return tmp > 0 ? tmp : 0;
} }
}; };
template <typename ADataType, template <ck::index_t... Is>
typename BDataType, using S = ck::Sequence<Is...>;
typename CDataType,
typename ALayout, using ADataType = ck::half_t;
typename BLayout, using BDataType = ck::half_t;
typename CLayout, using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AOp = PassThrough;
using BOp = PassThrough;
using COp = Relu;
// Compilation parameters for NT problem
// clang-format off
using DeviceGemmInstance =
//#########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
// clang-format on
template <typename AType,
typename BType,
typename CType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGemmInstance; static void host_verify(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n,
template <typename AElementwiseOperation, Tensor<CType>& c_m_n,
typename BElementwiseOperation, const AElementwiseOperation& a_element_op,
typename CElementwiseOperation> const BElementwiseOperation& b_element_op,
struct DeviceGemmInstance<ck::half_t, const CElementwiseOperation& c_element_op)
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
using F16 = ck::half_t; auto f_mk_kn_mn = [&](auto m, auto n) {
using F32 = float; const int K = a_m_k.mDesc.GetLengths()[1];
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using AOp = AElementwiseOperation;
using BOp = BElementwiseOperation;
using COp = CElementwiseOperation;
// Compilation parameters for NT problem
// clang-format off
using type =
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
// clang-format on
};
template <typename AElementwiseOperation, double v = 0;
typename BElementwiseOperation,
typename CElementwiseOperation> for(int k = 0; k < K; ++k)
struct DeviceGemmInstance<float, {
float, v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
float, static_cast<const double>(b_element_op(b_k_n(k, n)));
ck::tensor_layout::gemm::RowMajor, }
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, c_m_n(m, n) = c_element_op(v);
AElementwiseOperation, };
BElementwiseOperation,
CElementwiseOperation> make_ParallelTensorFunctor(f_mk_kn_mn,
{ c_m_n.mDesc.GetLengths()[0],
using F16 = ck::half_t; c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
using F32 = float; }
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using AOp = AElementwiseOperation;
using BOp = BElementwiseOperation;
using COp = CElementwiseOperation;
// Compilation parameters for NT problem
// clang-format off
using type =
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, AOp, BOp, COp, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>;
// clang-format on
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
if(argc != 4) bool do_verification = 0;
{ int init_method = 0;
printf("arg1: verification (0=no, 1=yes)\n"); int nrepeat = 5;
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
exit(0);
}
const bool do_verification = std::stoi(argv[1]);
const int init_method = std::stoi(argv[2]);
const int nrepeat = std::stoi(argv[3]);
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -144,15 +105,34 @@ int main(int argc, char* argv[]) ...@@ -144,15 +105,34 @@ int main(int argc, char* argv[])
ck::index_t StrideB = 4096; ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096; ck::index_t StrideC = 4096;
// matrix data type if(argc == 4)
using ADataType = ck::half_t; {
using BDataType = ck::half_t; M = std::stoi(argv[4]);
using CDataType = ck::half_t; N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
// matrix layout StrideA = std::stoi(argv[7]);
using ALayout = ck::tensor_layout::gemm::RowMajor; StrideB = std::stoi(argv[8]);
using BLayout = ck::tensor_layout::gemm::ColumnMajor; StrideC = std::stoi(argv[9]);
using CLayout = ck::tensor_layout::gemm::RowMajor; }
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -198,16 +178,7 @@ int main(int argc, char* argv[]) ...@@ -198,16 +178,7 @@ int main(int argc, char* argv[])
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
// do GEMM // do GEMM
auto gemm = typename DeviceGemmInstance<ADataType, auto gemm = DeviceGemmInstance{};
BDataType,
CDataType,
ALayout,
BLayout,
CLayout,
PassThrough,
PassThrough,
Relu>::type{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
...@@ -218,9 +189,9 @@ int main(int argc, char* argv[]) ...@@ -218,9 +189,9 @@ int main(int argc, char* argv[])
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
PassThrough{}, AOp{},
PassThrough{}, BOp{},
Relu{}); COp{});
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -233,7 +204,7 @@ int main(int argc, char* argv[]) ...@@ -233,7 +204,7 @@ int main(int argc, char* argv[])
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -246,7 +217,7 @@ int main(int argc, char* argv[]) ...@@ -246,7 +217,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, Relu{}); host_verify(a_m_k, b_k_n, c_m_n_host_result, AOp{}, BOp{}, COp{});
check_error(c_m_n_host_result, c_m_n_device_result); check_error(c_m_n_host_result, c_m_n_device_result);
} }
......
...@@ -20,10 +20,42 @@ ...@@ -20,10 +20,42 @@
// 0 in the "n" dimension // 0 in the "n" dimension
// assume C1 and C have same layout C // assume C1 and C have same layout C
struct BiasReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float a = v1 + v2;
float b = v2;
float c = (v0 > -v1) ? a + v0 : v2;
return c;
#endif
}
};
// v0 is from A * B // v0 is from A * B
// v1 is from C0 // v1 is from C0
// v2 is from C1 // v2 is from C1
struct BiasReluAdd struct BiasLeakyReluAdd
{ {
template <typename T1, typename T2> template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
...@@ -51,7 +83,7 @@ struct BiasReluAdd ...@@ -51,7 +83,7 @@ struct BiasReluAdd
} }
}; };
struct BiasRelu struct BiasLeakyRelu
{ {
template <typename T1, typename T2> template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2) const __host__ constexpr float operator()(float v0, T1 v1, T2) const
...@@ -99,7 +131,7 @@ struct BiasAdd ...@@ -99,7 +131,7 @@ struct BiasAdd
} }
#elif 0 #elif 0
float alpha = 0.1; float alpha = 0.1;
float beta = 0.2; float beta = 0.2;
float gamma = 0.3; float gamma = 0.3;
// wrong result // wrong result
......
...@@ -23,7 +23,7 @@ struct PassThrough ...@@ -23,7 +23,7 @@ struct PassThrough
} }
}; };
struct BiasReluAdd struct BiasLeakyReluAdd
{ {
template <typename T1, typename T2> template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
...@@ -97,7 +97,39 @@ struct BiasReluAdd ...@@ -97,7 +97,39 @@ struct BiasReluAdd
} }
}; };
struct BiasRelu struct BiasReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float a = v1 + v2;
float b = v2;
float c = (v0 > -v1) ? a + v0 : v2;
return c;
#endif
}
};
struct BiasLeakyRelu
{ {
template <typename T1, typename T2> template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2) const __host__ constexpr float operator()(float v0, T1 v1, T2) const
...@@ -377,6 +409,7 @@ int main(int argc, char* argv[]) ...@@ -377,6 +409,7 @@ int main(int argc, char* argv[])
std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) +
sizeof(WeiDataType) * (K * C * Y * X) + sizeof(WeiDataType) * (K * C * Y * X) +
sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K) +
sizeof(OutDataType) * (N * K * Ho * Wo); sizeof(OutDataType) * (N * K * Ho * Wo);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment