Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
707f2ae9
Unverified
Commit
707f2ae9
authored
Aug 09, 2023
by
peizhou001
Committed by
GitHub
Aug 09, 2023
Browse files
[Graphbolt]Add concurrent id hash map (#6082)
parent
144a491b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
409 additions
and
0 deletions
+409
-0
graphbolt/src/concurrent_id_hash_map.cc
graphbolt/src/concurrent_id_hash_map.cc
+222
-0
graphbolt/src/concurrent_id_hash_map.h
graphbolt/src/concurrent_id_hash_map.h
+187
-0
No files found.
graphbolt/src/concurrent_id_hash_map.cc
0 → 100644
View file @
707f2ae9
/**
* Copyright (c) 2023 by Contributors
* @file concurrent_id_hash_map.cc
* @brief Class about id hash map.
*/
#include "concurrent_id_hash_map.h"
#ifdef _MSC_VER
#include <intrin.h>
#endif // _MSC_VER
#include <cmath>
#include <numeric>
namespace
{
static
constexpr
int64_t
kEmptyKey
=
-
1
;
static
constexpr
int
kGrainSize
=
256
;
// The formula is established from experience which is used to get the hashmap
// size from the input array size.
inline
size_t
GetMapSize
(
size_t
num
)
{
size_t
capacity
=
1
;
return
capacity
<<
static_cast
<
size_t
>
(
1
+
std
::
log2
(
num
*
3
));
}
}
// namespace
namespace
graphbolt
{
namespace
sampling
{
template
<
typename
IdType
>
IdType
ConcurrentIdHashMap
<
IdType
>::
CompareAndSwap
(
IdType
*
ptr
,
IdType
old_val
,
IdType
new_val
)
{
#ifdef _MSC_VER
if
(
sizeof
(
IdType
)
==
4
)
{
return
_InterlockedCompareExchange
(
reinterpret_cast
<
long
*>
(
ptr
),
new_val
,
old_val
);
}
else
if
(
sizeof
(
IdType
)
==
8
)
{
return
_InterlockedCompareExchange64
(
reinterpret_cast
<
long
long
*>
(
ptr
),
new_val
,
old_val
);
}
else
{
LOG
(
FATAL
)
<<
"ID can only be int32 or int64"
;
}
#elif __GNUC__ // _MSC_VER
return
__sync_val_compare_and_swap
(
ptr
,
old_val
,
new_val
);
#else // _MSC_VER
#error "CompareAndSwap is not supported on this platform."
#endif // _MSC_VER
}
template
<
typename
IdType
>
ConcurrentIdHashMap
<
IdType
>::
ConcurrentIdHashMap
()
:
mask_
(
0
)
{}
template
<
typename
IdType
>
torch
::
Tensor
ConcurrentIdHashMap
<
IdType
>::
Init
(
const
torch
::
Tensor
&
ids
,
size_t
num_seeds
)
{
const
IdType
*
ids_data
=
ids
.
data_ptr
<
IdType
>
();
const
size_t
num_ids
=
static_cast
<
size_t
>
(
ids
.
size
(
0
));
size_t
capacity
=
GetMapSize
(
num_ids
);
mask_
=
static_cast
<
IdType
>
(
capacity
-
1
);
hash_map_
=
torch
::
full
({
static_cast
<
int64_t
>
(
capacity
*
2
)},
-
1
,
ids
.
options
());
// This code block is to fill the ids into hash_map_.
auto
unique_ids
=
torch
::
empty_like
(
ids
);
IdType
*
unique_ids_data
=
unique_ids
.
data_ptr
<
IdType
>
();
// Fill in the first `num_seeds` ids.
torch
::
parallel_for
(
0
,
num_seeds
,
kGrainSize
,
[
&
](
int64_t
s
,
int64_t
e
)
{
for
(
int64_t
i
=
s
;
i
<
e
;
i
++
)
{
InsertAndSet
(
ids_data
[
i
],
static_cast
<
IdType
>
(
i
));
}
});
// Place the first `num_seeds` ids.
unique_ids
.
slice
(
0
,
0
,
num_seeds
)
=
ids
.
slice
(
0
,
0
,
num_seeds
);
// An auxiliary array indicates whether the corresponding elements
// are inserted into hash map or not. Use `int16_t` instead of `bool` as
// vector<bool> is unsafe when updating different elements from different
// threads. See https://en.cppreference.com/w/cpp/container#Thread_safety.
std
::
vector
<
int16_t
>
valid
(
num_ids
);
const
int64_t
num_threads
=
torch
::
get_num_threads
();
std
::
vector
<
size_t
>
block_offset
(
num_threads
+
1
,
0
);
// Insert all elements in this loop.
torch
::
parallel_for
(
num_seeds
,
num_ids
,
kGrainSize
,
[
&
](
int64_t
s
,
int64_t
e
)
{
size_t
count
=
0
;
for
(
int64_t
i
=
s
;
i
<
e
;
i
++
)
{
valid
[
i
]
=
Insert
(
ids_data
[
i
]);
count
+=
valid
[
i
];
}
auto
thread_id
=
torch
::
get_thread_num
();
block_offset
[
thread_id
+
1
]
=
count
;
});
// Get ExclusiveSum of each block.
std
::
partial_sum
(
block_offset
.
begin
()
+
1
,
block_offset
.
end
(),
block_offset
.
begin
()
+
1
);
unique_ids
=
unique_ids
.
slice
(
0
,
0
,
num_seeds
+
block_offset
.
back
());
// Get unique array from ids and set value for hash map.
torch
::
parallel_for
(
num_seeds
,
num_ids
,
kGrainSize
,
[
&
](
int64_t
s
,
int64_t
e
)
{
auto
thread_id
=
torch
::
get_thread_num
();
auto
pos
=
block_offset
[
thread_id
]
+
num_seeds
;
for
(
int64_t
i
=
s
;
i
<
e
;
i
++
)
{
if
(
valid
[
i
])
{
unique_ids_data
[
pos
]
=
ids_data
[
i
];
Set
(
ids_data
[
i
],
pos
);
pos
=
pos
+
1
;
}
}
});
return
unique_ids
;
}
template
<
typename
IdType
>
torch
::
Tensor
ConcurrentIdHashMap
<
IdType
>::
MapIds
(
const
torch
::
Tensor
&
ids
)
const
{
const
IdType
*
ids_data
=
ids
.
data_ptr
<
IdType
>
();
torch
::
Tensor
new_ids
=
torch
::
empty_like
(
ids
);
auto
num_ids
=
new_ids
.
size
(
0
);
IdType
*
values_data
=
new_ids
.
data_ptr
<
IdType
>
();
torch
::
parallel_for
(
0
,
num_ids
,
kGrainSize
,
[
&
](
int64_t
s
,
int64_t
e
)
{
for
(
int64_t
i
=
s
;
i
<
e
;
i
++
)
{
values_data
[
i
]
=
MapId
(
ids_data
[
i
]);
}
});
return
new_ids
;
}
template
<
typename
IdType
>
constexpr
IdType
getKeyIndex
(
IdType
pos
)
{
return
2
*
pos
;
}
template
<
typename
IdType
>
constexpr
IdType
getValueIndex
(
IdType
pos
)
{
return
2
*
pos
+
1
;
}
template
<
typename
IdType
>
inline
void
ConcurrentIdHashMap
<
IdType
>::
Next
(
IdType
*
pos
,
IdType
*
delta
)
const
{
// Use Quadric probing.
*
pos
=
(
*
pos
+
(
*
delta
)
*
(
*
delta
))
&
mask_
;
*
delta
=
*
delta
+
1
;
}
template
<
typename
IdType
>
inline
IdType
ConcurrentIdHashMap
<
IdType
>::
MapId
(
IdType
id
)
const
{
IdType
pos
=
(
id
&
mask_
),
delta
=
1
;
IdType
empty_key
=
static_cast
<
IdType
>
(
kEmptyKey
);
IdType
*
hash_map_data
=
hash_map_
.
data_ptr
<
IdType
>
();
IdType
key
=
hash_map_data
[
getKeyIndex
(
pos
)];
while
(
key
!=
empty_key
&&
key
!=
id
)
{
Next
(
&
pos
,
&
delta
);
key
=
hash_map_data
[
getKeyIndex
(
pos
)];
}
return
hash_map_data
[
getValueIndex
(
pos
)];
}
template
<
typename
IdType
>
bool
ConcurrentIdHashMap
<
IdType
>::
Insert
(
IdType
id
)
{
IdType
pos
=
(
id
&
mask_
),
delta
=
1
;
InsertState
state
=
AttemptInsertAt
(
pos
,
id
);
while
(
state
==
InsertState
::
OCCUPIED
)
{
Next
(
&
pos
,
&
delta
);
state
=
AttemptInsertAt
(
pos
,
id
);
}
return
state
==
InsertState
::
INSERTED
;
}
template
<
typename
IdType
>
inline
void
ConcurrentIdHashMap
<
IdType
>::
Set
(
IdType
key
,
IdType
value
)
{
IdType
pos
=
(
key
&
mask_
),
delta
=
1
;
IdType
*
hash_map_data
=
hash_map_
.
data_ptr
<
IdType
>
();
while
(
hash_map_data
[
getKeyIndex
(
pos
)]
!=
key
)
{
Next
(
&
pos
,
&
delta
);
}
hash_map_data
[
getValueIndex
(
pos
)]
=
value
;
}
template
<
typename
IdType
>
inline
void
ConcurrentIdHashMap
<
IdType
>::
InsertAndSet
(
IdType
id
,
IdType
value
)
{
IdType
pos
=
(
id
&
mask_
),
delta
=
1
;
while
(
AttemptInsertAt
(
pos
,
id
)
==
InsertState
::
OCCUPIED
)
{
Next
(
&
pos
,
&
delta
);
}
hash_map_
.
data_ptr
<
IdType
>
()[
getValueIndex
(
pos
)]
=
value
;
}
template
<
typename
IdType
>
inline
typename
ConcurrentIdHashMap
<
IdType
>::
InsertState
ConcurrentIdHashMap
<
IdType
>::
AttemptInsertAt
(
int64_t
pos
,
IdType
key
)
{
IdType
empty_key
=
static_cast
<
IdType
>
(
kEmptyKey
);
IdType
*
hash_map_data
=
hash_map_
.
data_ptr
<
IdType
>
();
IdType
old_val
=
CompareAndSwap
(
&
(
hash_map_data
[
getKeyIndex
(
pos
)]),
empty_key
,
key
);
if
(
old_val
==
empty_key
)
{
return
InsertState
::
INSERTED
;
}
else
if
(
old_val
==
key
)
{
return
InsertState
::
EXISTED
;
}
else
{
return
InsertState
::
OCCUPIED
;
}
}
template
class
ConcurrentIdHashMap
<
int32_t
>;
template
class
ConcurrentIdHashMap
<
int64_t
>;
template
class
ConcurrentIdHashMap
<
int16_t
>;
template
class
ConcurrentIdHashMap
<
int8_t
>;
template
class
ConcurrentIdHashMap
<
uint8_t
>;
}
// namespace sampling
}
// namespace graphbolt
graphbolt/src/concurrent_id_hash_map.h
0 → 100644
View file @
707f2ae9
/**
* Copyright (c) 2023 by Contributors
* @file concurrent_id_hash_map.h
* @brief Class about concurrent id hash map.
*/
#ifndef GRAPHBOLT_CONCURRENT_ID_HASH_MAP_H_
#define GRAPHBOLT_CONCURRENT_ID_HASH_MAP_H_
#include <torch/torch.h>
#include <functional>
#include <memory>
#include <vector>
namespace
graphbolt
{
namespace
sampling
{
/**
* @brief A CPU targeted hashmap for mapping duplicate and non-consecutive ids
* in the provided array to unique and consecutive ones. It utilizes
* multi-threading to accelerate the insert and search speed. Currently it is
* only designed to be used in `ToBlockCpu` for optimizing, so it only support
* key insertions once with Init function, and it does not support key deletion.
*
* The hash map should be prepared in two phases before using. With the first
* being creating the hashmap, and then initialize it with an id array which is
* divided into 2 parts: [`seed ids`, `sampled ids`]. `Seed ids` refer to
* a set ids chosen as the input for sampling process and `sampled ids` are the
* ids new sampled from the process (note the the `seed ids` might also be
* sampled in the process and included in the `sampled ids`). In result `seed
* ids` are mapped to [0, num_seed_ids) and `sampled ids` to [num_seed_ids,
* num_unique_ids). Notice that mapping order is stable for `seed ids` while not
* for the `sampled ids`.
*
* For example, for an array `A` having 4 seed ids with following entries:
* [99, 98, 100, 97, 97, 101, 101, 102, 101]
* Create the hashmap `H` with:
* `H = ConcurrentIdHashMap()` (1)
* And Init it with:
* `U = H.Init(A)` (2) (U is an id array used to store the unqiue
* ids in A).
* Then `U` should be (U is not exclusive as the overall mapping is not stable):
* [99, 98, 100, 97, 102, 101]
* And the hashmap should generate following mappings:
* * [
* {key: 99, value: 0},
* {key: 98, value: 1},
* {key: 100, value: 2},
* {key: 97, value: 3},
* {key: 102, value: 4},
* {key: 101, value: 5}
* ]
* Search the hashmap with array `I`=[98, 99, 102]:
* R = H.Map(I) (3)
* R should be:
* [1, 0, 4]
**/
template
<
typename
IdType
>
class
ConcurrentIdHashMap
{
private:
/**
* @brief The result state of an attempt to insert.
*/
enum
class
InsertState
{
OCCUPIED
,
// Indicates that the space where an insertion is being
// attempted is already occupied by another element.
EXISTED
,
// Indicates that the element being inserted already exists in the
// map, and thus no insertion is performed.
INSERTED
// Indicates that the insertion was successful and a new element
// was added to the map.
};
public:
/**
* @brief Cross platform CAS operation.
* It is an atomic operation that compares the contents of a memory
* location with a given value and, only if they are the same, modifies
* the contents of that memory location to a new given value.
*
* @param ptr The pointer to the object to test and modify .
* @param old_val The value expected to be found in `ptr`.
* @param new_val The value to store in `ptr` if it is as expected.
*
* @return Old value pointed by the `ptr`.
*/
static
IdType
CompareAndSwap
(
IdType
*
ptr
,
IdType
old_val
,
IdType
new_val
);
ConcurrentIdHashMap
();
ConcurrentIdHashMap
(
const
ConcurrentIdHashMap
&
other
)
=
delete
;
ConcurrentIdHashMap
&
operator
=
(
const
ConcurrentIdHashMap
&
other
)
=
delete
;
/**
* @brief Initialize the hashmap with an array of ids. The first `num_seeds`
* ids are unique and must be mapped to a contiguous array starting
* from 0. The left can be duplicated and the mapping result is not stable.
*
* @param ids The array of the ids to be inserted.
* @param num_seeds The number of seed ids.
*
* @return Unique ids from the input `ids`.
*/
torch
::
Tensor
Init
(
const
torch
::
Tensor
&
ids
,
size_t
num_seeds
);
/**
* @brief Find mappings of given keys.
*
* @param ids The keys to map for.
*
* @return Mapping results corresponding to `ids`.
*/
torch
::
Tensor
MapIds
(
const
torch
::
Tensor
&
ids
)
const
;
private:
/**
* @brief Get the next position and delta for probing.
*
* @param[in,out] pos Calculate the next position with quadric probing.
* @param[in,out] delta Calculate the next delta by adding 1.
*/
inline
void
Next
(
IdType
*
pos
,
IdType
*
delta
)
const
;
/**
* @brief Find the mapping of a given key.
*
* @param id The key to map for.
*
* @return Mapping result corresponding to `id`.
*/
inline
IdType
MapId
(
const
IdType
id
)
const
;
/**
* @brief Insert an id into the hash map.
*
* @param id The id to be inserted.
*
* @return Whether the `id` is inserted or not.
*/
inline
bool
Insert
(
IdType
id
);
/**
* @brief Set the value for the key in the hash map.
*
* @param key The key to set for.
* @param value The value to be set for the `key`.
*
* @warning Key must exist.
*/
inline
void
Set
(
IdType
key
,
IdType
value
);
/**
* @brief Insert a key into the hash map.
*
* @param id The key to be inserted.
* @param value The value to be set for the `key`.
*
*/
inline
void
InsertAndSet
(
IdType
key
,
IdType
value
);
/**
* @brief Attempt to insert the key into the hash map at the given position.
*
* @param pos The position in the hash map to be inserted at.
* @param key The key to be inserted.
*
* @return The state of the insertion.
*/
inline
InsertState
AttemptInsertAt
(
int64_t
pos
,
IdType
key
);
private:
/**
* @brief Hash maps which is used to store all elements.
*/
torch
::
Tensor
hash_map_
;
/**
* @brief Mask which is assisted to get the position in the table
* for a key by performing `&` operation with it.
*/
IdType
mask_
;
};
}
// namespace sampling
}
// namespace graphbolt
#endif // GRAPHBOLT_CONCURRENT_ID_HASH_MAP_H_
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment