"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "949553db236cf78daa3bdfe1a966a61a8a01d22e"
Commit 6ff00ed4 authored by Chao Liu's avatar Chao Liu
Browse files

teak example

parent dcf48977
...@@ -52,70 +52,13 @@ struct BiasReluAdd ...@@ -52,70 +52,13 @@ struct BiasReluAdd
} }
}; };
// v0 is from A * B
// v1 is from C0
// v2 is from C1
struct BiasLeakyReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * (a + c);
return d;
}
};
struct BiasLeakyRelu
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
return c;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2) const
{
constexpr float alpha = 0.1;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * c;
return d;
}
};
struct BiasAdd struct BiasAdd
{ {
#if 1 #if 1
// correct result // correct result
// no scratch memory, good VGPR allocation (59) // no scratch memory, good VGPR allocation (59)
// good perf (101Tflops) // good perf (101Tflops)
template <typename T1, typename T2> __host__ __device__ constexpr float operator()(float v0, ck::half_t v1, ck::half_t v2) const
__host__ __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{ {
constexpr float alpha = 0.1; constexpr float alpha = 0.1;
constexpr float beta = 0.2; constexpr float beta = 0.2;
...@@ -124,7 +67,7 @@ struct BiasAdd ...@@ -124,7 +67,7 @@ struct BiasAdd
// compiler seems very volatile to the order of these calculation: // compiler seems very volatile to the order of these calculation:
// compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register // compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register
// over-allocation. Therefore, move v0 calculation to the very end // over-allocation. Therefore, move v0 calculation to the very end
float a = T1(beta) * v1 + T2(gamma) * v2; float a = ck::half_t(beta) * v1 + ck::half_t(gamma) * v2;
float b = a + float(alpha) * v0; float b = a + float(alpha) * v0;
return b; return b;
...@@ -151,7 +94,7 @@ struct BiasAdd ...@@ -151,7 +94,7 @@ struct BiasAdd
{ {
return 0.1 * v0 + 0.2 * v1 + 0.3 * v2; return 0.1 * v0 + 0.2 * v1 + 0.3 * v2;
} }
#elif 1 #elif 0
// wrong result // wrong result
// lots of scratch memory // lots of scratch memory
// huge perf drop // huge perf drop
...@@ -215,16 +158,15 @@ static void host_verify(const Tensor<AType>& a_m_k, ...@@ -215,16 +158,15 @@ static void host_verify(const Tensor<AType>& a_m_k,
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a_m_k.mDesc.GetLengths()[1]; const int K = a_m_k.mDesc.GetLengths()[1];
double v = 0; float acc = 0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
v += static_cast<const double>(a_element_op(a_m_k(m, k))) * acc += static_cast<const double>(a_element_op(a_m_k(m, k))) *
static_cast<const double>(b_element_op(b_k_n(k, n))); static_cast<const double>(b_element_op(b_k_n(k, n)));
} }
c_m_n(m, n) = c_element_op( c_m_n(m, n) = c_element_op(acc, c0_m_n(m, n), c1_m_n(m, n));
v, static_cast<const double>(c0_m_n(m, n)), static_cast<const double>(c1_m_n(m, n)));
}; };
make_ParallelTensorFunctor(f_mk_kn_mn, make_ParallelTensorFunctor(f_mk_kn_mn,
...@@ -249,9 +191,9 @@ int main(int argc, char* argv[]) ...@@ -249,9 +191,9 @@ int main(int argc, char* argv[])
if(argc == 4) if(argc == 4)
{ {
M = std::stoi(argv[4]); do_verification = std::stoi(argv[1]);
N = std::stoi(argv[5]); init_method = std::stoi(argv[2]);
K = std::stoi(argv[6]); nrepeat = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
...@@ -337,7 +279,9 @@ int main(int argc, char* argv[]) ...@@ -337,7 +279,9 @@ int main(int argc, char* argv[])
c0_m_n_device_buf.ToDevice(c0_m_n.mData.data()); c0_m_n_device_buf.ToDevice(c0_m_n.mData.data());
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
auto c_element_op = BiasReluAdd{}; auto a_element_op = AOp{};
auto b_element_op = BOp{};
auto c_element_op = COp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
...@@ -354,8 +298,8 @@ int main(int argc, char* argv[]) ...@@ -354,8 +298,8 @@ int main(int argc, char* argv[])
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
PassThrough{}, a_element_op,
PassThrough{}, b_element_op,
c_element_op); c_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment