test_stream.py 5.99 KB
Newer Older
1
2
3
4
5
6
from statistics import mean
import unittest
import numpy as np
import torch
import dgl
import dgl.ndarray as nd
7
8
from dgl import rand_graph
import dgl.ops as OPS
9
10
from dgl._ffi.streams import to_dgl_stream_handle, _dgl_get_stream
from dgl.utils import to_dgl_context
11
12
import backend as F

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# borrowed from PyTorch, torch/testing/_internal/common_utils.py
def _get_cycles_per_ms() -> float:
    """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
    """

    def measure() -> float:
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        torch.cuda._sleep(1000000)
        end.record()
        end.synchronize()
        cycles_per_ms = 1000000 / start.elapsed_time(end)
        return cycles_per_ms

    # Get 10 values and remove the 2 max and 2 min and return the avg.
    # This is to avoid system disturbance that skew the results, e.g.
    # the very first cuda call likely does a bunch of init, which takes
    # much longer than subsequent calls.
    num = 10
    vals = []
    for _ in range(num):
        vals.append(measure())
    vals = sorted(vals)
    return mean(vals[2 : num - 2])
38
39
40
41
42

@unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
def test_basics():
    g = rand_graph(10, 20, device=F.cpu())
    x = torch.ones(g.num_nodes(), 10)
43
    result = OPS.copy_u_sum(g, x).to(F.ctx())
44

45
46
47
    # launch on default stream used in DGL
    xx = x.to(device=F.ctx())
    gg = g.to(device=F.ctx())
48
    OPS.copy_u_sum(gg, xx)
49
    assert torch.equal(OPS.copy_u_sum(gg, xx), result)
50
51
52
53
54
55

    # launch on new stream created via torch.cuda
    s = torch.cuda.Stream(device=F.ctx())
    with torch.cuda.stream(s):
        xx = x.to(device=F.ctx(), non_blocking=True)
        gg = g.to(device=F.ctx())
56
        OPS.copy_u_sum(gg, xx)
57
    s.synchronize()
58
    assert torch.equal(OPS.copy_u_sum(gg, xx), result)
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
def test_set_get_stream():
    current_stream = torch.cuda.current_stream()
    # test setting another stream
    s = torch.cuda.Stream(device=F.ctx())
    torch.cuda.set_stream(s)
    assert to_dgl_stream_handle(s).value == _dgl_get_stream(to_dgl_context(F.ctx())).value
    # revert to default stream
    torch.cuda.set_stream(current_stream)

@unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
# borrowed from PyTorch, test/test_cuda.py: test_record_stream()
def test_record_stream_ndarray():
    cycles_per_ms = _get_cycles_per_ms()

    t = nd.array(np.array([1., 2., 3., 4.], dtype=np.float32), ctx=nd.cpu())
    t.pin_memory_()
    result = nd.empty([4], ctx=nd.gpu(0))
    stream = torch.cuda.Stream()
    ptr = [None]

    # Performs the CPU->GPU copy in a background stream
    def perform_copy():
        with torch.cuda.stream(stream):
            tmp = t.copyto(nd.gpu(0))
            ptr[0] = F.from_dgl_nd(tmp).data_ptr()
        torch.cuda.current_stream().wait_stream(stream)
        tmp.record_stream(
            to_dgl_stream_handle(torch.cuda.current_stream()))
        torch.cuda._sleep(int(50 * cycles_per_ms))  # delay the copy
        result.copyfrom(tmp)

    perform_copy()
    with torch.cuda.stream(stream):
        tmp2 = nd.empty([4], ctx=nd.gpu(0))
        assert F.from_dgl_nd(tmp2).data_ptr() != ptr[0], 'allocation re-used too soon'

    assert torch.equal(F.from_dgl_nd(result).cpu(), torch.tensor([1., 2., 3., 4.]))

    # Check that the block will be re-used after the main stream finishes
    torch.cuda.current_stream().synchronize()
    with torch.cuda.stream(stream):
        tmp3 = nd.empty([4], ctx=nd.gpu(0))
        assert F.from_dgl_nd(tmp3).data_ptr() == ptr[0], 'allocation not re-used'

@unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
def test_record_stream_graph_positive():
    cycles_per_ms = _get_cycles_per_ms()

    g = rand_graph(10, 20, device=F.cpu())
    x = torch.ones(g.num_nodes(), 10)
    result = OPS.copy_u_sum(g, x).to(F.ctx())

    stream = torch.cuda.Stream()
    results2 = torch.zeros_like(result)
    # Performs the computing in a background stream
    def perform_computing():
        with torch.cuda.stream(stream):
            g2 = g.to(F.ctx())
        torch.cuda.current_stream().wait_stream(stream)
        g2.record_stream(torch.cuda.current_stream())
        torch.cuda._sleep(int(50 * cycles_per_ms))  # delay the computing
        results2.copy_(OPS.copy_u_sum(g2, x))

    x = x.to(F.ctx())
    perform_computing()
    with torch.cuda.stream(stream):
        # since we have called record stream for g2, g3 won't reuse its memory
        g3 = rand_graph(10, 20, device=F.ctx())
    torch.cuda.current_stream().synchronize()
    assert torch.equal(result, results2)

@unittest.skipIf(F._default_context_str == 'cpu', reason="stream only runs on GPU.")
def test_record_stream_graph_negative():
    cycles_per_ms = _get_cycles_per_ms()

    g = rand_graph(10, 20, device=F.cpu())
    x = torch.ones(g.num_nodes(), 10)
    result = OPS.copy_u_sum(g, x).to(F.ctx())

    stream = torch.cuda.Stream()
    results2 = torch.zeros_like(result)
    # Performs the computing in a background stream
    def perform_computing():
        with torch.cuda.stream(stream):
            g2 = g.to(F.ctx())
        torch.cuda.current_stream().wait_stream(stream)
        # omit record_stream will produce a wrong result
        # g2.record_stream(torch.cuda.current_stream())
        torch.cuda._sleep(int(50 * cycles_per_ms))  # delay the computing
        results2.copy_(OPS.copy_u_sum(g2, x))
151

152
153
154
155
156
157
158
    x = x.to(F.ctx())
    perform_computing()
    with torch.cuda.stream(stream):
        # g3 will reuse g2's memory block, resulting a wrong result
        g3 = rand_graph(10, 20, device=F.ctx())
    torch.cuda.current_stream().synchronize()
    assert not torch.equal(result, results2)
159
160
161

if __name__ == '__main__':
    test_basics()
162
163
164
165
    test_set_get_stream()
    test_record_stream_ndarray()
    test_record_stream_graph_positive()
    test_record_stream_graph_negative()