"docs/vscode:/vscode.git/clone" did not exist on "97390468c7104c9b3255050ce22ea382f59fba5e"
test_run.py 3.8 KB
Newer Older
Elton Zheng's avatar
Elton Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import pytest

from deepspeed.pt import deepspeed_run as dsrun


def test_parser_mutual_exclusive():
    '''Ensure dsrun.parse_resource_filter() raises a ValueError when include_str and
    exclude_str are both provided.
    '''
    with pytest.raises(ValueError):
        dsrun.parse_resource_filter({}, include_str='A', exclude_str='B')


def test_parser_local():
    ''' Test cases with only one node. '''
    # First try no incude/exclude
    hosts = {'worker-0': [0, 1, 2, 3]}
    ret = dsrun.parse_resource_filter(hosts)
    assert (ret == hosts)

    # exclude slots
    ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:1')
    assert (ret == {'worker-0': [0, 2, 3]})

    ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:1,2')
    assert (ret == {'worker-0': [0, 3]})

    # only use one slot
    ret = dsrun.parse_resource_filter(hosts, include_str='worker-0:1')
    assert (ret == {'worker-0': [1]})

    # including slots multiple times shouldn't break things
    ret = dsrun.parse_resource_filter(hosts, include_str='worker-0:1,1')
    assert (ret == {'worker-0': [1]})
    ret = dsrun.parse_resource_filter(hosts, include_str='worker-0:1@worker-0:0,1')
    assert (ret == {'worker-0': [0, 1]})

    # including just 'worker-0' without : should still use all GPUs
    ret = dsrun.parse_resource_filter(hosts, include_str='worker-0')
    assert (ret == hosts)

    # excluding just 'worker-0' without : should eliminate everything
    ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0')
    assert (ret == {})

    # exclude all slots manually
    ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:0,1,2,3')
    assert (ret == {})


def test_parser_multinode():
    # First try no incude/exclude
    hosts = {'worker-0': [0, 1, 2, 3], 'worker-1': [0, 1, 2, 3]}
    ret = dsrun.parse_resource_filter(hosts)
    assert (ret == hosts)

    # include a node
    ret = dsrun.parse_resource_filter(hosts, include_str='worker-1:0,3')
    assert (ret == {'worker-1': [0, 3]})

    # exclude a node
    ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-1')
    assert (ret == {'worker-0': [0, 1, 2, 3]})

    # exclude part of each node
    ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:0,1@worker-1:3')
    assert (ret == {'worker-0': [2, 3], 'worker-1': [0, 1, 2]})


def test_parser_errors():
    '''Ensure we catch errors. '''
    hosts = {'worker-0': [0, 1, 2, 3], 'worker-1': [0, 1, 2, 3]}

    # host does not exist
    with pytest.raises(ValueError):
        dsrun.parse_resource_filter(hosts, include_str='jeff')
    with pytest.raises(ValueError):
        dsrun.parse_resource_filter(hosts, exclude_str='jeff')

    # slot does not exist
    with pytest.raises(ValueError):
        dsrun.parse_resource_filter(hosts, include_str='worker-1:4')
    with pytest.raises(ValueError):
        dsrun.parse_resource_filter(hosts, exclude_str='worker-1:4')

    # formatting
    with pytest.raises(ValueError):
        dsrun.parse_resource_filter(hosts, exclude_str='worker-1@worker-0:1@5')


def test_num_plus_parser():
    ''' Ensure we catch errors relating to num_nodes/num_gpus + -i/-e being mutually exclusive'''

    # inclusion
    with pytest.raises(ValueError):
        dsrun.main(args="--num_nodes 1 -i localhost foo.py".split())
    with pytest.raises(ValueError):
        dsrun.main(args="--num_nodes 1 --num_gpus 1 -i localhost foo.py".split())
    with pytest.raises(ValueError):
        dsrun.main(args="--num_gpus 1 -i localhost foo.py".split())

    # exclusion
    with pytest.raises(ValueError):
        dsrun.main(args="--num_nodes 1 -e localhost foo.py".split())
    with pytest.raises(ValueError):
        dsrun.main(args="--num_nodes 1 --num_gpus 1 -e localhost foo.py".split())
    with pytest.raises(ValueError):
        dsrun.main(args="--num_gpus 1 -e localhost foo.py".split())