test_stream.py 1.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from dgl import rand_graph
import dgl._ffi.streams as FS
import dgl.ops as OPS
import unittest
import backend as F
import torch


@unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
def test_basics():
    g = rand_graph(10, 20, device=F.cpu())
    x = torch.ones(g.num_nodes(), 10)

    # launch on default stream fetched via torch.cuda
    s = torch.cuda.default_stream(device=F.ctx())
    with torch.cuda.stream(s):
        xx = x.to(device=F.ctx(), non_blocking=True)
    with FS.stream(s):
        gg = g.to(device=F.ctx())
    s.synchronize()
    OPS.copy_u_sum(gg, xx)

    # launch on new stream created via torch.cuda
    s = torch.cuda.Stream(device=F.ctx())
    with torch.cuda.stream(s):
        xx = x.to(device=F.ctx(), non_blocking=True)
    with FS.stream(s):
        gg = g.to(device=F.ctx())
    s.synchronize()
    OPS.copy_u_sum(gg, xx)

    # launch on default stream used in DGL
    xx = x.to(device=F.ctx())
    gg = g.to(device=F.ctx())
    OPS.copy_u_sum(gg, xx)


if __name__ == '__main__':
    test_basics()