Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e42c7fcd
"vscode:/vscode.git/clone" did not exist on "f08d2a8a4c7ae03de982585c7dc8a47259e553fb"
Unverified
Commit
e42c7fcd
authored
Dec 22, 2023
by
czkkkkkk
Committed by
GitHub
Dec 22, 2023
Browse files
[Graphbolt] Implement Temporal Neighbor Sampling. (#6784)
parent
8a8f2b00
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
309 additions
and
11 deletions
+309
-11
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+37
-0
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+223
-5
graphbolt/src/utils.h
graphbolt/src/utils.h
+23
-0
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+26
-6
No files found.
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
e42c7fcd
...
@@ -508,12 +508,28 @@ int64_t NumPick(
...
@@ -508,12 +508,28 @@ int64_t NumPick(
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
int64_t
num_neighbors
);
int64_t
num_neighbors
);
int64_t
TemporalNumPick
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indics
,
int64_t
fanout
,
bool
replace
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
);
int64_t
NumPickByEtype
(
int64_t
NumPickByEtype
(
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
int64_t
num_neighbors
);
int64_t
num_neighbors
);
int64_t
TemporalNumPickByEtype
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
);
/**
/**
* @brief Picks a specified number of neighbors for a node, starting from the
* @brief Picks a specified number of neighbors for a node, starting from the
* given offset and having the specified number of neighbors.
* given offset and having the specified number of neighbors.
...
@@ -562,6 +578,16 @@ int64_t Pick(
...
@@ -562,6 +578,16 @@ int64_t Pick(
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
LABOR
>
args
,
PickedType
*
picked_data_ptr
);
SamplerArgs
<
SamplerType
::
LABOR
>
args
,
PickedType
*
picked_data_ptr
);
template
<
typename
PickedType
>
int64_t
TemporalPick
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
PickedType
*
picked_data_ptr
);
/**
/**
* @brief Picks a specified number of neighbors for a node per edge type,
* @brief Picks a specified number of neighbors for a node per edge type,
* starting from the given offset and having the specified number of neighbors.
* starting from the given offset and having the specified number of neighbors.
...
@@ -597,6 +623,17 @@ int64_t PickByEtype(
...
@@ -597,6 +623,17 @@ int64_t PickByEtype(
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
PickedType
*
picked_data_ptr
);
PickedType
*
picked_data_ptr
);
template
<
typename
PickedType
>
int64_t
TemporalPickByEtype
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
PickedType
*
picked_data_ptr
);
template
<
template
<
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
,
typename
PickedType
,
bool
NonUniform
,
bool
Replace
,
typename
ProbsType
,
typename
PickedType
,
int
StackSize
=
1024
>
int
StackSize
=
1024
>
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
e42c7fcd
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "./random.h"
#include "./random.h"
#include "./shared_memory_helper.h"
#include "./shared_memory_helper.h"
#include "./utils.h"
namespace
{
namespace
{
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>>
TensorizeDict
(
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>>
TensorizeDict
(
...
@@ -349,6 +350,31 @@ auto GetNumPickFn(
...
@@ -349,6 +350,31 @@ auto GetNumPickFn(
};
};
}
}
auto
GetTemporalNumPickFn
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
)
{
// If fanouts.size() > 1, returns the total number of all edge types of the
// given node.
return
[
&
seed_timestamp
,
&
csc_indices
,
&
fanouts
,
replace
,
&
probs_or_mask
,
&
type_per_edge
,
&
node_timestamp
,
&
edge_timestamp
](
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
)
{
if
(
fanouts
.
size
()
>
1
)
{
return
TemporalNumPickByEtype
(
seed_timestamp
,
csc_indices
,
fanouts
,
replace
,
type_per_edge
.
value
(),
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
seed_offset
,
offset
,
num_neighbors
);
}
else
{
return
TemporalNumPick
(
seed_timestamp
,
csc_indices
,
fanouts
[
0
],
replace
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
seed_offset
,
offset
,
num_neighbors
);
}
};
}
/**
/**
* @brief Get a lambda function which contains the sampling process.
* @brief Get a lambda function which contains the sampling process.
*
*
...
@@ -400,6 +426,39 @@ auto GetPickFn(
...
@@ -400,6 +426,39 @@ auto GetPickFn(
};
};
}
}
auto
GetTemporalPickFn
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
)
{
return
[
&
seed_timestamp
,
&
csc_indices
,
&
fanouts
,
replace
,
&
options
,
&
type_per_edge
,
&
probs_or_mask
,
&
node_timestamp
,
&
edge_timestamp
](
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
auto
picked_data_ptr
)
{
// If fanouts.size() > 1, perform sampling for each edge type of each
// node; otherwise just sample once for each node with no regard of edge
// types.
if
(
fanouts
.
size
()
>
1
)
{
return
TemporalPickByEtype
(
seed_timestamp
,
csc_indices
,
seed_offset
,
offset
,
num_neighbors
,
fanouts
,
replace
,
options
,
type_per_edge
.
value
(),
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
picked_data_ptr
);
}
else
{
int64_t
num_sampled
=
TemporalPick
(
seed_timestamp
,
csc_indices
,
seed_offset
,
offset
,
num_neighbors
,
fanouts
[
0
],
replace
,
options
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
picked_data_ptr
);
if
(
type_per_edge
)
{
std
::
sort
(
picked_data_ptr
,
picked_data_ptr
+
num_sampled
);
}
return
num_sampled
;
}
};
}
template
<
typename
NumPickFn
,
typename
PickFn
>
template
<
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighborsImpl
(
FusedCSCSamplingGraph
::
SampleNeighborsImpl
(
...
@@ -579,14 +638,31 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
...
@@ -579,14 +638,31 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
{
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
{
// TODO(zhenkun):
// 1. Get probs_or_mask.
// 1. Get probs_or_mask.
auto
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
if
(
probs_name
.
has_value
())
{
// Note probs will be passed as input for 'torch.multinomial' in deeper
// stack, which doesn't support 'torch.half' and 'torch.bool' data types. To
// avoid crashes, convert 'probs_or_mask' to 'float32' data type.
if
(
probs_or_mask
.
value
().
dtype
()
==
torch
::
kBool
||
probs_or_mask
.
value
().
dtype
()
==
torch
::
kFloat16
)
{
probs_or_mask
=
probs_or_mask
.
value
().
to
(
torch
::
kFloat32
);
}
}
// 2. Get the timestamp attribute for nodes of the graph
// 2. Get the timestamp attribute for nodes of the graph
auto
node_timestamp
=
this
->
NodeAttribute
(
node_timestamp_attr_name
);
// 3. Get the timestamp attribute for edges of the graph
// 3. Get the timestamp attribute for edges of the graph
// 4. GetTemporalNumPickFn (New implementation)
auto
edge_timestamp
=
this
->
EdgeAttribute
(
edge_timestamp_attr_name
);
// 5. GetTemporalPickFn (New implementation)
// 4. Call SampleNeighborsImpl
// 6. Call SampleNeighborsImpl (Old implementation)
return
SampleNeighborsImpl
(
return
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
();
input_nodes
,
return_eids
,
GetTemporalNumPickFn
(
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
),
GetTemporalPickFn
(
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
));
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
...
@@ -669,6 +745,43 @@ int64_t NumPick(
...
@@ -669,6 +745,43 @@ int64_t NumPick(
return
replace
?
fanout
:
std
::
min
(
fanout
,
num_valid_neighbors
);
return
replace
?
fanout
:
std
::
min
(
fanout
,
num_valid_neighbors
);
}
}
torch
::
Tensor
TemporalMask
(
int64_t
seed_timestamp
,
torch
::
Tensor
csc_indices
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
std
::
pair
<
int64_t
,
int64_t
>
edge_range
)
{
auto
[
l
,
r
]
=
edge_range
;
torch
::
Tensor
mask
=
torch
::
ones
({
r
-
l
},
torch
::
kBool
);
if
(
node_timestamp
.
has_value
())
{
auto
neighbor_timestamp
=
node_timestamp
.
value
().
index_select
(
0
,
csc_indices
.
slice
(
0
,
l
,
r
));
mask
&=
neighbor_timestamp
<=
seed_timestamp
;
}
if
(
edge_timestamp
.
has_value
())
{
mask
&=
edge_timestamp
.
value
().
slice
(
0
,
l
,
r
)
<=
seed_timestamp
;
}
if
(
probs_or_mask
.
has_value
())
{
mask
&=
probs_or_mask
.
value
().
slice
(
0
,
l
,
r
)
!=
0
;
}
return
mask
;
}
int64_t
TemporalNumPick
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indics
,
int64_t
fanout
,
bool
replace
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
)
{
auto
mask
=
TemporalMask
(
utils
::
GetValueByIndex
<
int64_t
>
(
seed_timestamp
,
seed_offset
),
csc_indics
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
{
offset
,
offset
+
num_neighbors
});
int64_t
num_valid_neighbors
=
utils
::
GetValueByIndex
<
int64_t
>
(
mask
.
sum
(),
0
);
if
(
num_valid_neighbors
==
0
||
fanout
==
-
1
)
return
num_valid_neighbors
;
return
replace
?
fanout
:
std
::
min
(
fanout
,
num_valid_neighbors
);
}
int64_t
NumPickByEtype
(
int64_t
NumPickByEtype
(
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
Tensor
&
type_per_edge
,
...
@@ -699,6 +812,40 @@ int64_t NumPickByEtype(
...
@@ -699,6 +812,40 @@ int64_t NumPickByEtype(
return
total_count
;
return
total_count
;
}
}
int64_t
TemporalNumPickByEtype
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
)
{
int64_t
etype_begin
=
offset
;
const
int64_t
end
=
offset
+
num_neighbors
;
int64_t
total_count
=
0
;
AT_DISPATCH_INTEGRAL_TYPES
(
type_per_edge
.
scalar_type
(),
"TemporalNumPickFnByEtype"
,
([
&
]
{
const
scalar_t
*
type_per_edge_data
=
type_per_edge
.
data_ptr
<
scalar_t
>
();
while
(
etype_begin
<
end
)
{
scalar_t
etype
=
type_per_edge_data
[
etype_begin
];
TORCH_CHECK
(
etype
>=
0
&&
etype
<
(
int64_t
)
fanouts
.
size
(),
"Etype values exceed the number of fanouts."
);
auto
etype_end_it
=
std
::
upper_bound
(
type_per_edge_data
+
etype_begin
,
type_per_edge_data
+
end
,
etype
);
int64_t
etype_end
=
etype_end_it
-
type_per_edge_data
;
// Do sampling for one etype.
total_count
+=
TemporalNumPick
(
seed_timestamp
,
csc_indices
,
fanouts
[
etype
],
replace
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
seed_offset
,
etype_begin
,
etype_end
-
etype_begin
);
etype_begin
=
etype_end
;
}
}));
return
total_count
;
}
/**
/**
* @brief Perform uniform sampling of elements and return the sampled indices.
* @brief Perform uniform sampling of elements and return the sampled indices.
*
*
...
@@ -983,6 +1130,35 @@ int64_t Pick(
...
@@ -983,6 +1130,35 @@ int64_t Pick(
}
}
}
}
template
<
typename
PickedType
>
int64_t
TemporalPick
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
PickedType
*
picked_data_ptr
)
{
auto
mask
=
TemporalMask
(
utils
::
GetValueByIndex
<
int64_t
>
(
seed_timestamp
,
seed_offset
),
csc_indices
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
{
offset
,
offset
+
num_neighbors
});
torch
::
Tensor
masked_prob
;
if
(
probs_or_mask
.
has_value
())
{
masked_prob
=
probs_or_mask
.
value
().
slice
(
0
,
offset
,
offset
+
num_neighbors
)
*
mask
;
}
else
{
masked_prob
=
mask
.
to
(
torch
::
kFloat32
);
}
auto
picked_indices
=
NonUniformPickOp
(
masked_prob
,
fanout
,
replace
);
auto
picked_indices_ptr
=
picked_indices
.
data_ptr
<
int64_t
>
();
for
(
int
i
=
0
;
i
<
picked_indices
.
numel
();
++
i
)
{
picked_data_ptr
[
i
]
=
static_cast
<
PickedType
>
(
picked_indices_ptr
[
i
])
+
offset
;
}
return
picked_indices
.
numel
();
}
template
<
SamplerType
S
,
typename
PickedType
>
template
<
SamplerType
S
,
typename
PickedType
>
int64_t
PickByEtype
(
int64_t
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
...
@@ -1020,6 +1196,48 @@ int64_t PickByEtype(
...
@@ -1020,6 +1196,48 @@ int64_t PickByEtype(
return
pick_offset
;
return
pick_offset
;
}
}
template
<
typename
PickedType
>
int64_t
TemporalPickByEtype
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_timestamp
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
PickedType
*
picked_data_ptr
)
{
int64_t
etype_begin
=
offset
;
int64_t
etype_end
=
offset
;
int64_t
pick_offset
=
0
;
AT_DISPATCH_INTEGRAL_TYPES
(
type_per_edge
.
scalar_type
(),
"TemporalPickByEtype"
,
([
&
]
{
const
scalar_t
*
type_per_edge_data
=
type_per_edge
.
data_ptr
<
scalar_t
>
();
const
auto
end
=
offset
+
num_neighbors
;
while
(
etype_begin
<
end
)
{
scalar_t
etype
=
type_per_edge_data
[
etype_begin
];
TORCH_CHECK
(
etype
>=
0
&&
etype
<
(
int64_t
)
fanouts
.
size
(),
"Etype values exceed the number of fanouts."
);
int64_t
fanout
=
fanouts
[
etype
];
auto
etype_end_it
=
std
::
upper_bound
(
type_per_edge_data
+
etype_begin
,
type_per_edge_data
+
end
,
etype
);
etype_end
=
etype_end_it
-
type_per_edge_data
;
// Do sampling for one etype.
if
(
fanout
!=
0
)
{
int64_t
picked_count
=
TemporalPick
(
seed_timestamp
,
csc_indices
,
seed_offset
,
etype_begin
,
etype_end
-
etype_begin
,
fanout
,
replace
,
options
,
probs_or_mask
,
node_timestamp
,
edge_timestamp
,
picked_data_ptr
+
pick_offset
);
pick_offset
+=
picked_count
;
}
etype_begin
=
etype_end
;
}
}));
return
pick_offset
;
}
template
<
typename
PickedType
>
template
<
typename
PickedType
>
int64_t
Pick
(
int64_t
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
...
...
graphbolt/src/utils.h
View file @
e42c7fcd
...
@@ -19,6 +19,29 @@ inline bool is_accessible_from_gpu(torch::Tensor tensor) {
...
@@ -19,6 +19,29 @@ inline bool is_accessible_from_gpu(torch::Tensor tensor) {
return
tensor
.
is_pinned
()
||
tensor
.
device
().
type
()
==
c10
::
DeviceType
::
CUDA
;
return
tensor
.
is_pinned
()
||
tensor
.
device
().
type
()
==
c10
::
DeviceType
::
CUDA
;
}
}
/**
* @brief Retrieves the value of the tensor at the given index.
*
* @note If the tensor is not contiguous, it will be copied to a contiguous
* tensor.
*
* @tparam T The type of the tensor.
* @param tensor The tensor.
* @param index The index.
*
* @return T The value of the tensor at the given index.
*/
template
<
typename
T
>
T
GetValueByIndex
(
const
torch
::
Tensor
&
tensor
,
int64_t
index
)
{
TORCH_CHECK
(
index
>=
0
&&
index
<
tensor
.
numel
(),
"The index should be within the range of the tensor, but got index "
,
index
,
" and tensor size "
,
tensor
.
numel
());
auto
contiguous_tensor
=
tensor
.
contiguous
();
auto
data_ptr
=
contiguous_tensor
.
data_ptr
<
T
>
();
return
data_ptr
[
index
];
}
}
// namespace utils
}
// namespace utils
}
// namespace graphbolt
}
// namespace graphbolt
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
e42c7fcd
...
@@ -439,11 +439,18 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -439,11 +439,18 @@ class FusedCSCSamplingGraph(SamplingGraph):
node_pairs
=
node_pairs
,
original_edge_ids
=
original_edge_ids
node_pairs
=
node_pairs
,
original_edge_ids
=
original_edge_ids
)
)
def
_convert_to_homogeneous_nodes
(
self
,
nodes
):
def
_convert_to_homogeneous_nodes
(
self
,
nodes
,
timestamps
=
None
):
homogeneous_nodes
=
[]
homogeneous_nodes
=
[]
homogeneous_timestamps
=
[]
for
ntype
,
ids
in
nodes
.
items
():
for
ntype
,
ids
in
nodes
.
items
():
ntype_id
=
self
.
node_type_to_id
[
ntype
]
ntype_id
=
self
.
node_type_to_id
[
ntype
]
homogeneous_nodes
.
append
(
ids
+
self
.
node_type_offset
[
ntype_id
])
homogeneous_nodes
.
append
(
ids
+
self
.
node_type_offset
[
ntype_id
])
if
timestamps
is
not
None
:
homogeneous_timestamps
.
append
(
timestamps
[
ntype
])
if
timestamps
is
not
None
:
return
torch
.
cat
(
homogeneous_nodes
),
torch
.
cat
(
homogeneous_timestamps
)
return
torch
.
cat
(
homogeneous_nodes
)
return
torch
.
cat
(
homogeneous_nodes
)
def
_convert_to_sampled_subgraph
(
def
_convert_to_sampled_subgraph
(
...
@@ -830,7 +837,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -830,7 +837,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
else
:
else
:
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
def
_
temporal_sample_neighbors
(
def
temporal_sample_neighbors
(
self
,
self
,
nodes
:
torch
.
Tensor
,
nodes
:
torch
.
Tensor
,
input_nodes_timestamp
:
torch
.
Tensor
,
input_nodes_timestamp
:
torch
.
Tensor
,
...
@@ -887,26 +894,39 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -887,26 +894,39 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns
Returns
-------
-------
torch.classes.graphbolt.
SampledSubgraph
Fused
SampledSubgraph
Impl
The sampled
C
subgraph.
The sampled subgraph.
"""
"""
if
isinstance
(
nodes
,
dict
):
nodes
,
input_nodes_timestamp
=
self
.
_convert_to_homogeneous_nodes
(
nodes
,
input_nodes_timestamp
)
# Ensure nodes is 1-D tensor.
# Ensure nodes is 1-D tensor.
self
.
_check_sampler_arguments
(
nodes
,
fanouts
,
probs_name
)
self
.
_check_sampler_arguments
(
nodes
,
fanouts
,
probs_name
)
has_original_eids
=
(
has_original_eids
=
(
self
.
edge_attributes
is
not
None
self
.
edge_attributes
is
not
None
and
ORIGINAL_EDGE_ID
in
self
.
edge_attributes
and
ORIGINAL_EDGE_ID
in
self
.
edge_attributes
)
)
return
self
.
_c_csc_graph
.
temporal_sample_neighbors
(
C_sampled_subgraph
=
self
.
_c_csc_graph
.
temporal_sample_neighbors
(
nodes
,
nodes
,
input_nodes_timestamp
,
input_nodes_timestamp
,
fanouts
.
tolist
(),
fanouts
.
tolist
(),
replace
,
replace
,
False
,
has_original_eids
,
has_original_eids
,
probs_name
,
probs_name
,
node_timestamp_attr_name
,
node_timestamp_attr_name
,
edge_timestamp_attr_name
,
edge_timestamp_attr_name
,
)
)
# Broadcast the input nodes' timestamp to the sampled neighbors.
sampled_count
=
torch
.
diff
(
C_sampled_subgraph
.
indptr
)
neighbors_timestamp
=
input_nodes_timestamp
.
repeat_interleave
(
sampled_count
)
return
(
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
),
neighbors_timestamp
,
)
def
sample_negative_edges_uniform
(
def
sample_negative_edges_uniform
(
self
,
edge_type
,
node_pairs
,
negative_ratio
self
,
edge_type
,
node_pairs
,
negative_ratio
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment