Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2e6ded06
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bf16a97018fcb351b552043c89cb0152317ac3f9"
Unverified
Commit
2e6ded06
authored
Jan 18, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Jan 18, 2024
Browse files
[GraphBolt][CUDA] Handle edge case for fetch feature overlap (#6979)
parent
f7e065f7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
0 deletions
+10
-0
python/dgl/graphbolt/feature_fetcher.py
python/dgl/graphbolt/feature_fetcher.py
+10
-0
No files found.
python/dgl/graphbolt/feature_fetcher.py
View file @
2e6ded06
...
@@ -92,6 +92,8 @@ class FeatureFetcher(MiniBatchTransformer):
...
@@ -92,6 +92,8 @@ class FeatureFetcher(MiniBatchTransformer):
nodes
=
input_nodes
[
type_name
]
nodes
=
input_nodes
[
type_name
]
if
nodes
is
None
:
if
nodes
is
None
:
continue
continue
if
nodes
.
is_cuda
:
nodes
.
record_stream
(
torch
.
cuda
.
current_stream
())
for
feature_name
in
feature_names
:
for
feature_name
in
feature_names
:
node_features
[
node_features
[
(
type_name
,
feature_name
)
(
type_name
,
feature_name
)
...
@@ -104,6 +106,8 @@ class FeatureFetcher(MiniBatchTransformer):
...
@@ -104,6 +106,8 @@ class FeatureFetcher(MiniBatchTransformer):
)
)
)
)
else
:
else
:
if
input_nodes
.
is_cuda
:
input_nodes
.
record_stream
(
torch
.
cuda
.
current_stream
())
for
feature_name
in
self
.
node_feature_keys
:
for
feature_name
in
self
.
node_feature_keys
:
node_features
[
feature_name
]
=
record_stream
(
node_features
[
feature_name
]
=
record_stream
(
self
.
feature_store
.
read
(
self
.
feature_store
.
read
(
...
@@ -134,6 +138,8 @@ class FeatureFetcher(MiniBatchTransformer):
...
@@ -134,6 +138,8 @@ class FeatureFetcher(MiniBatchTransformer):
edges
=
original_edge_ids
.
get
(
type_name
,
None
)
edges
=
original_edge_ids
.
get
(
type_name
,
None
)
if
edges
is
None
:
if
edges
is
None
:
continue
continue
if
edges
.
is_cuda
:
edges
.
record_stream
(
torch
.
cuda
.
current_stream
())
for
feature_name
in
feature_names
:
for
feature_name
in
feature_names
:
edge_features
[
i
][
edge_features
[
i
][
(
type_name
,
feature_name
)
(
type_name
,
feature_name
)
...
@@ -143,6 +149,10 @@ class FeatureFetcher(MiniBatchTransformer):
...
@@ -143,6 +149,10 @@ class FeatureFetcher(MiniBatchTransformer):
)
)
)
)
else
:
else
:
if
original_edge_ids
.
is_cuda
:
original_edge_ids
.
record_stream
(
torch
.
cuda
.
current_stream
()
)
for
feature_name
in
self
.
edge_feature_keys
:
for
feature_name
in
self
.
edge_feature_keys
:
edge_features
[
i
][
feature_name
]
=
record_stream
(
edge_features
[
i
][
feature_name
]
=
record_stream
(
self
.
feature_store
.
read
(
self
.
feature_store
.
read
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment