Commit 08c9433e authored by Chao Liu's avatar Chao Liu
Browse files

fix relu

parent 41cdd380
......@@ -123,17 +123,9 @@ struct DeviceGemmInstance<float,
int main(int argc, char* argv[])
{
if(argc != 4)
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
exit(0);
}
const bool do_verification = std::stoi(argv[1]);
const int init_method = std::stoi(argv[2]);
const int nrepeat = std::stoi(argv[3]);
bool do_verification = 0;
int init_method = 0;
int nrepeat = 5;
// GEMM shape
ck::index_t M = 3840;
......@@ -144,6 +136,35 @@ int main(int argc, char* argv[])
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
if(argc == 4)
{
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
// matrix data type
using ADataType = ck::half_t;
using BDataType = ck::half_t;
......
......@@ -20,10 +20,42 @@
// 0 in the "n" dimension
// assume C1 and C have same layout C
struct BiasReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float a = v1 + v2;
float b = v2;
float c = (v0 > -v1) ? a + v0 : v2;
return c;
#endif
}
};
// v0 is from A * B
// v1 is from C0
// v2 is from C1
struct BiasReluAdd
struct BiasLeakyReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
......@@ -51,7 +83,7 @@ struct BiasReluAdd
}
};
struct BiasRelu
struct BiasLeakyRelu
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2) const
......
......@@ -23,7 +23,7 @@ struct PassThrough
}
};
struct BiasReluAdd
struct BiasLeakyReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
......@@ -97,7 +97,39 @@ struct BiasReluAdd
}
};
struct BiasRelu
struct BiasReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float a = v1 + v2;
float b = v2;
float c = (v0 > -v1) ? a + v0 : v2;
return c;
#endif
}
};
struct BiasLeakyRelu
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2) const
......@@ -377,6 +409,7 @@ int main(int argc, char* argv[])
std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) +
sizeof(WeiDataType) * (K * C * Y * X) +
sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K) +
sizeof(OutDataType) * (N * K * Ho * Wo);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment