Commit 96e74d6e authored by Paul's avatar Paul
Browse files

Formatting

parent a0c4afbf
...@@ -38,8 +38,8 @@ auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments.. ...@@ -38,8 +38,8 @@ auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments..
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = pack( auto data = pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()},
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, device_cast(inputs.data()))...); device_cast(inputs.data()))...);
hip_tensor_descriptor<ndim> out_desc(output_shape); hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = device_cast(output.data()); auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())([=](auto i) { gs_launch(stream, output_shape.elements())([=](auto i) {
...@@ -266,7 +266,7 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments. ...@@ -266,7 +266,7 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
auto data = pack_vec4(device_cast(inputs.data())...); auto data = pack_vec4(device_cast(inputs.data())...);
auto* outp = as_vec4(device_cast(output.data())); auto* outp = as_vec4(device_cast(output.data()));
......
...@@ -17,26 +17,25 @@ namespace device { ...@@ -17,26 +17,25 @@ namespace device {
using gpu_half = __fp16; using gpu_half = __fp16;
namespace detail { namespace detail {
template<class T> template <class T>
struct device_type struct device_type
{ {
using type = T; using type = T;
}; };
template<> template <>
struct device_type<half> struct device_type<half>
{ {
using type = gpu_half; using type = gpu_half;
}; };
template <class T>
template<class T>
struct host_type struct host_type
{ {
using type = T; using type = T;
}; };
template<> template <>
struct device_type<gpu_half> struct device_type<gpu_half>
{ {
using type = half; using type = half;
...@@ -44,31 +43,31 @@ struct device_type<gpu_half> ...@@ -44,31 +43,31 @@ struct device_type<gpu_half>
} // namespace detail } // namespace detail
template<class T> template <class T>
using host_type = typename detail::host_type<T>::type; using host_type = typename detail::host_type<T>::type;
template<class T> template <class T>
using device_type = typename detail::device_type<T>::type; using device_type = typename detail::device_type<T>::type;
template<class T> template <class T>
host_type<T> host_cast(T x) host_type<T> host_cast(T x)
{ {
return reinterpret_cast<host_type<T>>(x); return reinterpret_cast<host_type<T>>(x);
} }
template<class T> template <class T>
host_type<T>* host_cast(T* x) host_type<T>* host_cast(T* x)
{ {
return reinterpret_cast<host_type<T>*>(x); return reinterpret_cast<host_type<T>*>(x);
} }
template<class T> template <class T>
device_type<T> device_cast(T x) device_type<T> device_cast(T x)
{ {
return reinterpret_cast<device_type<T>>(x); return reinterpret_cast<device_type<T>>(x);
} }
template<class T> template <class T>
device_type<T>* device_cast(T* x) device_type<T>* device_cast(T* x)
{ {
return reinterpret_cast<device_type<T>*>(x); return reinterpret_cast<device_type<T>*>(x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment