Unverified Commit 188bc2bf authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Fix][Unittest] Fix test stream (#4635)

* fix test stream

* init cusparse handle
parent d78a3a4b
...@@ -107,8 +107,12 @@ def test_record_stream_graph_positive(): ...@@ -107,8 +107,12 @@ def test_record_stream_graph_positive():
cycles_per_ms = _get_cycles_per_ms() cycles_per_ms = _get_cycles_per_ms()
g = rand_graph(10, 20, device=F.cpu()) g = rand_graph(10, 20, device=F.cpu())
x = torch.ones(g.num_nodes(), 10) g.create_formats_()
result = OPS.copy_u_sum(g, x).to(F.ctx()) x = torch.ones(g.num_nodes(), 10).to(F.ctx())
g1 = g.to(F.ctx())
# this is necessary to initialize the cusparse handle
result = OPS.copy_u_sum(g1, x)
torch.cuda.current_stream().synchronize()
stream = torch.cuda.Stream() stream = torch.cuda.Stream()
results2 = torch.zeros_like(result) results2 = torch.zeros_like(result)
...@@ -121,11 +125,11 @@ def test_record_stream_graph_positive(): ...@@ -121,11 +125,11 @@ def test_record_stream_graph_positive():
torch.cuda._sleep(int(50 * cycles_per_ms)) # delay the computing torch.cuda._sleep(int(50 * cycles_per_ms)) # delay the computing
results2.copy_(OPS.copy_u_sum(g2, x)) results2.copy_(OPS.copy_u_sum(g2, x))
x = x.to(F.ctx())
perform_computing() perform_computing()
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
# since we have called record stream for g2, g3 won't reuse its memory # since we have called record stream for g2, g3 won't reuse its memory
g3 = rand_graph(10, 20, device=F.ctx()) g3 = rand_graph(10, 20, device=F.ctx())
g3.create_formats_()
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
assert torch.equal(result, results2) assert torch.equal(result, results2)
...@@ -134,8 +138,12 @@ def test_record_stream_graph_negative(): ...@@ -134,8 +138,12 @@ def test_record_stream_graph_negative():
cycles_per_ms = _get_cycles_per_ms() cycles_per_ms = _get_cycles_per_ms()
g = rand_graph(10, 20, device=F.cpu()) g = rand_graph(10, 20, device=F.cpu())
x = torch.ones(g.num_nodes(), 10) g.create_formats_()
result = OPS.copy_u_sum(g, x).to(F.ctx()) x = torch.ones(g.num_nodes(), 10).to(F.ctx())
g1 = g.to(F.ctx())
# this is necessary to initialize the cusparse handle
result = OPS.copy_u_sum(g1, x)
torch.cuda.current_stream().synchronize()
stream = torch.cuda.Stream() stream = torch.cuda.Stream()
results2 = torch.zeros_like(result) results2 = torch.zeros_like(result)
...@@ -149,11 +157,11 @@ def test_record_stream_graph_negative(): ...@@ -149,11 +157,11 @@ def test_record_stream_graph_negative():
torch.cuda._sleep(int(50 * cycles_per_ms)) # delay the computing torch.cuda._sleep(int(50 * cycles_per_ms)) # delay the computing
results2.copy_(OPS.copy_u_sum(g2, x)) results2.copy_(OPS.copy_u_sum(g2, x))
x = x.to(F.ctx())
perform_computing() perform_computing()
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
# g3 will reuse g2's memory block, resulting a wrong result # g3 will reuse g2's memory block, resulting a wrong result
g3 = rand_graph(10, 20, device=F.ctx()) g3 = rand_graph(10, 20, device=F.ctx())
g3.create_formats_()
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
assert not torch.equal(result, results2) assert not torch.equal(result, results2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment