"git@developer.sourcefind.cn:change/sglang.git" did not exist on "532f998b0f894268b69b7310bf06349e26b8543c"
ConstantMergedTensorDescriptor.hip.hpp 4.82 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"

5
6
7
8
9
// OriginalTensorDesc : ConstantTensorDescriptor<...>
//     it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
//     each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
10
11
struct ConstantMergedTensorDescriptor
{
12
13
    static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};

Chao Liu's avatar
Chao Liu committed
14
15
    static constexpr index_t nDim         = sizeof...(OriginalDimMergeSeqs);
    static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
Chao Liu's avatar
Chao Liu committed
16
17
18

    __host__ __device__ constexpr ConstantMergedTensorDescriptor()
    {
19
20
21
22
23
24
        static_assert(nDim <= nOriginalDim, "wrong!");

        // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
        // OriginalTensorDesc::nDim number of dimensions

        // TODO: check OriginalDimMergeSeqs contains all original dimensions
Chao Liu's avatar
Chao Liu committed
25
26

        // TODO: check there is no duplication in OriginalDimMergeSeqs
Chao Liu's avatar
Chao Liu committed
27
28
    }

Chao Liu's avatar
Chao Liu committed
29
30
31
32
33
    __host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
    {
        return OriginalTensorDesc{};
    }

34
35
    __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }

Chao Liu's avatar
Chao Liu committed
36
37
38
39
    __host__ __device__ static constexpr index_t GetNumOfOriginalDimension()
    {
        return nOriginalDim;
    }
40
41
42

    template <index_t IDim>
    __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
43
    {
Chao Liu's avatar
Chao Liu committed
44
        return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
45
    }
Chao Liu's avatar
Chao Liu committed
46

47
48
49
    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetLength(Number<IDim>)
    {
Chao Liu's avatar
Chao Liu committed
50
        constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
51

52
53
54
55
56
57
58
59
        return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
    }

    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetStride(Number<IDim>)
    {
        static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
                      "wrong! stride of a merged dimension is undefined");
Chao Liu's avatar
Chao Liu committed
60

Chao Liu's avatar
Chao Liu committed
61
        constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front();
Chao Liu's avatar
Chao Liu committed
62

63
        return OriginalTensorDesc::GetStride(Number<idim_original>{});
Chao Liu's avatar
Chao Liu committed
64
65
    }

66
    __host__ __device__ static constexpr auto GetLengths()
Chao Liu's avatar
Chao Liu committed
67
    {
Chao Liu's avatar
Chao Liu committed
68
        return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
Chao Liu's avatar
Chao Liu committed
69
70
    }

71
    __host__ __device__ static constexpr index_t GetElementSize()
Chao Liu's avatar
Chao Liu committed
72
    {
73
        return OriginalTensorDesc::GetElementSize();
Chao Liu's avatar
Chao Liu committed
74
75
    }

76
77
    __host__ __device__ static auto
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
78
    {
79
80
81
82
83
84
85
        Array<index_t, nOriginalDim> original_multi_id;

        static_for<0, nDim, 1>{}([&](auto IDim) {
            constexpr index_t idim               = IDim.Get();
            constexpr auto original_dims_partial = std::get<idim>(mOriginalDimMergeSeqs);

            // get partial original-multi-id corresponding to this merged dimension
Chao Liu's avatar
Chao Liu committed
86
            const auto original_multi_id_partial =
87
88
89
                OriginalTensorDesc::Extract(original_dims_partial)
                    .GetMultiIndexFrom1dIndex(multi_id[idim]);

Chao Liu's avatar
Chao Liu committed
90
91
92
            static_for<0, original_dims_partial.GetSize(), 1>{}([&](auto I_) {
                constexpr auto I                = decltype(I_){};
                constexpr index_t idim_original = original_dims_partial.Get(I);
93

Chao Liu's avatar
Chao Liu committed
94
95
                original_multi_id[idim_original] = original_multi_id_partial[I.Get()];
            });
96
97
98
        });

        return original_multi_id;
Chao Liu's avatar
Chao Liu committed
99
100
    }

101
    __host__ __device__ static index_t GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
102
    {
103
104
        const auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);

Chao Liu's avatar
Chao Liu committed
105
        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
Chao Liu's avatar
Chao Liu committed
106
107
    }

Chao Liu's avatar
Chao Liu committed
108
    template <class... Is>
109
    __host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
Chao Liu's avatar
Chao Liu committed
110
    {
111
        return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
Chao Liu's avatar
Chao Liu committed
112
113
    }

114
    __host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
Chao Liu's avatar
Chao Liu committed
115
    {
Chao Liu's avatar
Chao Liu committed
116
        constexpr auto dummy_desc = make_ConstantTensorDescriptor_default_rank_packed(GetLengths());
117
118

        return dummy_desc.GetMultiIndexFrom1dIndex(id);
Chao Liu's avatar
Chao Liu committed
119
120
121
    }
};

122
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
123
124
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
                                                                       OriginalDimMergeSeqs...)
Chao Liu's avatar
Chao Liu committed
125
{
126
    return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
Chao Liu's avatar
Chao Liu committed
127
}
Chao Liu's avatar
Chao Liu committed
128
129
130
131
132
133

template <class TDesc>
__host__ __device__ void print_ConstantMergedTensorDescriptor(TDesc, const char* s)
{
    print_ConstantTensorDescriptor(TDesc::GetOriginalTensorDescriptor(), s);
}