Commit 96e74d6e authored by Paul's avatar Paul
Browse files

Formatting

parent a0c4afbf
...@@ -38,8 +38,8 @@ auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments.. ...@@ -38,8 +38,8 @@ auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments..
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = pack( auto data = pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()},
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, device_cast(inputs.data()))...); device_cast(inputs.data()))...);
hip_tensor_descriptor<ndim> out_desc(output_shape); hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = device_cast(output.data()); auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())([=](auto i) { gs_launch(stream, output_shape.elements())([=](auto i) {
......
...@@ -17,26 +17,25 @@ namespace device { ...@@ -17,26 +17,25 @@ namespace device {
using gpu_half = __fp16; using gpu_half = __fp16;
namespace detail { namespace detail {
template<class T> template <class T>
struct device_type struct device_type
{ {
using type = T; using type = T;
}; };
template<> template <>
struct device_type<half> struct device_type<half>
{ {
using type = gpu_half; using type = gpu_half;
}; };
template <class T>
template<class T>
struct host_type struct host_type
{ {
using type = T; using type = T;
}; };
template<> template <>
struct device_type<gpu_half> struct device_type<gpu_half>
{ {
using type = half; using type = half;
...@@ -44,31 +43,31 @@ struct device_type<gpu_half> ...@@ -44,31 +43,31 @@ struct device_type<gpu_half>
} // namespace detail } // namespace detail
template<class T> template <class T>
using host_type = typename detail::host_type<T>::type; using host_type = typename detail::host_type<T>::type;
template<class T> template <class T>
using device_type = typename detail::device_type<T>::type; using device_type = typename detail::device_type<T>::type;
template<class T> template <class T>
host_type<T> host_cast(T x) host_type<T> host_cast(T x)
{ {
return reinterpret_cast<host_type<T>>(x); return reinterpret_cast<host_type<T>>(x);
} }
template<class T> template <class T>
host_type<T>* host_cast(T* x) host_type<T>* host_cast(T* x)
{ {
return reinterpret_cast<host_type<T>*>(x); return reinterpret_cast<host_type<T>*>(x);
} }
template<class T> template <class T>
device_type<T> device_cast(T x) device_type<T> device_cast(T x)
{ {
return reinterpret_cast<device_type<T>>(x); return reinterpret_cast<device_type<T>>(x);
} }
template<class T> template <class T>
device_type<T>* device_cast(T* x) device_type<T>* device_cast(T* x)
{ {
return reinterpret_cast<device_type<T>*>(x); return reinterpret_cast<device_type<T>*>(x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment