test_kill_command.py 3.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
import multiprocessing
import os
import subprocess
import signal
import sys
import signal
import time

import pytest

from nni.tools.nnictl.command_utils import kill_command, _check_pid_running

# Windows sometimes fail with "Terminate batch job (Y/N)?"
pytestmark = pytest.mark.skipif(sys.platform == 'win32', reason='Windows has confirmation upon process killing.')


def process_normal():
    time.sleep(360)


def process_kill_slow(kill_time=2):
    def handler_stop_signals(signum, frame):
        time.sleep(kill_time)
        sys.exit(0)

    signal.signal(signal.SIGINT, handler_stop_signals)
    signal.signal(signal.SIGTERM, handler_stop_signals)
    time.sleep(360)


def process_patiently_kill():
    process = subprocess.Popen([sys.executable, __file__, '--mode', 'kill_very_slow'])
    time.sleep(1)
    kill_command(process.pid)  # wait long enough


def test_kill_process():
    process = multiprocessing.Process(target=process_normal)
    process.start()

    time.sleep(0.5)
    start_time = time.time()
    kill_command(process.pid)
    end_time = time.time()
    assert not _check_pid_running(process.pid)
    assert end_time - start_time < 2


def test_kill_process_slow_no_patience():
    process = subprocess.Popen([sys.executable, __file__, '--mode', 'kill_slow'])
    time.sleep(1)  # wait 1 second for the process to launch and register hooks
    start_time = time.time()
    kill_command(process.pid, timeout=1)  # didn't wait long enough
    end_time = time.time()
Yuge Zhang's avatar
Yuge Zhang committed
56
57
58
    if sys.platform == 'linux':
        # There was assert 0.5 < end_time - start_time. It's not stable.
        assert end_time - start_time < 2
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        assert process.poll() is None
        assert _check_pid_running(process.pid)
    else:
        assert end_time - start_time < 2
    # Wait more seconds and it will exit eventually
    for _ in range(20):
        time.sleep(1)
        if not _check_pid_running(process.pid):
            return


def test_kill_process_slow_patiently():
    process = subprocess.Popen([sys.executable, __file__, '--mode', 'kill_slow'])
    time.sleep(1)  # wait 1 second for the process to launch and register hooks
    start_time = time.time()
    kill_command(process.pid, timeout=3)  # wait long enough
    end_time = time.time()
    assert end_time - start_time < 5
Yuge Zhang's avatar
Yuge Zhang committed
77
    # assert end_time - start_time > 1  # This check is disabled because it's not stable
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


@pytest.mark.skipif(sys.platform != 'linux', reason='Signal issues on non-linux.')
def test_kill_process_interrupted():
    # Launch a subprocess that launches and kills another subprocess
    process = multiprocessing.Process(target=process_patiently_kill)
    process.start()
    time.sleep(3)

    os.kill(process.pid, signal.SIGINT)
    # it doesn't work
    assert process.is_alive()  # Sometimes this is false on darwin.
    time.sleep(0.5)
    # Ctrl+C again.
    os.kill(process.pid, signal.SIGINT)
    time.sleep(0.5)
    assert not process.is_alive()
    if sys.platform == 'linux':
        # exit code could be different on non-linux platforms
        assert process.exitcode != 0


def start_new_process_group(cmd):
    # Otherwise cmd will be killed after this process is killed
    # To mock the behavior of nni experiment launch
    if sys.platform == 'win32':
        return subprocess.Popen(cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
    else:
        return subprocess.Popen(cmd, preexec_fn=os.setpgrp)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices=['kill_slow', 'kill_very_slow'])

    args = parser.parse_args()
    if args.mode == 'kill_slow':
        process_kill_slow()
    elif args.mode == 'kill_very_slow':
        process_kill_slow(15)
    else:
        # debuggings here
        pass


if __name__ == '__main__':
    main()