Commit 08c9433e authored by Chao Liu's avatar Chao Liu
Browse files

fix relu

parent 41cdd380
...@@ -123,17 +123,9 @@ struct DeviceGemmInstance<float, ...@@ -123,17 +123,9 @@ struct DeviceGemmInstance<float,
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
if(argc != 4) bool do_verification = 0;
{ int init_method = 0;
printf("arg1: verification (0=no, 1=yes)\n"); int nrepeat = 5;
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
exit(0);
}
const bool do_verification = std::stoi(argv[1]);
const int init_method = std::stoi(argv[2]);
const int nrepeat = std::stoi(argv[3]);
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -144,6 +136,35 @@ int main(int argc, char* argv[]) ...@@ -144,6 +136,35 @@ int main(int argc, char* argv[])
ck::index_t StrideB = 4096; ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096; ck::index_t StrideC = 4096;
if(argc == 4)
{
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
// matrix data type // matrix data type
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
......
...@@ -20,10 +20,42 @@ ...@@ -20,10 +20,42 @@
// 0 in the "n" dimension // 0 in the "n" dimension
// assume C1 and C have same layout C // assume C1 and C have same layout C
struct BiasReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float a = v1 + v2;
float b = v2;
float c = (v0 > -v1) ? a + v0 : v2;
return c;
#endif
}
};
// v0 is from A * B // v0 is from A * B
// v1 is from C0 // v1 is from C0
// v2 is from C1 // v2 is from C1
struct BiasReluAdd struct BiasLeakyReluAdd
{ {
template <typename T1, typename T2> template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
...@@ -51,7 +83,7 @@ struct BiasReluAdd ...@@ -51,7 +83,7 @@ struct BiasReluAdd
} }
}; };
struct BiasRelu struct BiasLeakyRelu
{ {
template <typename T1, typename T2> template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2) const __host__ constexpr float operator()(float v0, T1 v1, T2) const
...@@ -99,7 +131,7 @@ struct BiasAdd ...@@ -99,7 +131,7 @@ struct BiasAdd
} }
#elif 0 #elif 0
float alpha = 0.1; float alpha = 0.1;
float beta = 0.2; float beta = 0.2;
float gamma = 0.3; float gamma = 0.3;
// wrong result // wrong result
......
...@@ -23,7 +23,7 @@ struct PassThrough ...@@ -23,7 +23,7 @@ struct PassThrough
} }
}; };
struct BiasReluAdd struct BiasLeakyReluAdd
{ {
template <typename T1, typename T2> template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
...@@ -97,7 +97,39 @@ struct BiasReluAdd ...@@ -97,7 +97,39 @@ struct BiasReluAdd
} }
}; };
struct BiasRelu struct BiasReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float a = v1 + v2;
float b = v2;
float c = (v0 > -v1) ? a + v0 : v2;
return c;
#endif
}
};
struct BiasLeakyRelu
{ {
template <typename T1, typename T2> template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2) const __host__ constexpr float operator()(float v0, T1 v1, T2) const
...@@ -377,6 +409,7 @@ int main(int argc, char* argv[]) ...@@ -377,6 +409,7 @@ int main(int argc, char* argv[])
std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) +
sizeof(WeiDataType) * (K * C * Y * X) + sizeof(WeiDataType) * (K * C * Y * X) +
sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K) +
sizeof(OutDataType) * (N * K * Ho * Wo); sizeof(OutDataType) * (N * K * Ho * Wo);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment