tensor.hpp 2.72 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

namespace ck {

template <AddressSpaceEnum AddressSpace,
          bool InvalidElementUseNumericalZeroValue,
          typename T,
Chao Liu's avatar
Chao Liu committed
11
          typename TensorDescTmp>
Chao Liu's avatar
Chao Liu committed
12
13
struct Tensor
{
Chao Liu's avatar
Chao Liu committed
14
15
16
    using TensorDescriptor = remove_cvref_t<TensorDescTmp>;
    using DataType         = remove_reference_t<T>;

Chao Liu's avatar
Chao Liu committed
17
18
19
20
21
22
    static constexpr AddressSpaceEnum kAdressSpace_ = AddressSpace;
    static constexpr bool kInvalidElementUseNumericalZeroValue_ =
        InvalidElementUseNumericalZeroValue;

    __host__ __device__ constexpr Tensor() : buf_{nullptr, 0}, desc_{} {}

Chao Liu's avatar
Chao Liu committed
23
    __host__ __device__ constexpr Tensor(DataType* p_data, TensorDescriptor desc)
Chao Liu's avatar
Chao Liu committed
24
25
26
27
        : buf_{p_data, desc.GetElementSpaceSize()}, desc_{desc}
    {
    }

Chao Liu's avatar
Chao Liu committed
28
29
30
    __host__ __device__ constexpr Tensor(DataType* p_data,
                                         TensorDescriptor desc,
                                         DataType invalid_element_value)
Chao Liu's avatar
Chao Liu committed
31
32
33
34
        : buf_{p_data, desc.GetElementSpaceSize(), invalid_element_value}, desc_{desc}
    {
    }

Chao Liu's avatar
Chao Liu committed
35
    // member
Chao Liu's avatar
Chao Liu committed
36
    DynamicBuffer<AddressSpace,
Chao Liu's avatar
Chao Liu committed
37
38
                  DataType,
                  typename TensorDescriptor::ElementSpaceSizeType,
Chao Liu's avatar
Chao Liu committed
39
40
41
                  InvalidElementUseNumericalZeroValue>
        buf_;

Chao Liu's avatar
Chao Liu committed
42
    TensorDescriptor desc_;
Chao Liu's avatar
Chao Liu committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
};

template <AddressSpaceEnum AddressSpace,
          bool InvalidElementUseNumericalZeroValue,
          typename T,
          typename TensorDesc>
__host__ __device__ constexpr auto make_tensor(const TensorDesc& desc, T* p_data)
{
    return Tensor<AddressSpace, InvalidElementUseNumericalZeroValue, T, remove_cvref_t<TensorDesc>>{
        p_data, desc};
}

template <typename OldTensor,
          typename NewTransforms,
          typename NewLowerDimensionOldVisibleIdss,
          typename NewUpperDimensionNewVisibleIdss>
__host__ __device__ constexpr auto transform_tensor(const OldTensor& old_tensor,
                                                    const NewTransforms& new_transforms,
                                                    NewLowerDimensionOldVisibleIdss,
                                                    NewUpperDimensionNewVisibleIdss)
{
    const auto new_desc = transform_tensor(old_tensor.desc_,
                                           new_transforms,
                                           NewLowerDimensionOldVisibleIdss{},
                                           NewUpperDimensionNewVisibleIdss{});

    return Tensor<OldTensor::kAddressSpace_,
                  OldTensor::kInvalidElementUseNumericalZeroValue,
                  typename OldTensor::DataType,
                  remove_cvref_t<decltype(new_desc)>>{old_tensor.buf_.p_data_, new_desc};
}

} // namespace ck