tensor_view.hpp 6.66 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/tensor_description/tensor_coordinate.hpp"

namespace ck {

template <typename BufferView_, typename TensorDesc_>
struct TensorView
{
    using BufferView  = remove_reference_t<BufferView_>;
    using DataType    = typename BufferView::type;
    using TensorDesc  = remove_cvref_t<TensorDesc_>;
    using TensorIndex = Array<index_t, TensorDesc::GetNumOfTopDimension()>;
    using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));

    __host__ __device__ constexpr TensorView() = default;

    __host__ __device__ constexpr TensorView(const BufferView& buffer_view, const TensorDesc& desc)
        : buf_{buffer_view}, desc_{desc}
    {
    }

    __host__ __device__ constexpr auto& GetTensorDescriptor() const { return desc_; }

    __host__ __device__ static constexpr index_t GetNumOfDimension()
    {
        return TensorDesc::GetNumOfTopDimension();
    }

    __host__ __device__ constexpr const auto& GetBufferView() const { return buf_; }

    __host__ __device__ constexpr auto& GetBufferView() { return buf_; }

    __host__ __device__ constexpr DataType GetElement(const TensorCoord& coord) const
    {
        return buf_.template Get<DataType>(
            coord.GetOffset(),
            coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
    }

    __host__ __device__ constexpr void SetElement(const TensorCoord& coord, const DataType& x)
    {
        buf_.template Set<DataType>(
            coord.GetOffset(),
            coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
            x);
    }

    // X is vector of DataType.
    // "coord" is coordinate of DataType, not X. "coord" should be aligned to X
    template <typename X,
              typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
                                           typename scalar_type<remove_cvref_t<DataType>>::type>,
                                 bool>::type = false>
    __host__ __device__ constexpr remove_cvref_t<X>
    GetVectorizedElements(const TensorCoord& coord) const
    {
        return buf_.template Get<X>(
            coord.GetOffset(),
            coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
    }

    // X is vector of DataType.
    // "coord" is coordinate of DataType, not X. "coord" should be aligned to X
    template <typename X,
              typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
                                           typename scalar_type<remove_cvref_t<DataType>>::type>,
                                 bool>::type = false>
    __host__ __device__ constexpr void SetVectorizedElements(const TensorCoord& coord, const X& x)
    {
        buf_.template Set<X>(coord.GetOffset(),
                             coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
                             x);
    }

    __host__ __device__ void Print() const
    {
        printf("TensorView{");

        // buf_
        printf("buf_: ");
        print(buf_);
        printf(", ");

        // desc_
        printf("desc_: ");
        print(desc_);

        printf("}");
    }

    // member
    BufferView buf_;
    TensorDesc desc_;
};

template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
          typename DataType,
          typename... Ts>
__host__ __device__ constexpr auto make_tensor_view(DataType* p,
                                                    const TensorDescriptor<Ts...>& desc)
{
    auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.GetElementSpaceSize());

    return TensorView<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}

template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
          typename DataType,
          typename... Lengths,
          typename... Strides,
          index_t GuaranteedLastDimensionVectorLength                              = -1,
          index_t GuaranteedLastDimensionVectorStride                              = -1,
          typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto
make_naive_tensor_view(DataType* p,
                       const Tuple<Lengths...>& lengths,
                       const Tuple<Strides...>& strides,
                       Number<GuaranteedLastDimensionVectorLength> = Number<-1>{},
                       Number<GuaranteedLastDimensionVectorStride> = Number<-1>{})
{
    auto desc = make_naive_tensor_descriptor(lengths,
                                             strides,
                                             Number<GuaranteedLastDimensionVectorLength>{},
                                             Number<GuaranteedLastDimensionVectorStride>{});

    auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.GetElementSpaceSize());

    return TensorView<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}

template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
          typename DataType,
          typename... Lengths,
          index_t GuaranteedLastDimensionVectorLength = -1>
__host__ __device__ constexpr auto
make_naive_tensor_view_packed(DataType* p,
                              const Tuple<Lengths...>& lengths,
                              Number<GuaranteedLastDimensionVectorLength> = Number<-1>{})
{
    auto desc =
        make_naive_tensor_descriptor_packed(lengths, Number<GuaranteedLastDimensionVectorLength>{});

    auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.GetElementSpaceSize());

    return TensorView<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}

template <typename OldTensorView,
          typename NewTransforms,
          typename NewLowerDimensionOldVisibleIdss,
          typename NewUpperDimensionNewVisibleIdss>
__host__ __device__ constexpr auto transform_tensor_view(const OldTensorView& old_tensor_view,
                                                         const NewTransforms& new_transforms,
                                                         NewLowerDimensionOldVisibleIdss,
                                                         NewUpperDimensionNewVisibleIdss)
{
    auto new_desc = transform_tensor_descriptor(old_tensor_view.desc_,
                                                new_transforms,
                                                NewLowerDimensionOldVisibleIdss{},
                                                NewUpperDimensionNewVisibleIdss{});

    return TensorView<typename OldTensorView::BufferView, remove_cvref_t<decltype(new_desc)>>{
        old_tensor_view.buf_, new_desc};
}

} // namespace ck