Unverified Commit 4a2a56c2 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Rangify constructor of HostTensorDescriptor & Tensor<> (#445)

* Rangify STL algorithms

This commit adapts rangified std::copy(), std::fill() & std::transform()

* Rangify check_err()

By rangifying check_err(), we can not only compare values between
std::vector<>s, but also compare any ranges which have same value
type.

* Allow constructing Tensor<> like a HostTensorDescriptor

* Simplify Tensor<> object construction logics

* Remove more unnecessary 'HostTensorDescriptor' objects

* Re-format example code

* Re-write more HostTensorDescriptor ctor call
parent 37f2e918
...@@ -61,7 +61,7 @@ bool run_conv_bwd_data_bias_relu(const ExecutionConfig& config, ...@@ -61,7 +61,7 @@ bool run_conv_bwd_data_bias_relu(const ExecutionConfig& config,
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{}; std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths); copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides); copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides);
...@@ -157,7 +157,7 @@ bool run_conv_bwd_data_bias_relu(const ExecutionConfig& config, ...@@ -157,7 +157,7 @@ bool run_conv_bwd_data_bias_relu(const ExecutionConfig& config,
in_device_buf.FromDevice(in_device.mData.data()); in_device_buf.FromDevice(in_device.mData.data());
return ck::utils::check_err(in_device.mData, in_host.mData); return ck::utils::check_err(in_device, in_host);
} }
return true; return true;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp" #include "ck/library/utility/fill.hpp"
...@@ -247,19 +248,6 @@ inline auto to_array(Range& range) noexcept ...@@ -247,19 +248,6 @@ inline auto to_array(Range& range) noexcept
return detail::to_array_proxy<ck::remove_cvref_t<Range>>{range}; return detail::to_array_proxy<ck::remove_cvref_t<Range>>{range};
} }
namespace ranges {
template <typename InputRange, typename OutputIterator>
inline auto copy(InputRange&& range, OutputIterator iter)
-> decltype(std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter))
{
return std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter);
}
} // namespace ranges
template <typename Axes> template <typename Axes>
inline auto is_valid_axes(const Axes& axes) inline auto is_valid_axes(const Axes& axes)
-> std::enable_if_t<detail::is_random_access_range_v<Axes>, bool> -> std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
...@@ -350,7 +338,7 @@ auto extend_shape(const Problem::Shape& shape, std::size_t new_dim) ...@@ -350,7 +338,7 @@ auto extend_shape(const Problem::Shape& shape, std::size_t new_dim)
using std::begin, std::end; using std::begin, std::end;
std::copy(begin(shape), end(shape), begin(extended_shape)); ck::ranges::copy(shape, begin(extended_shape));
extended_shape.back() = new_dim; extended_shape.back() = new_dim;
return extended_shape; return extended_shape;
...@@ -362,7 +350,7 @@ auto extend_axes(const Problem::Axes& axes) ...@@ -362,7 +350,7 @@ auto extend_axes(const Problem::Axes& axes)
using std::begin, std::end; using std::begin, std::end;
std::copy(begin(axes), end(axes), begin(extended_axes)); ck::ranges::copy(axes, begin(extended_axes));
extended_axes.back() = detail::get_array_size_v<Problem::Axes>; extended_axes.back() = detail::get_array_size_v<Problem::Axes>;
return extended_axes; return extended_axes;
......
...@@ -57,7 +57,7 @@ bool run_permute_bundle(const Problem& problem) ...@@ -57,7 +57,7 @@ bool run_permute_bundle(const Problem& problem)
using std::begin; using std::begin;
Tensor<DataType> input_tensor(input_shape); Tensor<DataType> input_tensor(input_shape);
ranges::copy(input_bundle_tensor.AsSpan<const DataType>(), begin(input_tensor)); ck::ranges::copy(input_bundle_tensor.AsSpan<const DataType>(), begin(input_tensor));
Tensor<DataType> output_tensor(transpose(input_shape, input_axes)); Tensor<DataType> output_tensor(transpose(input_shape, input_axes));
if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor)) if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor))
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
......
...@@ -97,7 +97,7 @@ bool run_grouped_conv_conv_fwd(bool do_verification, ...@@ -97,7 +97,7 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
std::array<ck::index_t, NDimSpatial> input1_left_pads{}; std::array<ck::index_t, NDimSpatial> input1_left_pads{};
std::array<ck::index_t, NDimSpatial> input1_right_pads{}; std::array<ck::index_t, NDimSpatial> input1_right_pads{};
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in0_g_n_c_wis_desc.GetLengths(), a0_g_n_c_wis_lengths); copy(in0_g_n_c_wis_desc.GetLengths(), a0_g_n_c_wis_lengths);
copy(in0_g_n_c_wis_desc.GetStrides(), a0_g_n_c_wis_strides); copy(in0_g_n_c_wis_desc.GetStrides(), a0_g_n_c_wis_strides);
...@@ -261,7 +261,7 @@ bool run_grouped_conv_conv_fwd(bool do_verification, ...@@ -261,7 +261,7 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
#endif #endif
return ck::utils::check_err( return ck::utils::check_err(
out1_device.mData, out1_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); out1_device, out1_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return true; return true;
......
...@@ -167,7 +167,7 @@ int main(int argc, char* argv[]) ...@@ -167,7 +167,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass &= ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
size_t M = acc.mDesc.GetLengths()[0]; size_t M = acc.mDesc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1]; size_t N = acc.mDesc.GetLengths()[1];
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc_sq({M});
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc({M});
Tensor<ComputeDataType> acc_layernorm(acc); Tensor<ComputeDataType> acc_layernorm(acc);
// reduce N dim // reduce N dim
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <iterator>
#include <type_traits>
#include <utility>
namespace ck {
namespace ranges {
template <typename InputRange, typename OutputIterator>
auto copy(InputRange&& range, OutputIterator iter)
-> decltype(std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter))
{
return std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter);
}
template <typename T, typename OutputRange>
auto fill(OutputRange&& range, const T& init)
-> std::void_t<decltype(std::fill(std::begin(std::forward<OutputRange>(range)),
std::end(std::forward<OutputRange>(range)),
init))>
{
std::fill(std::begin(std::forward<OutputRange>(range)),
std::end(std::forward<OutputRange>(range)),
init);
}
template <typename InputRange, typename OutputIterator, typename UnaryOperation>
auto transform(InputRange&& range, OutputIterator iter, UnaryOperation unary_op)
-> decltype(std::transform(std::begin(range), std::end(range), iter, unary_op))
{
return std::transform(std::begin(range), std::end(range), iter, unary_op);
}
} // namespace ranges
} // namespace ck
...@@ -15,18 +15,22 @@ ...@@ -15,18 +15,22 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
#include "ck/host_utility/io.hpp" #include "ck/host_utility/io.hpp"
#include "ck/library/utility/ranges.hpp"
namespace ck { namespace ck {
namespace utils { namespace utils {
template <typename T> template <typename Range, typename RefRange>
typename std::enable_if<std::is_floating_point<T>::value && !std::is_same<T, half_t>::value, typename std::enable_if<
bool>::type std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
check_err(const std::vector<T>& out, std::is_floating_point_v<ranges::range_value_t<Range>> &&
const std::vector<T>& ref, !std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-5, double rtol = 1e-5,
double atol = 3e-6) double atol = 3e-6)
...@@ -44,15 +48,17 @@ check_err(const std::vector<T>& out, ...@@ -44,15 +48,17 @@ check_err(const std::vector<T>& out,
double max_err = std::numeric_limits<double>::min(); double max_err = std::numeric_limits<double>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
err = std::abs(out[i] - ref[i]); const double o = *std::next(std::begin(out), i);
if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i])) const double r = *std::next(std::begin(ref), i);
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
} }
res = false; res = false;
} }
...@@ -64,10 +70,13 @@ check_err(const std::vector<T>& out, ...@@ -64,10 +70,13 @@ check_err(const std::vector<T>& out,
return res; return res;
} }
template <typename T> template <typename Range, typename RefRange>
typename std::enable_if<std::is_same<T, bhalf_t>::value, bool>::type typename std::enable_if<
check_err(const std::vector<T>& out, std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
const std::vector<T>& ref, std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3, double rtol = 1e-3,
double atol = 1e-3) double atol = 1e-3)
...@@ -86,9 +95,9 @@ check_err(const std::vector<T>& out, ...@@ -86,9 +95,9 @@ check_err(const std::vector<T>& out,
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
double o = type_convert<float>(out[i]); const double o = type_convert<float>(*std::next(std::begin(out), i));
double r = type_convert<float>(ref[i]); const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r); err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
...@@ -108,10 +117,13 @@ check_err(const std::vector<T>& out, ...@@ -108,10 +117,13 @@ check_err(const std::vector<T>& out,
return res; return res;
} }
template <typename T> template <typename Range, typename RefRange>
typename std::enable_if<std::is_same_v<T, half_t>, bool>::type typename std::enable_if<
check_err(span<const T> out, std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
span<const T> ref, std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3, double rtol = 1e-3,
double atol = 1e-3) double atol = 1e-3)
...@@ -126,12 +138,12 @@ check_err(span<const T> out, ...@@ -126,12 +138,12 @@ check_err(span<const T> out,
bool res{true}; bool res{true};
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<T>::min(); double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
double o = type_convert<float>(out[i]); const double o = type_convert<float>(*std::next(std::begin(out), i));
double r = type_convert<float>(ref[i]); const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r); err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
...@@ -151,26 +163,17 @@ check_err(span<const T> out, ...@@ -151,26 +163,17 @@ check_err(span<const T> out,
return res; return res;
} }
template <typename T> template <typename Range, typename RefRange>
typename std::enable_if<std::is_same<T, half_t>::value, bool>::type std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
check_err(const std::vector<T>& out, std::is_integral_v<ranges::range_value_t<Range>> &&
const std::vector<T>& ref, !std::is_same_v<ranges::range_value_t<Range>, bhalf_t>)
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3)
{
return check_err(span<const T>{out}, span<const T>{ref}, msg, rtol, atol);
}
template <typename T>
std::enable_if_t<(std::is_integral_v<T> && !std::is_same_v<T, bhalf_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, int4_t> || std::is_same_v<ranges::range_value_t<Range>, int4_t>
#endif #endif
, ,
bool> bool>
check_err(const std::vector<T>& out, check_err(const Range& out,
const std::vector<T>& ref, const RefRange& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double = 0, double = 0,
double atol = 0) double atol = 0)
...@@ -188,9 +191,9 @@ check_err(const std::vector<T>& out, ...@@ -188,9 +191,9 @@ check_err(const std::vector<T>& out,
int64_t max_err = std::numeric_limits<int64_t>::min(); int64_t max_err = std::numeric_limits<int64_t>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
int64_t o = out[i]; const int64_t o = *std::next(std::begin(out), i);
int64_t r = ref[i]; const int64_t r = *std::next(std::begin(ref), i);
err = std::abs(o - r); err = std::abs(o - r);
if(err > atol) if(err > atol)
{ {
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp" #include "ck/utility/span.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/ranges.hpp"
template <typename Range> template <typename Range>
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
{ {
...@@ -84,10 +87,10 @@ struct HostTensorDescriptor ...@@ -84,10 +87,10 @@ struct HostTensorDescriptor
this->CalculateStrides(); this->CalculateStrides();
} }
template <typename Range, template <typename Lengths,
typename = std::enable_if_t< typename = std::enable_if_t<
std::is_convertible_v<decltype(*std::begin(std::declval<Range>())), std::size_t>>> std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t>>>
HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end()) HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end())
{ {
this->CalculateStrides(); this->CalculateStrides();
} }
...@@ -102,13 +105,12 @@ struct HostTensorDescriptor ...@@ -102,13 +105,12 @@ struct HostTensorDescriptor
{ {
} }
template < template <typename Lengths,
typename Range1, typename Strides,
typename Range2, typename = std::enable_if_t<
typename = std::enable_if_t< std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> &&
std::is_convertible_v<decltype(*std::begin(std::declval<Range1>())), std::size_t> && std::is_convertible_v<ck::ranges::range_value_t<Strides>, std::size_t>>>
std::is_convertible_v<decltype(*std::begin(std::declval<Range2>())), std::size_t>>> HostTensorDescriptor(const Lengths& lens, const Strides& strides)
HostTensorDescriptor(const Range1& lens, const Range2& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{ {
} }
...@@ -244,14 +246,20 @@ struct Tensor ...@@ -244,14 +246,20 @@ struct Tensor
{ {
} }
template <typename X> template <typename X, typename Y>
Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
{ {
} }
template <typename X, typename Y> template <typename Lengths>
Tensor(std::vector<X> lens, std::vector<Y> strides) Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize()) {
}
template <typename Lengths, typename Strides>
Tensor(const Lengths& lens, const Strides& strides)
: mDesc(lens, strides), mData(GetElementSpaceSize())
{ {
} }
...@@ -261,10 +269,10 @@ struct Tensor ...@@ -261,10 +269,10 @@ struct Tensor
Tensor<OutT> CopyAsType() const Tensor<OutT> CopyAsType() const
{ {
Tensor<OutT> ret(mDesc); Tensor<OutT> ret(mDesc);
for(size_t i = 0; i < mData.size(); i++)
{ ck::ranges::transform(
ret.mData[i] = ck::type_convert<OutT>(mData[i]); mData, ret.mData.begin(), [](auto value) { return ck::type_convert<OutT>(value); });
}
return ret; return ret;
} }
...@@ -294,13 +302,7 @@ struct Tensor ...@@ -294,13 +302,7 @@ struct Tensor
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
void SetZero() void SetZero() { ck::ranges::fill<T>(mData, 0); }
{
for(auto& v : mData)
{
v = T{0};
}
}
template <typename F> template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank) void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iterator>
#include <utility>
#include "ck/utility/type.hpp"
namespace ck {
template <typename T>
using iter_value_t = typename std::iterator_traits<remove_cvref_t<T>>::value_type;
template <typename T>
using iter_reference_t = decltype(*std::declval<T&>());
template <typename T>
using iter_difference_t = typename std::iterator_traits<remove_cvref_t<T>>::difference_type;
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iterator>
#include <type_traits>
#include <utility>
#include "ck/library/utility/iterator.hpp"
namespace ck {
namespace ranges {
template <typename R>
using iterator_t = decltype(std::begin(std::declval<R&>()));
template <typename R>
using sentinel_t = decltype(std::end(std::declval<R&>()));
template <typename R>
using range_size_t = decltype(std::size(std::declval<R&>()));
template <typename R>
using range_difference_t = ck::iter_difference_t<ranges::iterator_t<R>>;
template <typename R>
using range_value_t = iter_value_t<ranges::iterator_t<R>>;
template <typename R>
using range_reference_t = iter_reference_t<ranges::iterator_t<R>>;
template <typename T, typename = void>
struct is_range : std::false_type
{
};
template <typename T>
struct is_range<
T,
std::void_t<decltype(std::begin(std::declval<T&>())), decltype(std::end(std::declval<T&>()))>>
: std::true_type
{
};
template <typename T>
inline constexpr bool is_range_v = is_range<T>::value;
template <typename T, typename = void>
struct is_sized_range : std::false_type
{
};
template <typename T>
struct is_sized_range<T, std::void_t<decltype(std::size(std::declval<T&>()))>>
: std::bool_constant<is_range_v<T>>
{
};
} // namespace ranges
} // namespace ck
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
namespace ck { namespace ck {
...@@ -111,15 +112,15 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, ...@@ -111,15 +112,15 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), Row>::value) if(std::is_same<decltype(layout), Row>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -330,8 +331,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, ...@@ -330,8 +331,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
{ {
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data()); e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
pass = pass & ck::utils::check_err(e1_g_m_o_device_result.mData, pass = pass & ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result);
e1_g_m_o_host_result.mData);
if(do_log) if(do_log)
{ {
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
namespace ck { namespace ck {
...@@ -105,15 +106,15 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -105,15 +106,15 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), Row>::value) if(std::is_same<decltype(layout), Row>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -283,8 +284,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -283,8 +284,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
{ {
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
pass = pass & pass = pass & ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result);
ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData);
if(do_log) if(do_log)
{ {
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
namespace ck { namespace ck {
...@@ -50,15 +51,15 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -50,15 +51,15 @@ bool profile_batched_gemm_impl(int do_verification,
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -202,8 +203,7 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -202,8 +203,7 @@ bool profile_batched_gemm_impl(int do_verification,
{ {
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
pass = pass & pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
namespace ck { namespace ck {
...@@ -78,15 +79,15 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -78,15 +79,15 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
auto layout) { auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {row * stride, stride, 1_uz});
std::vector<std::size_t>({row * stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {col * stride, 1_uz, stride});
std::vector<std::size_t>({col * stride, 1, stride}));
} }
}; };
...@@ -95,17 +96,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -95,17 +96,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
Tensor<CDataType> c_g_m_n_host_result( Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d0_g_m_host_result({BatchCount, M});
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); Tensor<ReduceDataType> d1_g_m_host_result({BatchCount, M});
Tensor<ReduceDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<CDataType> c_g_m_n_device_result( Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d0_g_m_device_result({BatchCount, M});
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); Tensor<ReduceDataType> d1_g_m_device_result({BatchCount, M});
Tensor<ReduceDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
...@@ -319,12 +316,9 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -319,12 +316,9 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
bool c_error = bool c_error = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData); bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result);
bool d0_error = bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result);
ck::utils::check_err(d0_g_m_device_result.mData, d0_g_m_host_result.mData);
bool d1_error =
ck::utils::check_err(d1_g_m_device_result.mData, d1_g_m_host_result.mData);
pass = pass && (c_error == true); pass = pass && (c_error == true);
pass = pass && (d0_error == true); pass = pass && (d0_error == true);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment