Commit a0c4afbf authored by Paul's avatar Paul
Browse files

Add support for fp16 on the gpu

parent 34b44662
......@@ -3,6 +3,7 @@
#include <migraph/gpu/device/tensor.hpp>
#include <migraph/gpu/device/launch.hpp>
#include <migraph/gpu/device/types.hpp>
#include <migraph/functional.hpp>
#include <migraph/ranges.hpp>
......@@ -38,9 +39,9 @@ auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments..
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = pack(
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, inputs.data())...);
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, device_cast(inputs.data()))...);
hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = output.data();
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())([=](auto i) {
data([&](auto&&... ps) {
auto outidx = out_desc.multi(i);
......@@ -71,11 +72,11 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = as_vec4(input1.data());
auto* yp = as_vec4(input2.data());
auto* zp = as_vec4(input3.data());
auto* outp = as_vec4(output.data());
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = as_vec4(device_cast(input1.data()));
auto* yp = as_vec4(device_cast(input2.data()));
auto* zp = as_vec4(device_cast(input3.data()));
auto* outp = as_vec4(device_cast(output.data()));
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
......@@ -130,11 +131,11 @@ void trinary_broadcast_impl(hipStream_t stream,
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = input1.data();
auto* yp = input2.data();
auto* zp = input3.data();
auto* outp = output.data();
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = device_cast(input1.data());
auto* yp = device_cast(input2.data());
auto* zp = device_cast(input3.data());
auto* outp = device_cast(output.data());
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
......@@ -177,10 +178,10 @@ void binary_broadcast_vec_impl(
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = as_vec4(input1.data());
auto* yp = as_vec4(input2.data());
auto* outp = as_vec4(output.data());
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = as_vec4(device_cast(input1.data()));
auto* yp = as_vec4(device_cast(input2.data()));
auto* outp = as_vec4(device_cast(output.data()));
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
......@@ -230,10 +231,10 @@ void binary_broadcast_impl(
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = input1.data();
auto* yp = input2.data();
auto* outp = output.data();
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = device_cast(input1.data());
auto* yp = device_cast(input2.data());
auto* outp = device_cast(output.data());
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
......@@ -265,10 +266,10 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
// assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
const std::size_t vec_size = 4;
auto data = pack_vec4(inputs.data()...);
auto* outp = as_vec4(output.data());
auto data = pack_vec4(device_cast(inputs.data())...);
auto* outp = as_vec4(device_cast(output.data()));
gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) {
vec4<type> out = outp[i];
data(
......@@ -290,8 +291,8 @@ void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... a
// assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = pack(inputs.data()...);
auto* outp = output.data();
auto data = pack(device_cast(inputs.data())...);
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
});
......
/*=============================================================================
Copyright (c) 2017 Paul Fultz II
types.hpp
Distributed under the Boost Software License, Version 1.0. (See accompanying
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#define MIGRAPH_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#include <migraph/half.hpp>
namespace migraph {
namespace gpu {
namespace device {
using gpu_half = __fp16;
namespace detail {
template<class T>
struct device_type
{
using type = T;
};
template<>
struct device_type<half>
{
using type = gpu_half;
};
template<class T>
struct host_type
{
using type = T;
};
template<>
struct device_type<gpu_half>
{
using type = half;
};
} // namespace detail
template<class T>
using host_type = typename detail::host_type<T>::type;
template<class T>
using device_type = typename detail::device_type<T>::type;
template<class T>
host_type<T> host_cast(T x)
{
return reinterpret_cast<host_type<T>>(x);
}
template<class T>
host_type<T>* host_cast(T* x)
{
return reinterpret_cast<host_type<T>*>(x);
}
template<class T>
device_type<T> device_cast(T x)
{
return reinterpret_cast<device_type<T>>(x);
}
template<class T>
device_type<T>* device_cast(T* x)
{
return reinterpret_cast<device_type<T>*>(x);
}
} // namespace device
} // namespace gpu
} // namespace migraph
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment