Commit 83f89182 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup the mul_add latest implementation

parent 67903751
...@@ -90,15 +90,33 @@ void mul_add(hipStream_t stream, ...@@ -90,15 +90,33 @@ void mul_add(hipStream_t stream,
const argument& arg3) const argument& arg3)
{ {
auto sr = result.get_shape(); auto sr = result.get_shape();
auto s1 = arg1.get_shape();
auto s2 = arg2.get_shape(); auto s2 = arg2.get_shape();
auto s3 = arg3.get_shape(); auto s3 = arg3.get_shape();
auto type = sr.type();
hip_visit_all(result, arg1, arg2, arg3, sr)([&](auto r, auto i1, auto i2, auto i3, auto dsr) { if(type == sr.type())
gs_launch(stream, sr.elements())([=](auto i) __device__ { {
hip_visit_all(result, arg1, arg2, arg3, sr, s1, s2, s3)([&](auto r, auto i1, auto i2, auto i3,
auto dsr, auto ds1, auto ds2, auto ds3) {
__half2* rp = reinterpret_cast<__half2*>(r.data());
__half2* i1p = reinterpret_cast<__half2*>(i1.data());
__half2* i2p = reinterpret_cast<__half2*>(i2.data());
__half2* i3p = reinterpret_cast<__half2*>(i3.data());
gs_launch(stream, sr.elements() / 2)([=](auto i) __device__ {
auto idx = dsr.multi(i); auto idx = dsr.multi(i);
r[i] = i1[i] * i2[idx] + i3[idx]; auto idx1 = ds1.index(idx);
auto idx2 = ds2.index(idx);
auto idx3 = ds3.index(idx);
rp[i] = __hadd2(__hmul2(i1p[idx1], i2p[idx2]), i3p[idx3]);
}); });
}); });
}
else
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b)
__device__ { return a * x + b; });
}
} }
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment