Commit 96e74d6e authored by Paul's avatar Paul
Browse files

Formatting

parent a0c4afbf
......@@ -38,8 +38,8 @@ auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments..
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = pack(
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, device_cast(inputs.data()))...);
auto data = pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()},
device_cast(inputs.data()))...);
hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())([=](auto i) {
......@@ -266,7 +266,7 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
// assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
const std::size_t vec_size = 4;
auto data = pack_vec4(device_cast(inputs.data())...);
auto* outp = as_vec4(device_cast(output.data()));
......
......@@ -17,26 +17,25 @@ namespace device {
using gpu_half = __fp16;
namespace detail {
template<class T>
template <class T>
struct device_type
{
using type = T;
};
template<>
template <>
struct device_type<half>
{
using type = gpu_half;
};
template<class T>
template <class T>
struct host_type
{
using type = T;
};
template<>
template <>
struct device_type<gpu_half>
{
using type = half;
......@@ -44,31 +43,31 @@ struct device_type<gpu_half>
} // namespace detail
template<class T>
template <class T>
using host_type = typename detail::host_type<T>::type;
template<class T>
template <class T>
using device_type = typename detail::device_type<T>::type;
template<class T>
template <class T>
host_type<T> host_cast(T x)
{
return reinterpret_cast<host_type<T>>(x);
}
template<class T>
template <class T>
host_type<T>* host_cast(T* x)
{
return reinterpret_cast<host_type<T>*>(x);
}
template<class T>
template <class T>
device_type<T> device_cast(T x)
{
return reinterpret_cast<device_type<T>>(x);
}
template<class T>
template <class T>
device_type<T>* device_cast(T* x)
{
return reinterpret_cast<device_type<T>*>(x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment