device_tensor.hpp 969 Bytes
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#pragma once
#include "tensor.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "tensor_descriptor.hpp"

template <typename ConstTensorDesc, std::size_t... Is>
auto make_TensorDescriptor_impl(ConstTensorDesc, std::integer_sequence<std::size_t, Is...>)
{
    std::initializer_list<std::size_t> lengths = {ConstTensorDesc::GetLengths()[Is]...};
    std::initializer_list<std::size_t> strides = {ConstTensorDesc::GetStrides()[Is]...};

    return TensorDescriptor(lengths, strides);
}

template <typename ConstTensorDesc>
auto make_TensorDescriptor(ConstTensorDesc)
{
    return make_TensorDescriptor_impl(
        ConstTensorDesc{},
        std::make_integer_sequence<std::size_t, ConstTensorDesc::GetNumOfDimension()>{});
}

template <typename ConstTensorDesc>
void ostream_ConstantTensorDescriptor(ConstTensorDesc, std::ostream& os = std::cout)
{
    ostream_TensorDescriptor(make_TensorDescriptor(ConstTensorDesc{}), os);
}