Unverified Commit 2e6ded06 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Handle edge case for fetch feature overlap (#6979)

parent f7e065f7
......@@ -92,6 +92,8 @@ class FeatureFetcher(MiniBatchTransformer):
nodes = input_nodes[type_name]
if nodes is None:
continue
if nodes.is_cuda:
nodes.record_stream(torch.cuda.current_stream())
for feature_name in feature_names:
node_features[
(type_name, feature_name)
......@@ -104,6 +106,8 @@ class FeatureFetcher(MiniBatchTransformer):
)
)
else:
if input_nodes.is_cuda:
input_nodes.record_stream(torch.cuda.current_stream())
for feature_name in self.node_feature_keys:
node_features[feature_name] = record_stream(
self.feature_store.read(
......@@ -134,6 +138,8 @@ class FeatureFetcher(MiniBatchTransformer):
edges = original_edge_ids.get(type_name, None)
if edges is None:
continue
if edges.is_cuda:
edges.record_stream(torch.cuda.current_stream())
for feature_name in feature_names:
edge_features[i][
(type_name, feature_name)
......@@ -143,6 +149,10 @@ class FeatureFetcher(MiniBatchTransformer):
)
)
else:
if original_edge_ids.is_cuda:
original_edge_ids.record_stream(
torch.cuda.current_stream()
)
for feature_name in self.edge_feature_keys:
edge_features[i][feature_name] = record_stream(
self.feature_store.read(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment