Unverified Commit 4b265390 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Fix link prediction early-stop. (#7083)

parent f3af2a9f
......@@ -332,6 +332,8 @@ def train(args, model, graph, features, train_set):
total_loss += loss.item()
if step + 1 == args.early_stop:
# Early stopping requires a new dataloader to reset its state.
dataloader = create_dataloader(args, graph, features, train_set)
break
end_epoch_time = time.time()
......
......@@ -174,10 +174,5 @@ class FeatureFetcher(MiniBatchTransformer):
with torch.cuda.stream(self.stream):
data = self._read_data(data, current_stream)
if self.stream is not None:
event = torch.cuda.current_stream().record_event()
def _wait():
event.wait()
data.wait = _wait
data.wait = torch.cuda.current_stream().record_event().wait
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment