Commit a7361926 authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent 81b26528
...@@ -36,28 +36,37 @@ struct PassThrough ...@@ -36,28 +36,37 @@ struct PassThrough
// v2 is from bias vector // v2 is from bias vector
struct BiasAdd struct BiasAdd
{ {
#if 1 #if 0
// correct result // correct result
// no scratch memory, good VGPR allocation (59) // no scratch memory, good VGPR allocation (59)
// good perf (101Tflops) // good perf (101Tflops)
template <typename T1, typename T2> template <typename T1, typename T2>
__host__ __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const __host__ __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{ {
constexpr float alpha = 0.1;
constexpr float beta = 0.2;
constexpr float gamma = 0.3;
// compiler seems very volatile to the order of these calculation: // compiler seems very volatile to the order of these calculation:
// compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register // compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register
// over-allocation. Therefore, move v0 calculation to the very end // over-allocation. Therefore, move v0 calculation to the very end
float a = T1(0.2) * v1 + T2(0.3) * v2; float a = T1(beta) * v1 + T2(gamma) * v2;
float b = a + float(0.1) * v0; float b = a + float(alpha) * v0;
return b; return b;
} }
#elif 0 #elif 1
// correct result float alpha = 0.1;
// some scratch memory (68), large VGPR usage (126) float beta = 0.2;
// very little perf drop (101Tflops) float gamma = 0.3;
__host__ __device__ constexpr auto operator()(float v0, ck::half_t v1, ck::half_t v2) const
// wrong result
// lots of scratch memory
// huge perf drop
template <typename T1, typename T2>
__host__ __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{ {
return float(0.1) * v0 + ck::half_t(0.2) * v1 + ck::half_t(0.3) * v2; return alpha * v0 + beta * v1 + gamma * v2;
} }
#elif 0 #elif 0
// correct result // correct result
...@@ -361,5 +370,7 @@ int main(int argc, char* argv[]) ...@@ -361,5 +370,7 @@ int main(int argc, char* argv[])
PassThrough{}, PassThrough{},
PassThrough{}, PassThrough{},
c_element_op); c_element_op);
check_error(c_m_n_host_result, c_m_n_device_result);
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment