"docs/source/en/api/schedulers/lms_discrete.mdx" did not exist on "98f346835ab43e642f5d7d66253b1e06065af21f"
tensor.hpp 4 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

namespace ck {

template <AddressSpaceEnum AddressSpace,
          bool InvalidElementUseNumericalZeroValue,
          typename T,
Chao Liu's avatar
Chao Liu committed
11
          typename TensorDescTmp>
Chao Liu's avatar
Chao Liu committed
12
13
struct Tensor
{
Chao Liu's avatar
Chao Liu committed
14
15
16
    using TensorDescriptor = remove_cvref_t<TensorDescTmp>;
    using DataType         = remove_reference_t<T>;

Chao Liu's avatar
Chao Liu committed
17
    static constexpr AddressSpaceEnum kAddressSpace_ = AddressSpace;
Chao Liu's avatar
Chao Liu committed
18
19
20
    static constexpr bool kInvalidElementUseNumericalZeroValue_ =
        InvalidElementUseNumericalZeroValue;

Chao Liu's avatar
Chao Liu committed
21
    __host__ __device__ constexpr Tensor() : buf_{}, desc_{} {}
Chao Liu's avatar
Chao Liu committed
22

Chao Liu's avatar
Chao Liu committed
23
    __host__ __device__ constexpr Tensor(DataType* p_data, TensorDescriptor desc)
Chao Liu's avatar
Chao Liu committed
24
25
26
27
        : buf_{p_data, desc.GetElementSpaceSize()}, desc_{desc}
    {
    }

Chao Liu's avatar
Chao Liu committed
28
29
30
    __host__ __device__ constexpr Tensor(DataType* p_data,
                                         TensorDescriptor desc,
                                         DataType invalid_element_value)
Chao Liu's avatar
Chao Liu committed
31
32
33
34
        : buf_{p_data, desc.GetElementSpaceSize(), invalid_element_value}, desc_{desc}
    {
    }

Chao Liu's avatar
Chao Liu committed
35
    // member
Chao Liu's avatar
Chao Liu committed
36
    DynamicBuffer<AddressSpace,
Chao Liu's avatar
Chao Liu committed
37
38
                  DataType,
                  typename TensorDescriptor::ElementSpaceSizeType,
Chao Liu's avatar
Chao Liu committed
39
40
41
                  InvalidElementUseNumericalZeroValue>
        buf_;

Chao Liu's avatar
Chao Liu committed
42
    TensorDescriptor desc_;
Chao Liu's avatar
Chao Liu committed
43
44
45
46
47
48
49
50
51
52
53
54
};

template <AddressSpaceEnum AddressSpace,
          bool InvalidElementUseNumericalZeroValue,
          typename T,
          typename TensorDesc>
__host__ __device__ constexpr auto make_tensor(const TensorDesc& desc, T* p_data)
{
    return Tensor<AddressSpace, InvalidElementUseNumericalZeroValue, T, remove_cvref_t<TensorDesc>>{
        p_data, desc};
}

Chao Liu's avatar
Chao Liu committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
template <AddressSpaceEnum AddressSpace,
          bool InvalidElementUseNumericalZeroValue,
          typename T,
          typename... Lengths,
          typename... Strides,
          typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto make_naive_tensor(const Tuple<Lengths...>& lengths,
                                                     const Tuple<Strides...>& strides,
                                                     T* p_data,
                                                     T invalid_element_value = 0)
{
    auto desc = make_naive_tensor_descriptor(lengths, strides);

    return Tensor<AddressSpace, InvalidElementUseNumericalZeroValue, T, decltype(desc)>{
        p_data, desc, invalid_element_value};
}

template <AddressSpaceEnum AddressSpace,
          bool InvalidElementUseNumericalZeroValue,
          typename T,
          typename... Lengths>
__host__ __device__ constexpr auto
make_naive_tensor_packed(const Tuple<Lengths...>& lengths, T* p_data, T invalid_element_value = 0)
{
    auto desc = make_naive_tensor_descriptor_packed(lengths);

    return Tensor<AddressSpace, InvalidElementUseNumericalZeroValue, T, decltype(desc)>{
        p_data, desc, invalid_element_value};
}

Chao Liu's avatar
Chao Liu committed
85
86
87
88
89
90
91
92
93
template <typename OldTensor,
          typename NewTransforms,
          typename NewLowerDimensionOldVisibleIdss,
          typename NewUpperDimensionNewVisibleIdss>
__host__ __device__ constexpr auto transform_tensor(const OldTensor& old_tensor,
                                                    const NewTransforms& new_transforms,
                                                    NewLowerDimensionOldVisibleIdss,
                                                    NewUpperDimensionNewVisibleIdss)
{
Chao Liu's avatar
Chao Liu committed
94
95
96
97
    const auto new_desc = transform_tensor_descriptor(old_tensor.desc_,
                                                      new_transforms,
                                                      NewLowerDimensionOldVisibleIdss{},
                                                      NewUpperDimensionNewVisibleIdss{});
Chao Liu's avatar
Chao Liu committed
98
99

    return Tensor<OldTensor::kAddressSpace_,
Chao Liu's avatar
Chao Liu committed
100
                  OldTensor::kInvalidElementUseNumericalZeroValue_,
Chao Liu's avatar
Chao Liu committed
101
102
103
104
105
                  typename OldTensor::DataType,
                  remove_cvref_t<decltype(new_desc)>>{old_tensor.buf_.p_data_, new_desc};
}

} // namespace ck