Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a2c5472a
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "75d53cc83966b4046e5a329ddf7baa6aa24f52e2"
Unverified
Commit
a2c5472a
authored
Mar 18, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Mar 18, 2024
Browse files
[GraphBolt] Labor dependent template specialization. (#7220)
parent
74c5e31d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
92 additions
and
44 deletions
+92
-44
graphbolt/include/graphbolt/continuous_seed.h
graphbolt/include/graphbolt/continuous_seed.h
+25
-0
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+21
-10
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+46
-34
No files found.
graphbolt/include/graphbolt/continuous_seed.h
View file @
a2c5472a
...
@@ -92,6 +92,31 @@ class continuous_seed {
...
@@ -92,6 +92,31 @@ class continuous_seed {
#endif // __CUDA_ARCH__
#endif // __CUDA_ARCH__
};
};
class
single_seed
{
uint64_t
seed_
;
public:
/* implicit */
single_seed
(
const
int64_t
seed
)
:
seed_
(
seed
)
{}
// NOLINT
single_seed
(
torch
::
Tensor
seed_arr
)
:
seed_
(
seed_arr
.
data_ptr
<
int64_t
>
()[
0
])
{}
#ifdef __CUDACC__
__device__
inline
float
uniform
(
const
uint64_t
id
)
const
{
const
uint64_t
kCurandSeed
=
999961
;
// Could be any random number.
curandStatePhilox4_32_10_t
rng
;
curand_init
(
kCurandSeed
,
seed_
,
id
,
&
rng
);
return
curand_uniform
(
&
rng
);
}
#else
inline
float
uniform
(
const
uint64_t
id
)
const
{
pcg32
ng0
(
seed_
,
id
);
std
::
uniform_real_distribution
<
float
>
uni
;
return
uni
(
ng0
);
}
#endif // __CUDA_ARCH__
};
}
// namespace graphbolt
}
// namespace graphbolt
#endif // GRAPHBOLT_CONTINUOUS_SEED_H_
#endif // GRAPHBOLT_CONTINUOUS_SEED_H_
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
a2c5472a
...
@@ -17,7 +17,11 @@
...
@@ -17,7 +17,11 @@
namespace
graphbolt
{
namespace
graphbolt
{
namespace
sampling
{
namespace
sampling
{
enum
SamplerType
{
NEIGHBOR
,
LABOR
};
enum
SamplerType
{
NEIGHBOR
,
LABOR
,
LABOR_DEPENDENT
};
constexpr
bool
is_labor
(
SamplerType
S
)
{
return
S
==
SamplerType
::
LABOR
||
S
==
SamplerType
::
LABOR_DEPENDENT
;
}
template
<
SamplerType
S
>
template
<
SamplerType
S
>
struct
SamplerArgs
;
struct
SamplerArgs
;
...
@@ -27,6 +31,13 @@ struct SamplerArgs<SamplerType::NEIGHBOR> {};
...
@@ -27,6 +31,13 @@ struct SamplerArgs<SamplerType::NEIGHBOR> {};
template
<
>
template
<
>
struct
SamplerArgs
<
SamplerType
::
LABOR
>
{
struct
SamplerArgs
<
SamplerType
::
LABOR
>
{
const
torch
::
Tensor
&
indices
;
single_seed
random_seed
;
int64_t
num_nodes
;
};
template
<
>
struct
SamplerArgs
<
SamplerType
::
LABOR_DEPENDENT
>
{
const
torch
::
Tensor
&
indices
;
const
torch
::
Tensor
&
indices
;
continuous_seed
random_seed
;
continuous_seed
random_seed
;
int64_t
num_nodes
;
int64_t
num_nodes
;
...
@@ -555,12 +566,12 @@ int64_t Pick(
...
@@ -555,12 +566,12 @@ int64_t Pick(
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
,
PickedType
*
picked_data_ptr
);
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
,
PickedType
*
picked_data_ptr
);
template
<
typename
PickedType
>
template
<
SamplerType
S
,
typename
PickedType
>
int64_t
Pick
(
std
::
enable_if_t
<
is_labor
(
S
),
int64_t
>
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
SamplerArgs
<
SamplerType
::
LABOR
>
args
,
PickedType
*
picked_data_ptr
);
PickedType
*
picked_data_ptr
);
template
<
typename
PickedType
>
template
<
typename
PickedType
>
int64_t
TemporalPick
(
int64_t
TemporalPick
(
...
@@ -619,13 +630,13 @@ int64_t TemporalPickByEtype(
...
@@ -619,13 +630,13 @@ int64_t TemporalPickByEtype(
PickedType
*
picked_data_ptr
);
PickedType
*
picked_data_ptr
);
template
<
template
<
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
,
typename
Picked
Type
,
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
,
Sampler
Type
S
,
int
StackSize
=
1024
>
typename
PickedType
,
int
StackSize
=
1024
>
int64_t
LaborPick
(
std
::
enable_if_t
<
is_labor
(
S
),
int64_t
>
LaborPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
SamplerArgs
<
SamplerType
::
LABOR
>
args
,
PickedType
*
picked_data_ptr
);
PickedType
*
picked_data_ptr
);
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
a2c5472a
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include <limits>
#include <limits>
#include <numeric>
#include <numeric>
#include <tuple>
#include <tuple>
#include <type_traits>
#include <vector>
#include <vector>
#include "./macro.h"
#include "./macro.h"
...
@@ -660,26 +661,37 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -660,26 +661,37 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}
if
(
layer
)
{
if
(
layer
)
{
SamplerArgs
<
SamplerType
::
LABOR
>
args
=
[
&
]
{
if
(
random_seed
.
has_value
()
&&
random_seed
->
numel
()
>=
2
)
{
if
(
random_seed
.
has_value
())
{
SamplerArgs
<
SamplerType
::
LABOR_DEPENDENT
>
args
{
return
SamplerArgs
<
SamplerType
::
LABOR
>
{
indices_
,
indices_
,
{
random_seed
.
value
(),
static_cast
<
float
>
(
seed2_contribution
)},
{
random_seed
.
value
(),
static_cast
<
float
>
(
seed2_contribution
)},
NumNodes
()};
NumNodes
()};
return
SampleNeighborsImpl
(
}
else
{
nodes
.
value
(),
return_eids
,
return
SamplerArgs
<
SamplerType
::
LABOR
>
{
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
indices_
,
GetPickFn
(
RandomEngine
::
ThreadLocal
()
->
RandInt
(
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
()),
probs_or_mask
,
args
));
NumNodes
()};
}
else
{
}
auto
args
=
[
&
]
{
}();
if
(
random_seed
.
has_value
()
&&
random_seed
->
numel
()
==
1
)
{
return
SampleNeighborsImpl
(
return
SamplerArgs
<
SamplerType
::
LABOR
>
{
nodes
.
value
(),
return_eids
,
indices_
,
random_seed
.
value
(),
NumNodes
()};
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
}
else
{
GetPickFn
(
return
SamplerArgs
<
SamplerType
::
LABOR
>
{
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
indices_
,
args
));
RandomEngine
::
ThreadLocal
()
->
RandInt
(
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
()),
NumNodes
()};
}
}();
return
SampleNeighborsImpl
(
nodes
.
value
(),
return_eids
,
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetPickFn
(
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
args
));
}
}
else
{
}
else
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
(
...
@@ -1297,7 +1309,7 @@ int64_t TemporalPick(
...
@@ -1297,7 +1309,7 @@ int64_t TemporalPick(
}
}
return
picked_indices
.
numel
();
return
picked_indices
.
numel
();
}
}
if
constexpr
(
S
==
SamplerType
::
LABOR
)
{
if
constexpr
(
is_labor
(
S
)
)
{
return
Pick
(
return
Pick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
masked_prob
,
args
,
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
masked_prob
,
args
,
picked_data_ptr
);
picked_data_ptr
);
...
@@ -1383,12 +1395,12 @@ int64_t TemporalPickByEtype(
...
@@ -1383,12 +1395,12 @@ int64_t TemporalPickByEtype(
return
pick_offset
;
return
pick_offset
;
}
}
template
<
typename
PickedType
>
template
<
SamplerType
S
,
typename
PickedType
>
int64_t
Pick
(
std
::
enable_if_t
<
is_labor
(
S
),
int64_t
>
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
SamplerArgs
<
SamplerType
::
LABOR
>
args
,
PickedType
*
picked_data_ptr
)
{
PickedType
*
picked_data_ptr
)
{
if
(
fanout
==
0
)
return
0
;
if
(
fanout
==
0
)
return
0
;
if
(
probs_or_mask
.
has_value
())
{
if
(
probs_or_mask
.
has_value
())
{
if
(
fanout
<
0
)
{
if
(
fanout
<
0
)
{
...
@@ -1438,9 +1450,9 @@ inline T invcdf(T u, int64_t n, T rem) {
...
@@ -1438,9 +1450,9 @@ inline T invcdf(T u, int64_t n, T rem) {
return
rem
*
(
one
-
std
::
pow
(
one
-
u
,
one
/
n
));
return
rem
*
(
one
-
std
::
pow
(
one
-
u
,
one
/
n
));
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
seed_t
>
inline
T
jth_sorted_uniform_random
(
inline
T
jth_sorted_uniform_random
(
continuous_
seed
seed
,
int64_t
t
,
int64_t
c
,
int64_t
j
,
T
&
rem
,
int64_t
n
)
{
seed
_t
seed
,
int64_t
t
,
int64_t
c
,
int64_t
j
,
T
&
rem
,
int64_t
n
)
{
const
T
u
=
seed
.
uniform
(
t
+
j
*
c
);
const
T
u
=
seed
.
uniform
(
t
+
j
*
c
);
// https://mathematica.stackexchange.com/a/256707
// https://mathematica.stackexchange.com/a/256707
rem
-=
invcdf
(
u
,
n
,
rem
);
rem
-=
invcdf
(
u
,
n
,
rem
);
...
@@ -1474,13 +1486,13 @@ inline T jth_sorted_uniform_random(
...
@@ -1474,13 +1486,13 @@ inline T jth_sorted_uniform_random(
* should be put. Enough memory space should be allocated in advance.
* should be put. Enough memory space should be allocated in advance.
*/
*/
template
<
template
<
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
,
typename
Picked
Type
,
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
,
Sampler
Type
S
,
int
StackSize
>
typename
PickedType
,
int
StackSize
>
inline
int64_t
LaborPick
(
inline
std
::
enable_if_t
<
is_labor
(
S
),
int64_t
>
LaborPick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
SamplerArgs
<
SamplerType
::
LABOR
>
args
,
PickedType
*
picked_data_ptr
)
{
PickedType
*
picked_data_ptr
)
{
fanout
=
Replace
?
fanout
:
std
::
min
(
fanout
,
num_neighbors
);
fanout
=
Replace
?
fanout
:
std
::
min
(
fanout
,
num_neighbors
);
if
(
!
NonUniform
&&
!
Replace
&&
fanout
>=
num_neighbors
)
{
if
(
!
NonUniform
&&
!
Replace
&&
fanout
>=
num_neighbors
)
{
std
::
iota
(
picked_data_ptr
,
picked_data_ptr
+
num_neighbors
,
offset
);
std
::
iota
(
picked_data_ptr
,
picked_data_ptr
+
num_neighbors
,
offset
);
...
@@ -1504,8 +1516,8 @@ inline int64_t LaborPick(
...
@@ -1504,8 +1516,8 @@ inline int64_t LaborPick(
}
}
AT_DISPATCH_INDEX_TYPES
(
AT_DISPATCH_INDEX_TYPES
(
args
.
indices
.
scalar_type
(),
"LaborPickMain"
,
([
&
]
{
args
.
indices
.
scalar_type
(),
"LaborPickMain"
,
([
&
]
{
const
index_t
*
local_indices_data
=
const
auto
local_indices_data
=
args
.
indices
.
data_ptr
<
index_t
>
()
+
offset
;
reinterpret_cast
<
index_t
*>
(
args
.
indices
.
data_ptr
(
)
)
+
offset
;
if
constexpr
(
Replace
)
{
if
constexpr
(
Replace
)
{
// [Algorithm] @mfbalin
// [Algorithm] @mfbalin
// Use a max-heap to get rid of the big random numbers and filter the
// Use a max-heap to get rid of the big random numbers and filter the
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment