map.hpp 4.01 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/tuple.hpp"

namespace ck {

// naive Map
template <typename Key, typename Data, index_t MaxSize = 128>
struct Map
{
    using Pair = Tuple<Key, Data>;
    using Impl = Array<Pair, MaxSize>;

    Impl impl_;
    index_t size_;

    struct Iterator
    {
        Impl& impl_;
        index_t pos_;

        __host__ __device__ constexpr Iterator(Impl& impl, index_t pos) : impl_{impl}, pos_{pos} {}

        __host__ __device__ constexpr Iterator& operator++()
        {
            pos_++;

            return *this;
        }

        __host__ __device__ constexpr bool operator!=(const Iterator& other) const
        {
            return other.pos_ != pos_;
        }

        __host__ __device__ constexpr Pair& operator*() { return impl_.At(pos_); }
    };

    struct ConstIterator
    {
        const Impl& impl_;
        index_t pos_;

        __host__ __device__ constexpr ConstIterator(const Impl& impl, index_t pos)
            : impl_{impl}, pos_{pos}
        {
        }

        __host__ __device__ constexpr ConstIterator& operator++()
        {
            pos_++;

            return *this;
        }

        __host__ __device__ constexpr bool operator!=(const ConstIterator& other) const
        {
            return other.pos_ != pos_;
        }

        __host__ __device__ constexpr const Pair& operator*() const { return impl_.At(pos_); }
    };

    __host__ __device__ constexpr Map() : impl_{}, size_{0} {}

    __host__ __device__ constexpr index_t Size() const { return size_; }

    __host__ __device__ void Clear() { size_ = 0; }

    __host__ __device__ constexpr index_t FindPosition(const Key& key) const
    {
        for(index_t i = 0; i < Size(); i++)
        {
            if(impl_[i].template At<0>() == key)
            {
                return i;
            }
        }

        return size_;
    }

    __host__ __device__ constexpr ConstIterator Find(const Key& key) const
    {
        return ConstIterator{impl_, FindPosition(key)};
    }

    __host__ __device__ constexpr Iterator Find(const Key& key)
    {
        return Iterator{impl_, FindPosition(key)};
    }

    __host__ __device__ constexpr const Data& operator[](const Key& key) const
    {
        const auto it = Find(key);

        // FIXME
Chao Liu's avatar
fix bug  
Chao Liu committed
103
        assert(it.pos_ < Size());
Chao Liu's avatar
Chao Liu committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

        return impl_[it.pos_].template At<1>();
    }

    __host__ __device__ constexpr Data& operator()(const Key& key)
    {
        auto it = Find(key);

        // if entry not found
        if(it.pos_ == Size())
        {
            impl_(it.pos_).template At<0>() = key;
            size_++;
        }

        // FIXME
        assert(size_ <= MaxSize);

        return impl_(it.pos_).template At<1>();
    }

    // WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
    __host__ __device__ constexpr ConstIterator begin() const { return ConstIterator{impl_, 0}; }

    // WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
    __host__ __device__ constexpr ConstIterator end() const { return ConstIterator{impl_, size_}; }

    // WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
    __host__ __device__ constexpr Iterator begin() { return Iterator{impl_, 0}; }

    // WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
    __host__ __device__ constexpr Iterator end() { return Iterator{impl_, size_}; }

    __host__ __device__ void Print() const
    {
        printf("Map{size_: %d, ", size_);
        //
        printf("impl_: [");
        //
        for(const auto& [key, data] : *this)
        {
            printf("{key: ");
            print(key);
            printf(", data: ");
            print(data);
            printf("}, ");
        }
        //
        printf("]");
        //
        printf("}");
    }
};

} // namespace ck