Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0f3f8181
Unverified
Commit
0f3f8181
authored
Aug 29, 2023
by
Muhammed Fatih BALIN
Committed by
GitHub
Aug 29, 2023
Browse files
[Graphbolt] Improve Labor performance (#6203)
parent
27f6561a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
19 deletions
+32
-19
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+2
-2
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+30
-17
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
0f3f8181
...
...
@@ -406,8 +406,8 @@ int64_t PickByEtype(
PickedType
*
picked_data_ptr
);
template
<
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
=
float
,
typename
PickedType
>
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
,
typename
PickedType
,
int
StackSize
=
1024
>
int64_t
LaborPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
const
torch
::
TensorOptions
&
options
,
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
0f3f8181
...
...
@@ -8,6 +8,8 @@
#include <graphbolt/serialize.h>
#include <torch/torch.h>
#include <algorithm>
#include <array>
#include <cmath>
#include <limits>
#include <numeric>
...
...
@@ -730,11 +732,11 @@ int64_t Pick(
return
UniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
picked_data_ptr
);
}
else
if
(
replace
)
{
return
LaborPick
<
false
,
true
>
(
return
LaborPick
<
false
,
true
,
float
>
(
offset
,
num_neighbors
,
fanout
,
options
,
/* probs_or_mask= */
torch
::
nullopt
,
args
,
picked_data_ptr
);
}
else
{
// replace = false
return
LaborPick
<
false
,
false
>
(
return
LaborPick
<
false
,
false
,
float
>
(
offset
,
num_neighbors
,
fanout
,
options
,
/* probs_or_mask= */
torch
::
nullopt
,
args
,
picked_data_ptr
);
}
...
...
@@ -770,7 +772,8 @@ inline void safe_divide(T& a, U b) {
* should be put. Enough memory space should be allocated in advance.
*/
template
<
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
,
typename
PickedType
>
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
,
typename
PickedType
,
int
StackSize
>
inline
int64_t
LaborPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
const
torch
::
TensorOptions
&
options
,
...
...
@@ -781,10 +784,16 @@ inline int64_t LaborPick(
std
::
iota
(
picked_data_ptr
,
picked_data_ptr
+
num_neighbors
,
offset
);
return
num_neighbors
;
}
torch
::
Tensor
heap_tensor
=
torch
::
empty
({
fanout
*
2
},
torch
::
kInt32
);
// Assuming max_degree of a vertex is <= 4 billion.
auto
heap_data
=
reinterpret_cast
<
std
::
pair
<
float
,
uint32_t
>*>
(
std
::
array
<
std
::
pair
<
float
,
uint32_t
>
,
StackSize
>
heap
;
auto
heap_data
=
heap
.
data
();
torch
::
Tensor
heap_tensor
;
if
(
fanout
>
StackSize
)
{
constexpr
int
factor
=
sizeof
(
heap_data
[
0
])
/
sizeof
(
int32_t
);
heap_tensor
=
torch
::
empty
({
fanout
*
factor
},
torch
::
kInt32
);
heap_data
=
reinterpret_cast
<
std
::
pair
<
float
,
uint32_t
>*>
(
heap_tensor
.
data_ptr
<
int32_t
>
());
}
const
ProbsType
*
local_probs_data
=
NonUniform
?
probs_or_mask
.
value
().
data_ptr
<
ProbsType
>
()
+
offset
:
nullptr
;
...
...
@@ -814,22 +823,29 @@ inline int64_t LaborPick(
// is O((fanout + num_neighbors) log(fanout)). It is possible to
// decrease the logarithmic factor down to
// O(log(min(fanout, num_neighbors))).
torch
::
Tensor
remaining
=
torch
::
ones
({
num_neighbors
},
torch
::
kFloat32
);
float
*
rem_data
=
remaining
.
data_ptr
<
float
>
();
std
::
array
<
float
,
StackSize
>
remaining
;
auto
remaining_data
=
remaining
.
data
();
torch
::
Tensor
remaining_tensor
;
if
(
num_neighbors
>
StackSize
)
{
remaining_tensor
=
torch
::
empty
({
num_neighbors
},
torch
::
kFloat32
);
remaining_data
=
remaining_tensor
.
data_ptr
<
float
>
();
}
std
::
fill_n
(
remaining_data
,
num_neighbors
,
1.
f
);
auto
heap_end
=
heap_data
;
const
auto
init_count
=
(
num_neighbors
+
fanout
-
1
)
/
num_neighbors
;
auto
sample_neighbor_i_with_index_t_jth_time
=
[
&
](
scalar_t
t
,
int64_t
j
,
uint32_t
i
)
{
auto
rnd
=
labor
::
jth_sorted_uniform_random
(
args
.
random_seed
,
t
,
args
.
num_nodes
,
j
,
rem_data
[
i
],
args
.
random_seed
,
t
,
args
.
num_nodes
,
j
,
rem
aining
_data
[
i
],
fanout
-
j
);
// r_t
if
constexpr
(
NonUniform
)
{
safe_divide
(
rnd
,
local_probs_data
[
i
]);
}
// r_t / \pi_t
if
(
heap_end
<
heap_data
+
fanout
)
{
heap_end
[
0
]
=
std
::
make_pair
(
rnd
,
i
);
std
::
push_heap
(
heap_data
,
++
heap_end
);
if
(
++
heap_end
>=
heap_data
+
fanout
)
{
std
::
make_heap
(
heap_data
,
heap_data
+
fanout
);
}
return
false
;
}
else
if
(
rnd
<
heap_data
[
0
].
first
)
{
std
::
pop_heap
(
heap_data
,
heap_data
+
fanout
);
...
...
@@ -837,18 +853,18 @@ inline int64_t LaborPick(
std
::
push_heap
(
heap_data
,
heap_data
+
fanout
);
return
false
;
}
else
{
rem_data
[
i
]
=
-
1
;
rem
aining
_data
[
i
]
=
-
1
;
return
true
;
}
};
for
(
uint32_t
i
=
0
;
i
<
num_neighbors
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
init_count
;
j
++
)
{
const
auto
t
=
local_indices_data
[
i
];
for
(
int64_t
j
=
0
;
j
<
init_count
;
j
++
)
{
sample_neighbor_i_with_index_t_jth_time
(
t
,
j
,
i
);
}
}
for
(
uint32_t
i
=
0
;
i
<
num_neighbors
;
++
i
)
{
if
(
rem_data
[
i
]
==
-
1
)
continue
;
if
(
rem
aining
_data
[
i
]
==
-
1
)
continue
;
const
auto
t
=
local_indices_data
[
i
];
for
(
int64_t
j
=
init_count
;
j
<
fanout
;
++
j
)
{
if
(
sample_neighbor_i_with_index_t_jth_time
(
t
,
j
,
i
))
break
;
...
...
@@ -906,9 +922,6 @@ inline int64_t LaborPick(
picked_data_ptr
[
num_sampled
++
]
=
offset
+
j
;
}
}
TORCH_CHECK
(
!
Replace
||
num_sampled
==
fanout
||
num_sampled
==
0
,
"Sampling with replacement should sample exactly fanout neighbors or 0!"
);
return
num_sampled
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment