Commit f50bcff2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup additional changes

parent 702412b1
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
__global__ void add_kernel(__half* a, __half* b, __half* r, int n)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n)
{
r[tid] = a[tid] + b[tid%768];
}
}
void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; });
auto s2 = arg2.get_shape();
if (s2.element_space() == 768 and s2.type() == shape::half_type)
{
auto elem_num = s2.elements();
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
add_kernel<<<block_num, block_size>>>(reinterpret_cast<__half*>(arg1.data()),
reinterpret_cast<__half*>(arg2.data()),
reinterpret_cast<__half*>(result.data()), elem_num);
}
else
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; });
}
}
void add(hipStream_t stream,
......
......@@ -13,8 +13,12 @@ void contiguous_nonstandard(hipStream_t stream, const argument& result, const ar
shape s{result.get_shape().type(), result.get_shape().lens()};
visit_all(result, arg)([&](auto output_v, auto input_v) {
hip_visit_views(output_v, input_v, s)([&](auto output, auto input, auto standard_shape) {
mi_gs_launch(stream,
standard_shape)([=](auto idx) __device__ { output[idx] = input[idx]; });
gs_launch(stream, s.elements())([=](auto i) __device__ {
auto idx = standard_shape.multi(i);
output[idx] = input[idx];
});
// mi_gs_launch(stream,
// standard_shape)([=](auto idx) __device__ { output[idx] = input[idx]; });
});
});
}
......@@ -22,31 +26,31 @@ void contiguous_nonstandard(hipStream_t stream, const argument& result, const ar
void contiguous_packed(hipStream_t stream, const argument& result, const argument& arg)
{
index_int nelements = result.get_shape().elements();
auto type = result.get_shape().type();
if (type == shape::half_type)
{
visit_all(result, arg)([&](auto output_v, auto input_v) {
const auto* input = device_cast(input_v.data());
auto* output = device_cast(output_v.data());
const __half2* input2 = reinterpret_cast<__half2*>(input_v.data());
__half2* output2 = reinterpret_cast<__half2*>(output_v.data());
gs_launch(stream, nelements / 2)([=](auto i) __device__ {
output2[i] = input2[i];
if (i == 0 and (nelements % 2) == 1)
{
output[nelements - 1] = input[nelements - 1];
}
});
});
}
else
{
// auto type = result.get_shape().type();
// if (type == shape::half_type)
// {
// visit_all(result, arg)([&](auto output_v, auto input_v) {
// const auto* input = device_cast(input_v.data());
// auto* output = device_cast(output_v.data());
// const __half2* input2 = reinterpret_cast<__half2*>(input_v.data());
// __half2* output2 = reinterpret_cast<__half2*>(output_v.data());
// gs_launch(stream, nelements / 2)([=](auto i) __device__ {
// output2[i] = input2[i];
// if (i == 0 and (nelements % 2) == 1)
// {
// output[nelements - 1] = input[nelements - 1];
// }
// });
// });
// }
// else
// {
visit_all(result, arg)([&](auto output_v, auto input_v) {
const auto* input = device_cast(input_v.data());
auto* output = device_cast(output_v.data());
gs_launch(stream, nelements)([=](auto i) __device__ { output[i] = input[i]; });
});
}
// }
}
void contiguous(hipStream_t stream, const argument& result, const argument& arg)
......
......@@ -59,7 +59,8 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
assert(s.elements() > 0);
index_int n = s.elements();
index_int groups = (n + nlocal - 1) / nlocal;
index_int nglobal = std::min<index_int>(128, groups) * nlocal;
// change the max group num to 1 Million
index_int nglobal = std::min<index_int>((1 << 20), groups) * nlocal;
assert(groups > 0);
assert(nglobal > 0);
......
#include <migraphx/gpu/device/mul.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
__global__ void mul_kernel(__half* a, __half* b, __half* r, int n)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n)
{
r[tid] = a[tid] * b[tid%768];
}
}
void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x * y; });
auto s2 = arg2.get_shape();
if (s2.element_space() == 768 and s2.type() == shape::half_type)
{
auto elem_num = s2.elements();
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
mul_kernel<<<block_num, block_size>>>(reinterpret_cast<__half*>(arg1.data()),
reinterpret_cast<__half*>(arg2.data()),
reinterpret_cast<__half*>(result.data()), elem_num);
}
else
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x * y; });
}
}
void mul(hipStream_t stream,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment