test_time_range.py 3.44 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import unittest
from unittest import mock

from cupy import cuda
from cupyx import profiler


@unittest.skipUnless(cuda.nvtx.available, 'nvtx is required for time_range')
class TestTimeRange(unittest.TestCase):

    def test_time_range(self):
        push_patch = mock.patch('cupy.cuda.nvtx.RangePush')
        pop_patch = mock.patch('cupy.cuda.nvtx.RangePop')
        with push_patch as push, pop_patch as pop:
            with profiler.time_range('test:time_range', color_id=-1):
                pass
            push.assert_called_once_with('test:time_range', -1)
            pop.assert_called_once_with()

    def test_time_range_with_ARGB(self):
        push_patch = mock.patch('cupy.cuda.nvtx.RangePushC')
        pop_patch = mock.patch('cupy.cuda.nvtx.RangePop')
        with push_patch as push, pop_patch as pop:
            with profiler.time_range('test:time_range_with_ARGB',
                                     argb_color=0xFF00FF00):
                pass
            push.assert_called_once_with(
                'test:time_range_with_ARGB', 0xFF00FF00)
            pop.assert_called_once_with()

    def test_time_range_err(self):
        push_patch = mock.patch('cupy.cuda.nvtx.RangePush')
        pop_patch = mock.patch('cupy.cuda.nvtx.RangePop')
        with push_patch as push, pop_patch as pop:
            try:
                with profiler.time_range('test:time_range_error', -1):
                    raise Exception()
            except Exception:
                pass
            push.assert_called_once_with('test:time_range_error', -1)
            pop.assert_called_once_with()

    def test_time_range_as_decorator(self):
        push_patch = mock.patch('cupy.cuda.nvtx.RangePush')
        pop_patch = mock.patch('cupy.cuda.nvtx.RangePop')
        with push_patch as push, pop_patch as pop:
            @profiler.time_range()
            def f():
                pass
            f()
            # Default value of color id is -1
            push.assert_called_once_with('f', -1)
            pop.assert_called_once_with()

    def test_time_range_as_decorator_with_ARGB(self):
        push_patch = mock.patch('cupy.cuda.nvtx.RangePushC')
        pop_patch = mock.patch('cupy.cuda.nvtx.RangePop')
        with push_patch as push, pop_patch as pop:
            @profiler.time_range(argb_color=0xFFFF0000)
            def f():
                pass
            f()
            push.assert_called_once_with('f', 0xFFFF0000)
            pop.assert_called_once_with()

    def test_time_range_as_decorator_err(self):
        push_patch = mock.patch('cupy.cuda.nvtx.RangePush')
        pop_patch = mock.patch('cupy.cuda.nvtx.RangePop')
        with push_patch as push, pop_patch as pop:
            @profiler.time_range()
            def f():
                raise Exception()
            try:
                f()
            except Exception:
                pass
            # Default value of color id is -1
            push.assert_called_once_with('f', -1)
            pop.assert_called_once_with()


class TestTimeRangeNVTXUnavailable(unittest.TestCase):

    def setUp(self):
        self.nvtx_available = cuda.nvtx.available
        cuda.nvtx.available = False

    def tearDown(self):
        cuda.nvtx.available = self.nvtx_available

    def test_time_range(self):
        with self.assertRaises(RuntimeError):
            with profiler.time_range(''):
                pass

    def test_time_range_decorator(self):
        with self.assertRaises(RuntimeError):
            profiler.time_range()