"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bf16a97018fcb351b552043c89cb0152317ac3f9"
Unverified Commit 2e6ded06 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Handle edge case for fetch feature overlap (#6979)

parent f7e065f7
...@@ -92,6 +92,8 @@ class FeatureFetcher(MiniBatchTransformer): ...@@ -92,6 +92,8 @@ class FeatureFetcher(MiniBatchTransformer):
nodes = input_nodes[type_name] nodes = input_nodes[type_name]
if nodes is None: if nodes is None:
continue continue
if nodes.is_cuda:
nodes.record_stream(torch.cuda.current_stream())
for feature_name in feature_names: for feature_name in feature_names:
node_features[ node_features[
(type_name, feature_name) (type_name, feature_name)
...@@ -104,6 +106,8 @@ class FeatureFetcher(MiniBatchTransformer): ...@@ -104,6 +106,8 @@ class FeatureFetcher(MiniBatchTransformer):
) )
) )
else: else:
if input_nodes.is_cuda:
input_nodes.record_stream(torch.cuda.current_stream())
for feature_name in self.node_feature_keys: for feature_name in self.node_feature_keys:
node_features[feature_name] = record_stream( node_features[feature_name] = record_stream(
self.feature_store.read( self.feature_store.read(
...@@ -134,6 +138,8 @@ class FeatureFetcher(MiniBatchTransformer): ...@@ -134,6 +138,8 @@ class FeatureFetcher(MiniBatchTransformer):
edges = original_edge_ids.get(type_name, None) edges = original_edge_ids.get(type_name, None)
if edges is None: if edges is None:
continue continue
if edges.is_cuda:
edges.record_stream(torch.cuda.current_stream())
for feature_name in feature_names: for feature_name in feature_names:
edge_features[i][ edge_features[i][
(type_name, feature_name) (type_name, feature_name)
...@@ -143,6 +149,10 @@ class FeatureFetcher(MiniBatchTransformer): ...@@ -143,6 +149,10 @@ class FeatureFetcher(MiniBatchTransformer):
) )
) )
else: else:
if original_edge_ids.is_cuda:
original_edge_ids.record_stream(
torch.cuda.current_stream()
)
for feature_name in self.edge_feature_keys: for feature_name in self.edge_feature_keys:
edge_features[i][feature_name] = record_stream( edge_features[i][feature_name] = record_stream(
self.feature_store.read( self.feature_store.read(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment