Commit e4e99a49 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use new utilities to shorten codes

parent 7acbf104
...@@ -11,17 +11,19 @@ ...@@ -11,17 +11,19 @@
#include <type_traits> #include <type_traits>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using In0DataType = ck::int4_t; using In0DataType = ck::int4_t;
using Wei0DataType = ck::int4_t; using Wei0DataType = ck::int4_t;
......
...@@ -7,17 +7,19 @@ ...@@ -7,17 +7,19 @@
#include <type_traits> #include <type_traits>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using In0DataType = int8_t; using In0DataType = int8_t;
using Wei0DataType = int8_t; using Wei0DataType = int8_t;
......
...@@ -37,10 +37,10 @@ bool run_grouped_conv_conv_fwd(bool do_verification, ...@@ -37,10 +37,10 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
Tensor<Out1DataType> out1_host(out1_g_n_k_wos_desc); Tensor<Out1DataType> out1_host(out1_g_n_k_wos_desc);
Tensor<Out1DataType> out1_device(out1_g_n_k_wos_desc); Tensor<Out1DataType> out1_device(out1_g_n_k_wos_desc);
std::cout << "in0: " << in0.mDesc << std::endl; std::cout << "in0: " << in0.GetDesc() << std::endl;
std::cout << "wei0: " << wei0.mDesc << std::endl; std::cout << "wei0: " << wei0.GetDesc() << std::endl;
std::cout << "wei1: " << wei1.mDesc << std::endl; std::cout << "wei1: " << wei1.GetDesc() << std::endl;
std::cout << "out1: " << out1_host.mDesc << std::endl; std::cout << "out1: " << out1_host.GetDesc() << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -57,27 +57,27 @@ bool run_grouped_conv_conv_fwd(bool do_verification, ...@@ -57,27 +57,27 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
} }
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
DeviceMem in0_device_buf(sizeof(KernelIn0DataType) * in0.mDesc.GetElementSpaceSize()); DeviceMem in0_device_buf(in0.GetMemorySize());
DeviceMem wei0_device_buf(sizeof(KernelWei0DataType) * wei0.mDesc.GetElementSpaceSize()); DeviceMem wei0_device_buf(wei0.GetMemorySize());
DeviceMem wei1_device_buf(sizeof(KernelWei1DataType) * wei1.mDesc.GetElementSpaceSize()); DeviceMem wei1_device_buf(wei1.GetMemorySize());
DeviceMem out1_device_buf(sizeof(KernelOut1DataType) * out1_device.mDesc.GetElementSpaceSize()); DeviceMem out1_device_buf(out1_device.GetMemorySize());
const Tensor<KernelIn0DataType> in0_converted(in0); const Tensor<KernelIn0DataType> in0_converted(in0);
const Tensor<KernelWei0DataType> wei0_converted(wei0); const Tensor<KernelWei0DataType> wei0_converted(wei0);
const Tensor<KernelWei1DataType> wei1_converted(wei1); const Tensor<KernelWei1DataType> wei1_converted(wei1);
in0_device_buf.ToDevice(in0_converted.mData.data()); in0_device_buf.ToDevice(in0_converted.data());
wei0_device_buf.ToDevice(wei0_converted.mData.data()); wei0_device_buf.ToDevice(wei0_converted.data());
wei1_device_buf.ToDevice(wei1_converted.mData.data()); wei1_device_buf.ToDevice(wei1_converted.data());
#else #else
DeviceMem in0_device_buf(sizeof(In0DataType) * in0.mDesc.GetElementSpaceSize()); DeviceMem in0_device_buf(in0.GetMemorySize());
DeviceMem wei0_device_buf(sizeof(Wei0DataType) * wei0.mDesc.GetElementSpaceSize()); DeviceMem wei0_device_buf(wei0.GetMemorySize());
DeviceMem wei1_device_buf(sizeof(Wei1DataType) * wei1.mDesc.GetElementSpaceSize()); DeviceMem wei1_device_buf(wei1.GetMemorySize());
DeviceMem out1_device_buf(sizeof(Out1DataType) * out1_device.mDesc.GetElementSpaceSize()); DeviceMem out1_device_buf(out1_device.GetMemorySize());
in0_device_buf.ToDevice(in0.mData.data()); in0_device_buf.ToDevice(in0.data());
wei0_device_buf.ToDevice(wei0.mData.data()); wei0_device_buf.ToDevice(wei0.data());
wei1_device_buf.ToDevice(wei1.mData.data()); wei1_device_buf.ToDevice(wei1.data());
#endif #endif
std::array<ck::index_t, NDimSpatial + 3> a0_g_n_c_wis_lengths{}; std::array<ck::index_t, NDimSpatial + 3> a0_g_n_c_wis_lengths{};
...@@ -97,7 +97,7 @@ bool run_grouped_conv_conv_fwd(bool do_verification, ...@@ -97,7 +97,7 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
std::array<ck::index_t, NDimSpatial> input1_left_pads{}; std::array<ck::index_t, NDimSpatial> input1_left_pads{};
std::array<ck::index_t, NDimSpatial> input1_right_pads{}; std::array<ck::index_t, NDimSpatial> input1_right_pads{};
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in0_g_n_c_wis_desc.GetLengths(), a0_g_n_c_wis_lengths); copy(in0_g_n_c_wis_desc.GetLengths(), a0_g_n_c_wis_lengths);
copy(in0_g_n_c_wis_desc.GetStrides(), a0_g_n_c_wis_strides); copy(in0_g_n_c_wis_desc.GetStrides(), a0_g_n_c_wis_strides);
...@@ -120,18 +120,17 @@ bool run_grouped_conv_conv_fwd(bool do_verification, ...@@ -120,18 +120,17 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
const ck::index_t gemm_batch = a0_g_n_c_wis_lengths[0]; const ck::index_t gemm_batch = a0_g_n_c_wis_lengths[0];
const ck::index_t gemm0_m_length = const ck::index_t gemm0_m_length =
e1_g_n_k_wos_lengths[1] * std::accumulate(e1_g_n_k_wos_lengths.begin() + 3, e1_g_n_k_wos_lengths[1] * ck::accumulate_n(e1_g_n_k_wos_lengths.begin() + 3,
e1_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, NDimSpatial,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
const ck::index_t gemm0_n_length = b0_g_k_c_xs_lengths[1]; const ck::index_t gemm0_n_length = b0_g_k_c_xs_lengths[1];
const ck::index_t gemm0_k_length = const ck::index_t gemm0_k_length = ck::accumulate_n(b0_g_k_c_xs_lengths.begin() + 2,
std::accumulate(b0_g_k_c_xs_lengths.begin() + 2, NDimSpatial + 1,
b0_g_k_c_xs_lengths.begin() + 2 + NDimSpatial + 1, ck::index_t{1},
ck::index_t{1}, std::multiplies<ck::index_t>{});
std::multiplies<ck::index_t>{});
const ck::index_t gemm1_n_length = b1_g_k_c_xs_lengths[1]; const ck::index_t gemm1_n_length = b1_g_k_c_xs_lengths[1];
...@@ -149,36 +148,28 @@ bool run_grouped_conv_conv_fwd(bool do_verification, ...@@ -149,36 +148,28 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
auto device_op = DeviceOpInstance{}; auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument( auto argument = device_op.MakeArgument(in0_device_buf.GetDeviceBuffer(),
#ifdef BUILD_INT4_EXAMPLE wei0_device_buf.GetDeviceBuffer(),
static_cast<KernelIn0DataType*>(in0_device_buf.GetDeviceBuffer()), wei1_device_buf.GetDeviceBuffer(),
static_cast<KernelWei0DataType*>(wei0_device_buf.GetDeviceBuffer()), out1_device_buf.GetDeviceBuffer(),
static_cast<KernelWei1DataType*>(wei1_device_buf.GetDeviceBuffer()), gemm0_m_length,
static_cast<KernelOut1DataType*>(out1_device_buf.GetDeviceBuffer()), gemm0_n_length,
#else gemm0_k_length,
static_cast<In0DataType*>(in0_device_buf.GetDeviceBuffer()), gemm1_n_length,
static_cast<Wei0DataType*>(wei0_device_buf.GetDeviceBuffer()), gemm_batch,
static_cast<Wei1DataType*>(wei1_device_buf.GetDeviceBuffer()), a0_stride,
static_cast<Out1DataType*>(out1_device_buf.GetDeviceBuffer()), b0_stride,
#endif b1_stride,
gemm0_m_length, e1_stride,
gemm0_n_length, a0_batch_stride,
gemm0_k_length, b0_batch_stride,
gemm1_n_length, b1_batch_stride,
gemm_batch, e1_batch_stride,
a0_stride, in0_element_op,
b0_stride, wei0_element_op,
b1_stride, out0_element_op,
e1_stride, wei1_element_op,
a0_batch_stride, out1_element_op);
b0_batch_stride,
b1_batch_stride,
e1_batch_stride,
in0_element_op,
wei0_element_op,
out0_element_op,
wei1_element_op,
out1_element_op);
if(!device_op.IsSupportedArgument(argument)) if(!device_op.IsSupportedArgument(argument))
{ {
...@@ -251,17 +242,17 @@ bool run_grouped_conv_conv_fwd(bool do_verification, ...@@ -251,17 +242,17 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
ref_conv1_invoker.Run(ref_conv1_argument); ref_conv1_invoker.Run(ref_conv1_argument);
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
Tensor<KernelOut1DataType> out1_device_converted(out1_host.mDesc); Tensor<KernelOut1DataType> out1_device_converted(out1_host.GetDesc());
out1_device_buf.FromDevice(out1_device_converted.mData.data()); out1_device_buf.FromDevice(out1_device_converted.data());
out1_device = out1_device_converted.CopyAsType<Out1DataType>(); out1_device = out1_device_converted.CopyAsType<Out1DataType>();
#else #else
out1_device_buf.FromDevice(out1_device.mData.data()); out1_device_buf.FromDevice(out1_device.data());
#endif #endif
return ck::utils::check_err( return ck::utils::check_err(
out1_device.mData, out1_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); out1_device, out1_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return true; return true;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h> #include <getopt.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
...@@ -12,13 +13,14 @@ ...@@ -12,13 +13,14 @@
#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" #include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp" #include "ck/library/utility/ranges.hpp"
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
...@@ -100,35 +102,37 @@ int main(int argc, char* argv[]) ...@@ -100,35 +102,37 @@ int main(int argc, char* argv[])
Tensor<GammaDataType> gamma({G, C}); Tensor<GammaDataType> gamma({G, C});
Tensor<BetaDataType> beta({G, C}); Tensor<BetaDataType> beta({G, C});
ck::utils::FillUniformDistribution<XDataType>{0.f, 1.f}(x.begin(), x.end()); ck::utils::FillUniformDistribution<XDataType>{0.f, 1.f}(x);
ck::utils::FillUniformDistribution<GammaDataType>{0.f, 1.f}(gamma.begin(), gamma.end()); ck::utils::FillUniformDistribution<GammaDataType>{0.f, 1.f}(gamma);
ck::utils::FillUniformDistribution<BetaDataType>{0.f, 1.f}(beta.begin(), beta.end()); ck::utils::FillUniformDistribution<BetaDataType>{0.f, 1.f}(beta);
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(x.GetMemorySize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(gamma.GetMemorySize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(beta.GetMemorySize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(y.GetMemorySize());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.data());
beta_dev.ToDevice(beta.mData.data()); beta_dev.ToDevice(beta.data());
const auto y_element_op = YElementOp{}; const auto y_element_op = YElementOp{};
using Indices = std::vector<ck::index_t>;
auto device_instance = DeviceInstance{}; auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer( auto argument_ptr =
{N, H, W, G, C}, device_instance.MakeArgumentPointer({N, H, W, G, C},
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, ck::ranges::to<Indices>(x.GetStrides()),
{0, 0, 0, C, 1}, {0, 0, 0, C, 1},
{0, 0, 0, C, 1}, {0, 0, 0, C, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, ck::ranges::to<Indices>(y.GetStrides()),
{1, 2, 4}, // reduction dimension: [H, W, C] {1, 2, 4}, // reduction dimension: [H, W, C]
1e-6, 1e-6,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
y_element_op); y_element_op);
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -164,8 +168,8 @@ int main(int argc, char* argv[]) ...@@ -164,8 +168,8 @@ int main(int argc, char* argv[])
auto ref_invoker = ref.MakeInvoker(); auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.data());
pass &= ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <cstddef> #include <cstddef>
#include <array> #include <array>
#include <iterator>
#include <type_traits> #include <type_traits>
namespace ck { namespace ck {
...@@ -13,16 +14,18 @@ template <typename T> ...@@ -13,16 +14,18 @@ template <typename T>
class span class span
{ {
public: public:
using element_type = T; using element_type = T;
using value_type = std::remove_cv_t<element_type>; using value_type = std::remove_cv_t<element_type>;
using size_type = std::size_t; using size_type = std::size_t;
using difference_type = std::ptrdiff_t; using difference_type = std::ptrdiff_t;
using pointer = element_type*; using pointer = element_type*;
using const_pointer = const element_type*; using const_pointer = const element_type*;
using reference = element_type&; using reference = element_type&;
using const_reference = const element_type&; using const_reference = const element_type&;
using iterator = pointer; using iterator = pointer;
using const_iterator = pointer; using const_iterator = pointer;
using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
constexpr span() : span(nullptr, size_type{0}) {} constexpr span() : span(nullptr, size_type{0}) {}
...@@ -51,6 +54,12 @@ class span ...@@ -51,6 +54,12 @@ class span
constexpr iterator end() const noexcept { return begin() + size(); } constexpr iterator end() const noexcept { return begin() + size(); }
constexpr const_iterator cend() const noexcept { return end(); } constexpr const_iterator cend() const noexcept { return end(); }
constexpr reverse_iterator rbegin() const noexcept { return std::make_reverse_iterator(end()); }
constexpr const_reverse_iterator crbegin() const noexcept { return rbegin(); }
constexpr reverse_iterator rend() const noexcept { return std::make_reverse_iterator(begin()); }
constexpr const_reverse_iterator crend() const noexcept { return rend(); }
constexpr reference front() const { return *begin(); } constexpr reference front() const { return *begin(); }
constexpr reference back() const { return *(--end()); } constexpr reference back() const { return *(--end()); }
...@@ -59,6 +68,29 @@ class span ...@@ -59,6 +68,29 @@ class span
constexpr size_type size() const noexcept { return size_; } constexpr size_type size() const noexcept { return size_; }
constexpr bool empty() const noexcept { return size() == 0; }
friend constexpr iterator begin(const span& s) noexcept { return s.begin(); }
friend constexpr const_iterator cbegin(const span& s) noexcept { return s.begin(); }
friend constexpr iterator end(const span& s) noexcept { return s.end(); }
friend constexpr const_iterator cend(const span& s) noexcept { return s.end(); }
friend constexpr reverse_iterator rbegin(const span& s) noexcept { return s.rbegin(); }
friend constexpr const_reverse_iterator crbegin(const span& s) noexcept { return s.crbegin(); }
friend constexpr reverse_iterator rend(const span& s) noexcept { return s.rend(); }
friend constexpr const_reverse_iterator crend(const span& s) noexcept { return s.crend(); }
friend constexpr reference front(const span& s) { return s.front(); }
friend constexpr reference back(const span& s) { return s.back(); }
friend constexpr pointer data(const span& s) noexcept { return s.data(); }
friend constexpr size_type size(const span& s) noexcept { return s.size(); }
friend constexpr bool empty(const span& s) noexcept { return s.empty(); }
private: private:
pointer ptr_; pointer ptr_;
size_type size_; size_type size_;
......
...@@ -57,7 +57,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -57,7 +57,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) { auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2]; const int K = arg.a_g_m_k_.GetLengths()[2];
AccDataType v_acc = 0; AccDataType v_acc = 0;
...@@ -81,9 +81,9 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -81,9 +81,9 @@ struct ReferenceBatchedGemm : public device::BaseOperator
}; };
make_ParallelTensorFunctor(f_gmk_gkn_gmn, make_ParallelTensorFunctor(f_gmk_gkn_gmn,
arg.c_g_m_n_.mDesc.GetLengths()[0], arg.c_g_m_n_.GetLengths()[0],
arg.c_g_m_n_.mDesc.GetLengths()[1], arg.c_g_m_n_.GetLengths()[1],
arg.c_g_m_n_.mDesc.GetLengths()[2])( arg.c_g_m_n_.GetLengths()[2])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
......
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
#pragma once #pragma once
#include <iostream>
#include <vector>
#include <array>
#include <algorithm> #include <algorithm>
#include <array>
#include <iostream>
#include <thread> #include <thread>
#include <vector>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/utility/data_type.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -3,12 +3,14 @@ ...@@ -3,12 +3,14 @@
#pragma once #pragma once
#include <algorithm>
#include <array>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <array>
#include <algorithm>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp"
#include "ck/utility/data_type.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -72,9 +72,9 @@ struct ReferenceCGemm : public device::BaseOperator ...@@ -72,9 +72,9 @@ struct ReferenceCGemm : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1]; const std::size_t K = arg.a_m_k_real_.GetLengths()[1];
if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1]) if(K != arg.a_m_k_imag_.GetLengths()[1])
{ {
throw std::runtime_error("wrong! Incompatible real and imag sizes in CGEMM"); throw std::runtime_error("wrong! Incompatible real and imag sizes in CGEMM");
} }
...@@ -111,13 +111,11 @@ struct ReferenceCGemm : public device::BaseOperator ...@@ -111,13 +111,11 @@ struct ReferenceCGemm : public device::BaseOperator
arg.c_m_n_imag_(m, n) = ck::type_convert<CDataType>(v_c_imag); arg.c_m_n_imag_(m, n) = ck::type_convert<CDataType>(v_c_imag);
}; };
make_ParallelTensorFunctor(f_mk_kn_mn_real, make_ParallelTensorFunctor(
arg.c_m_n_real_.mDesc.GetLengths()[0], f_mk_kn_mn_real, arg.c_m_n_real_.GetLengths()[0], arg.c_m_n_real_.GetLengths()[1])(
arg.c_m_n_real_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_mk_kn_mn_imag, make_ParallelTensorFunctor(
arg.c_m_n_imag_.mDesc.GetLengths()[0], f_mk_kn_mn_imag, arg.c_m_n_imag_.GetLengths()[0], arg.c_m_n_imag_.GetLengths()[1])(
arg.c_m_n_imag_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -76,14 +76,14 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -76,14 +76,14 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.wei_k_c_y_x_.GetLengths()[1]; ++c)
{ {
for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.wei_k_c_y_x_.GetLengths()[2]; ++y)
{ {
auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) + auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) - ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < arg.wei_k_c_y_x_.GetLengths()[3]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) + ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
...@@ -91,10 +91,10 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -91,10 +91,10 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && arg.in_n_c_hi_wi_.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) arg.in_n_c_hi_wi_.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -119,10 +119,10 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -119,10 +119,10 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], arg.out_n_k_ho_wo_.GetLengths()[0],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], arg.out_n_k_ho_wo_.GetLengths()[1],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], arg.out_n_k_ho_wo_.GetLengths()[2],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( arg.out_n_k_ho_wo_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
......
...@@ -79,14 +79,14 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator ...@@ -79,14 +79,14 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.wei_k_c_y_x_.GetLengths()[1]; ++c)
{ {
for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.wei_k_c_y_x_.GetLengths()[2]; ++y)
{ {
auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) + auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) - ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < arg.wei_k_c_y_x_.GetLengths()[3]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) + ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
...@@ -94,10 +94,10 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator ...@@ -94,10 +94,10 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && arg.in_n_c_hi_wi_.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) arg.in_n_c_hi_wi_.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -125,10 +125,10 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator ...@@ -125,10 +125,10 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], arg.out_n_k_ho_wo_.GetLengths()[0],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], arg.out_n_k_ho_wo_.GetLengths()[1],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], arg.out_n_k_ho_wo_.GetLengths()[2],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( arg.out_n_k_ho_wo_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
......
...@@ -57,7 +57,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -57,7 +57,7 @@ struct ReferenceGemm : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_m_k_.GetLengths()[1];
AccDataType v_acc = 0; AccDataType v_acc = 0;
...@@ -81,7 +81,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -81,7 +81,7 @@ struct ReferenceGemm : public device::BaseOperator
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( f_mk_kn_mn, arg.c_m_n_.GetLengths()[0], arg.c_m_n_.GetLengths()[1])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -61,7 +61,7 @@ struct ReferenceGemmBias2D : public device::BaseOperator ...@@ -61,7 +61,7 @@ struct ReferenceGemmBias2D : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_m_k_.GetLengths()[1];
AccDataType a = 0; AccDataType a = 0;
AccDataType b = 0; AccDataType b = 0;
...@@ -79,7 +79,7 @@ struct ReferenceGemmBias2D : public device::BaseOperator ...@@ -79,7 +79,7 @@ struct ReferenceGemmBias2D : public device::BaseOperator
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( f_mk_kn_mn, arg.c_m_n_.GetLengths()[0], arg.c_m_n_.GetLengths()[1])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -60,7 +60,7 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator ...@@ -60,7 +60,7 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_m_k_.GetLengths()[1];
float v_acc = 0; float v_acc = 0;
...@@ -83,7 +83,7 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator ...@@ -83,7 +83,7 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( f_mk_kn_mn, arg.c_m_n_.GetLengths()[0], arg.c_m_n_.GetLengths()[1])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -63,7 +63,7 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator ...@@ -63,7 +63,7 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_m_k_.GetLengths()[1];
float v_acc = 0; float v_acc = 0;
...@@ -89,7 +89,7 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator ...@@ -89,7 +89,7 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( f_mk_kn_mn, arg.c_m_n_.GetLengths()[0], arg.c_m_n_.GetLengths()[1])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -38,11 +40,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -38,11 +40,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
const Tensor<InDataType>& beta, // 1xN const Tensor<InDataType>& beta, // 1xN
const InDataType epsilon = 1e-5) const InDataType epsilon = 1e-5)
{ {
assert(acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] && assert(acc.GetLengths()[1] == gamma.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]); acc.GetLengths()[1] == beta.GetLengths()[0]);
size_t M = acc.mDesc.GetLengths()[0]; size_t M = acc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1]; size_t N = acc.GetLengths()[1];
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M})));
...@@ -131,7 +133,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -131,7 +133,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
Tensor<AccDataType> acc_m_n(arg.c_m_n_.mDesc); Tensor<AccDataType> acc_m_n(arg.c_m_n_.GetDesc());
acc_m_n.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}); acc_m_n.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
......
...@@ -30,7 +30,7 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -30,7 +30,7 @@ struct ReferenceSoftmax : public device::BaseOperator
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims) : in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
{ {
// std::cout << "debug: scalar dims: "; // std::cout << "debug: scalar dims: ";
for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++) for(size_t i = 0; i < in.GetNumOfDimension(); i++)
{ {
if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) == if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) ==
sm_reduce_dims.end()) sm_reduce_dims.end())
...@@ -58,7 +58,7 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -58,7 +58,7 @@ struct ReferenceSoftmax : public device::BaseOperator
std::vector<size_t> scalar_lengths; std::vector<size_t> scalar_lengths;
for(index_t dim : arg.sm_scalar_dims_) for(index_t dim : arg.sm_scalar_dims_)
{ {
scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]); scalar_lengths.push_back(arg.in_.GetLengths()[dim]);
} }
Tensor<AccDataType> reduce_max(scalar_lengths); Tensor<AccDataType> reduce_max(scalar_lengths);
...@@ -84,7 +84,7 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -84,7 +84,7 @@ struct ReferenceSoftmax : public device::BaseOperator
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") << // LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// std::endl; // std::endl;
Tensor<AccDataType> in_stable(arg.in_.mDesc); Tensor<AccDataType> in_stable(arg.in_.GetDesc());
in_stable.ForEach([&](auto& self, auto idx) { in_stable.ForEach([&](auto& self, auto idx) {
// numerator = exp(x - max(x)) // numerator = exp(x - max(x))
self(idx) = std::exp(static_cast<AccDataType>(arg.in_(idx)) - self(idx) = std::exp(static_cast<AccDataType>(arg.in_(idx)) -
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <iterator>
#include <type_traits>
#include <utility>
namespace ck {
namespace ranges {
template <typename InputRange, typename OutputIterator>
auto copy(InputRange&& range, OutputIterator iter)
-> decltype(std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter))
{
return std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter);
}
template <typename T, typename OutputRange>
auto fill(OutputRange&& range, const T& init)
-> std::void_t<decltype(std::fill(std::begin(std::forward<OutputRange>(range)),
std::end(std::forward<OutputRange>(range)),
init))>
{
std::fill(std::begin(std::forward<OutputRange>(range)),
std::end(std::forward<OutputRange>(range)),
init);
}
template <typename InputRange, typename OutputIterator, typename UnaryOperation>
auto transform(InputRange&& range, OutputIterator iter, UnaryOperation unary_op)
-> decltype(std::transform(std::begin(range), std::end(range), iter, unary_op))
{
return std::transform(std::begin(range), std::end(range), iter, unary_op);
}
} // namespace ranges
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <array>
#include <initializer_list>
#include <iterator>
#include <limits>
#include <type_traits>
#include "ck/library/utility/ranges.hpp"
namespace ck {
namespace utils {
namespace detail {
template <typename Range, std::size_t Size>
class to_array_result;
template <typename FromT, std::size_t Size>
class to_array_result<FromT (&)[Size], Size> final
{
public:
explicit constexpr to_array_result(FromT (&source)[Size]) noexcept : source_(source) {}
template <typename T>
operator std::array<T, Size>() const
{
static_assert(std::is_convertible_v<FromT, T>);
return copy_as_array<T>(source_, std::make_index_sequence<Size>{});
}
private:
template <typename T, std::size_t... Indices>
static std::array<T, Size> copy_as_array(FromT (&array)[Size], std::index_sequence<Indices...>)
{
return std::array<T, Size>{array[Indices]...};
}
private:
FromT (&source_)[Size];
};
template <typename FromT, std::size_t Size>
class to_array_result<FromT(&&)[Size], Size> final
{
public:
explicit constexpr to_array_result(FromT(&&source)[Size]) noexcept : source_(std::move(source))
{
}
template <typename T>
operator std::array<T, Size>() &&
{
static_assert(std::is_convertible_v<FromT, T>);
return move_as_array<T>(std::move(source_), std::make_index_sequence<Size>{});
}
private:
template <typename T, std::size_t... Indices>
static std::array<T, Size> move_as_array(FromT(&&array)[Size], std::index_sequence<Indices...>)
{
return std::array<T, Size>{std::move(array[Indices])...};
}
private:
FromT(&&source_)[Size];
};
template <typename Range>
class to_array_result<Range, std::numeric_limits<std::size_t>::max()> final
{
public:
explicit constexpr to_array_result(const Range& source) noexcept : source_(source) {}
template <typename T, std::size_t Size>
operator std::array<T, Size>() const
{
static_assert(std::is_convertible_v<ranges::range_value_t<Range>, T>);
std::array<T, Size> destination;
std::copy_n(std::begin(source_),
std::min<std::size_t>(Size, std::size(source_)),
std::begin(destination));
return destination;
}
private:
const Range& source_;
};
struct empty_array_result final
{
template <typename T>
operator std::array<T, 0>() const
{
return std::array<T, 0>{};
}
};
} // namespace detail
template <typename T, std::size_t N>
inline constexpr auto to_array(T (&array)[N]) -> detail::to_array_result<decltype(array), N>
{
return detail::to_array_result<decltype(array), N>{array};
}
template <typename T, std::size_t N>
inline constexpr auto to_array(T(&&array)[N]) -> detail::to_array_result<decltype(array), N>
{
return detail::to_array_result<decltype(array), N>{std::move(array)};
}
template <typename Range>
inline constexpr auto to_array(const Range& range) noexcept
-> detail::to_array_result<ck::remove_cvref_t<Range>, std::numeric_limits<std::size_t>::max()>
{
return detail::to_array_result<ck::remove_cvref_t<Range>,
std::numeric_limits<std::size_t>::max()>{range};
}
inline constexpr auto empty_array() noexcept { return detail::empty_array_result{}; }
} // namespace utils
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cstddef>
#include <type_traits>
namespace ck {
namespace utils {
struct mutable_buffer
{
using pointer = std::byte*;
using size_type = std::size_t;
friend mutable_buffer operator+(const mutable_buffer& buffer, size_type n) noexcept
{
mutable_buffer advanced(buffer);
advanced += n;
return advanced;
}
constexpr mutable_buffer() noexcept : mutable_buffer(nullptr, 0) {}
constexpr mutable_buffer(void* data, size_type size = 0) noexcept
: data_(static_cast<pointer>(data)), size_(size)
{
}
constexpr mutable_buffer& operator+=(size_type n) noexcept
{
const size_type advance = std::min(size(), n);
data_ = data() + advance;
size_ = size() - advance;
return *this;
}
constexpr pointer data() const noexcept { return data_; }
constexpr size_type size() const noexcept { return size_; }
constexpr operator void*() const noexcept { return data(); }
constexpr operator const void*() const noexcept { return data(); }
// this method only exists while T is complete type
template <typename T, typename = std::void_t<decltype(sizeof(T))>>
constexpr operator T*() const noexcept
{
if(size() % sizeof(T) != 0)
{
return nullptr;
}
return reinterpret_cast<T*>(data());
}
private:
pointer data_;
size_type size_;
};
} // namespace utils
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment