Commit 6ff00ed4 authored by Chao Liu's avatar Chao Liu
Browse files

teak example

parent dcf48977
......@@ -52,70 +52,13 @@ struct BiasReluAdd
}
};
// v0 is from A * B
// v1 is from C0
// v2 is from C1
struct BiasLeakyReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * (a + c);
return d;
}
};
struct BiasLeakyRelu
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
return c;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2) const
{
constexpr float alpha = 0.1;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * c;
return d;
}
};
struct BiasAdd
{
#if 1
// correct result
// no scratch memory, good VGPR allocation (59)
// good perf (101Tflops)
template <typename T1, typename T2>
__host__ __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
__host__ __device__ constexpr float operator()(float v0, ck::half_t v1, ck::half_t v2) const
{
constexpr float alpha = 0.1;
constexpr float beta = 0.2;
......@@ -124,7 +67,7 @@ struct BiasAdd
// compiler seems very volatile to the order of these calculation:
// compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register
// over-allocation. Therefore, move v0 calculation to the very end
float a = T1(beta) * v1 + T2(gamma) * v2;
float a = ck::half_t(beta) * v1 + ck::half_t(gamma) * v2;
float b = a + float(alpha) * v0;
return b;
......@@ -151,7 +94,7 @@ struct BiasAdd
{
return 0.1 * v0 + 0.2 * v1 + 0.3 * v2;
}
#elif 1
#elif 0
// wrong result
// lots of scratch memory
// huge perf drop
......@@ -215,16 +158,15 @@ static void host_verify(const Tensor<AType>& a_m_k,
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a_m_k.mDesc.GetLengths()[1];
double v = 0;
float acc = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
static_cast<const double>(b_element_op(b_k_n(k, n)));
acc += static_cast<const double>(a_element_op(a_m_k(m, k))) *
static_cast<const double>(b_element_op(b_k_n(k, n)));
}
c_m_n(m, n) = c_element_op(
v, static_cast<const double>(c0_m_n(m, n)), static_cast<const double>(c1_m_n(m, n)));
c_m_n(m, n) = c_element_op(acc, c0_m_n(m, n), c1_m_n(m, n));
};
make_ParallelTensorFunctor(f_mk_kn_mn,
......@@ -249,9 +191,9 @@ int main(int argc, char* argv[])
if(argc == 4)
{
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
}
else if(argc == 10)
{
......@@ -337,7 +279,9 @@ int main(int argc, char* argv[])
c0_m_n_device_buf.ToDevice(c0_m_n.mData.data());
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
auto c_element_op = BiasReluAdd{};
auto a_element_op = AOp{};
auto b_element_op = BOp{};
auto c_element_op = COp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
......@@ -354,8 +298,8 @@ int main(int argc, char* argv[])
StrideA,
StrideB,
StrideC,
PassThrough{},
PassThrough{},
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment