launch.py 28.7 KB
Newer Older
1
2
"""Launching tool for DGL distributed training"""
import argparse
3
import json
4
import logging
5
import multiprocessing
6
import os
7
import queue
8
import re
9
10
11
12
13
import signal
import stat
import subprocess
import sys
import time
14
from functools import partial
15
from threading import Thread
16
from typing import Optional
17

18

19
def cleanup_proc(get_all_remote_pids, conn):
20
21
    """This process tries to clean up the remote training tasks."""
    print("cleanupu process runs")
22
23
24
25
26
    # This process should not handle SIGINT.
    signal.signal(signal.SIGINT, signal.SIG_IGN)

    data = conn.recv()
    # If the launch process exits normally, this process doesn't need to do anything.
27
    if data == "exit":
28
29
30
31
32
33
        sys.exit(0)
    else:
        remote_pids = get_all_remote_pids()
        # Otherwise, we need to ssh to each machine and kill the training jobs.
        for (ip, port), pids in remote_pids.items():
            kill_process(ip, port, pids)
34
35
    print("cleanup process exits")

36
37

def kill_process(ip, port, pids):
38
    """ssh to a remote machine and kill the specified processes."""
39
40
41
42
43
44
45
    curr_pid = os.getpid()
    killed_pids = []
    # If we kill child processes first, the parent process may create more again. This happens
    # to Python's process pool. After sorting, we always kill parent processes first.
    pids.sort()
    for pid in pids:
        assert curr_pid != pid
46
47
48
49
50
51
52
53
        print("kill process {} on {}:{}".format(pid, ip, port), flush=True)
        kill_cmd = (
            "ssh -o StrictHostKeyChecking=no -p "
            + str(port)
            + " "
            + ip
            + " 'kill {}'".format(pid)
        )
54
55
56
57
58
59
60
61
62
63
        subprocess.run(kill_cmd, shell=True)
        killed_pids.append(pid)
    # It's possible that some of the processes are not killed. Let's try again.
    for i in range(3):
        killed_pids = get_killed_pids(ip, port, killed_pids)
        if len(killed_pids) == 0:
            break
        else:
            killed_pids.sort()
            for pid in killed_pids:
64
65
66
67
68
69
70
71
72
73
                print(
                    "kill process {} on {}:{}".format(pid, ip, port), flush=True
                )
                kill_cmd = (
                    "ssh -o StrictHostKeyChecking=no -p "
                    + str(port)
                    + " "
                    + ip
                    + " 'kill -9 {}'".format(pid)
                )
74
75
                subprocess.run(kill_cmd, shell=True)

76

77
def get_killed_pids(ip, port, killed_pids):
78
    """Get the process IDs that we want to kill but are still alive."""
79
    killed_pids = [str(pid) for pid in killed_pids]
80
81
82
83
84
85
86
87
    killed_pids = ",".join(killed_pids)
    ps_cmd = (
        "ssh -o StrictHostKeyChecking=no -p "
        + str(port)
        + " "
        + ip
        + " 'ps -p {} -h'".format(killed_pids)
    )
88
89
    res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)
    pids = []
90
    for p in res.stdout.decode("utf-8").split("\n"):
91
92
93
94
95
        l = p.split()
        if len(l) > 0:
            pids.append(int(l[0]))
    return pids

96

97
98
def execute_remote(
    cmd: str,
99
    state_q: queue.Queue,
100
101
    ip: str,
    port: int,
102
    username: Optional[str] = "",
103
104
105
106
107
) -> Thread:
    """Execute command line on remote machine via ssh.

    Args:
        cmd: User-defined command (udf) to execute on the remote host.
108
        state_q: A queue collecting Thread exit states.
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        ip: The ip-address of the host to run the command on.
        port: Port number that the host is listening on.
        thread_list:
        username: Optional. If given, this will specify a username to use when issuing commands over SSH.
            Useful when your infra requires you to explicitly specify a username to avoid permission issues.

    Returns:
        thread: The Thread whose run() is to run the `cmd` on the remote host. Returns when the cmd completes
            on the remote host.
    """
    ip_prefix = ""
    if username:
        ip_prefix += "{username}@".format(username=username)

    # Construct ssh command that executes `cmd` on the remote host
    ssh_cmd = "ssh -o StrictHostKeyChecking=no -p {port} {ip_prefix}{ip} '{cmd}'".format(
        port=str(port),
        ip_prefix=ip_prefix,
        ip=ip,
        cmd=cmd,
    )

131
    # thread func to run the job
132
133
134
135
136
137
138
139
140
141
    def run(ssh_cmd, state_q):
        try:
            subprocess.check_call(ssh_cmd, shell=True)
            state_q.put(0)
        except subprocess.CalledProcessError as err:
            print(f"Called process error {err}")
            state_q.put(err.returncode)
        except Exception:
            state_q.put(-1)

142
143
144
145
146
147
148
    thread = Thread(
        target=run,
        args=(
            ssh_cmd,
            state_q,
        ),
    )
149
150
    thread.setDaemon(True)
    thread.start()
151
152
    # sleep for a while in case of ssh is rejected by peer due to busy connection
    time.sleep(0.2)
153
    return thread
154

155

156
def get_remote_pids(ip, port, cmd_regex):
157
    """Get the process IDs that run the command in the remote machine."""
158
159
160
    pids = []
    curr_pid = os.getpid()
    # Here we want to get the python processes. We may get some ssh processes, so we should filter them out.
161
162
163
164
165
166
167
    ps_cmd = (
        "ssh -o StrictHostKeyChecking=no -p "
        + str(port)
        + " "
        + ip
        + " 'ps -aux | grep python | grep -v StrictHostKeyChecking'"
    )
168
    res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)
169
    for p in res.stdout.decode("utf-8").split("\n"):
170
171
172
173
174
175
176
177
        l = p.split()
        if len(l) < 2:
            continue
        # We only get the processes that run the specified command.
        res = re.search(cmd_regex, p)
        if res is not None and int(l[1]) != curr_pid:
            pids.append(l[1])

178
179
180
181
182
183
184
185
    pid_str = ",".join([str(pid) for pid in pids])
    ps_cmd = (
        "ssh -o StrictHostKeyChecking=no -p "
        + str(port)
        + " "
        + ip
        + " 'pgrep -P {}'".format(pid_str)
    )
186
    res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)
187
    pids1 = res.stdout.decode("utf-8").split("\n")
188
189
    all_pids = []
    for pid in set(pids + pids1):
190
        if pid == "" or int(pid) == curr_pid:
191
192
193
194
195
            continue
        all_pids.append(int(pid))
    all_pids.sort()
    return all_pids

196

197
def get_all_remote_pids(hosts, ssh_port, udf_command):
198
    """Get all remote processes."""
199
200
201
202
203
204
    remote_pids = {}
    for node_id, host in enumerate(hosts):
        ip, _ = host
        # When creating training processes in remote machines, we may insert some arguments
        # in the commands. We need to use regular expressions to match the modified command.
        cmds = udf_command.split()
205
        new_udf_command = " .*".join(cmds)
206
207
208
209
        pids = get_remote_pids(ip, ssh_port, new_udf_command)
        remote_pids[(ip, ssh_port)] = pids
    return remote_pids

210
211
212
213
214
215

def construct_torch_dist_launcher_cmd(
    num_trainers: int,
    num_nodes: int,
    node_rank: int,
    master_addr: str,
216
    master_port: int,
217
218
219
220
221
222
223
224
225
226
227
228
229
230
) -> str:
    """Constructs the torch distributed launcher command.
    Helper function.

    Args:
        num_trainers:
        num_nodes:
        node_rank:
        master_addr:
        master_port:

    Returns:
        cmd_str.
    """
231
232
233
234
235
236
237
238
    torch_cmd_template = (
        "-m torch.distributed.launch "
        "--nproc_per_node={nproc_per_node} "
        "--nnodes={nnodes} "
        "--node_rank={node_rank} "
        "--master_addr={master_addr} "
        "--master_port={master_port}"
    )
239
240
241
242
243
    return torch_cmd_template.format(
        nproc_per_node=num_trainers,
        nnodes=num_nodes,
        node_rank=node_rank,
        master_addr=master_addr,
244
        master_port=master_port,
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    )


def wrap_udf_in_torch_dist_launcher(
    udf_command: str,
    num_trainers: int,
    num_nodes: int,
    node_rank: int,
    master_addr: str,
    master_port: int,
) -> str:
    """Wraps the user-defined function (udf_command) with the torch.distributed.launch module.

     Example: if udf_command is "python3 run/some/trainer.py arg1 arg2", then new_df_command becomes:
         "python3 -m torch.distributed.launch <TORCH DIST ARGS> run/some/trainer.py arg1 arg2

    udf_command is assumed to consist of pre-commands (optional) followed by the python launcher script (required):
    Examples:
        # simple
        python3.7 path/to/some/trainer.py arg1 arg2

        # multi-commands
        (cd some/dir && python3.7 path/to/some/trainer.py arg1 arg2)

    IMPORTANT: If udf_command consists of multiple python commands, then this will result in undefined behavior.

    Args:
        udf_command:
        num_trainers:
        num_nodes:
        node_rank:
        master_addr:
        master_port:

    Returns:

    """
    torch_dist_cmd = construct_torch_dist_launcher_cmd(
        num_trainers=num_trainers,
        num_nodes=num_nodes,
        node_rank=node_rank,
        master_addr=master_addr,
287
        master_port=master_port,
288
289
290
291
292
293
294
    )
    # Auto-detect the python binary that kicks off the distributed trainer code.
    # Note: This allowlist order matters, this will match with the FIRST matching entry. Thus, please add names to this
    #       from most-specific to least-specific order eg:
    #           (python3.7, python3.8) -> (python3)
    # The allowed python versions are from this: https://www.dgl.ai/pages/start.html
    python_bin_allowlist = (
295
296
297
298
299
        "python3.6",
        "python3.7",
        "python3.8",
        "python3.9",
        "python3",
300
        # for backwards compatibility, accept python2 but technically DGL is a py3 library, so this is not recommended
301
302
        "python2.7",
        "python2",
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    )
    # If none of the candidate python bins match, then we go with the default `python`
    python_bin = "python"
    for candidate_python_bin in python_bin_allowlist:
        if candidate_python_bin in udf_command:
            python_bin = candidate_python_bin
            break

    # transforms the udf_command from:
    #     python path/to/dist_trainer.py arg0 arg1
    # to:
    #     python -m torch.distributed.launch [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1
    # Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each
    #       python command within the torch distributed launcher.
317
318
319
    new_udf_command = udf_command.replace(
        python_bin, f"{python_bin} {torch_dist_cmd}"
    )
320
321
322
323

    return new_udf_command


324
325
326
327
328
329
330
331
def construct_dgl_server_env_vars(
    num_samplers: int,
    num_server_threads: int,
    tot_num_clients: int,
    part_config: str,
    ip_config: str,
    num_servers: int,
    graph_format: str,
332
    keep_alive: bool,
333
    pythonpath: Optional[str] = "",
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
) -> str:
    """Constructs the DGL server-specific env vars string that are required for DGL code to behave in the correct
    server role.
    Convenience function.

    Args:
        num_samplers:
        num_server_threads:
        tot_num_clients:
        part_config: Partition config.
            Relative path to workspace.
        ip_config: IP config file containing IP addresses of cluster hosts.
            Relative path to workspace.
        num_servers:
        graph_format:
349
350
        keep_alive:
            Whether to keep server alive when clients exit
351
        pythonpath: Optional. If given, this will pass this as PYTHONPATH.
352
353
354
355
356
357
358
359
360
361
362
363
364
365

    Returns:
        server_env_vars: The server-specific env-vars in a string format, friendly for CLI execution.

    """
    server_env_vars_template = (
        "DGL_ROLE={DGL_ROLE} "
        "DGL_NUM_SAMPLER={DGL_NUM_SAMPLER} "
        "OMP_NUM_THREADS={OMP_NUM_THREADS} "
        "DGL_NUM_CLIENT={DGL_NUM_CLIENT} "
        "DGL_CONF_PATH={DGL_CONF_PATH} "
        "DGL_IP_CONFIG={DGL_IP_CONFIG} "
        "DGL_NUM_SERVER={DGL_NUM_SERVER} "
        "DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} "
366
        "DGL_KEEP_ALIVE={DGL_KEEP_ALIVE} "
367
        "{suffix_optional_envvars}"
368
    )
369
370
371
    suffix_optional_envvars = ""
    if pythonpath:
        suffix_optional_envvars += f"PYTHONPATH={pythonpath} "
372
373
374
375
376
377
378
379
380
    return server_env_vars_template.format(
        DGL_ROLE="server",
        DGL_NUM_SAMPLER=num_samplers,
        OMP_NUM_THREADS=num_server_threads,
        DGL_NUM_CLIENT=tot_num_clients,
        DGL_CONF_PATH=part_config,
        DGL_IP_CONFIG=ip_config,
        DGL_NUM_SERVER=num_servers,
        DGL_GRAPH_FORMAT=graph_format,
381
        DGL_KEEP_ALIVE=int(keep_alive),
382
        suffix_optional_envvars=suffix_optional_envvars,
383
384
385
386
387
388
389
390
391
392
393
    )


def construct_dgl_client_env_vars(
    num_samplers: int,
    tot_num_clients: int,
    part_config: str,
    ip_config: str,
    num_servers: int,
    graph_format: str,
    num_omp_threads: int,
394
    group_id: int,
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    pythonpath: Optional[str] = "",
) -> str:
    """Constructs the DGL client-specific env vars string that are required for DGL code to behave in the correct
    client role.
    Convenience function.

    Args:
        num_samplers:
        tot_num_clients:
        part_config: Partition config.
            Relative path to workspace.
        ip_config: IP config file containing IP addresses of cluster hosts.
            Relative path to workspace.
        num_servers:
        graph_format:
        num_omp_threads:
411
412
        group_id:
            Used in client processes to indicate which group it belongs to.
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        pythonpath: Optional. If given, this will pass this as PYTHONPATH.

    Returns:
        client_env_vars: The client-specific env-vars in a string format, friendly for CLI execution.

    """
    client_env_vars_template = (
        "DGL_DIST_MODE={DGL_DIST_MODE} "
        "DGL_ROLE={DGL_ROLE} "
        "DGL_NUM_SAMPLER={DGL_NUM_SAMPLER} "
        "DGL_NUM_CLIENT={DGL_NUM_CLIENT} "
        "DGL_CONF_PATH={DGL_CONF_PATH} "
        "DGL_IP_CONFIG={DGL_IP_CONFIG} "
        "DGL_NUM_SERVER={DGL_NUM_SERVER} "
        "DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} "
        "OMP_NUM_THREADS={OMP_NUM_THREADS} "
429
        "DGL_GROUP_ID={DGL_GROUP_ID} "
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        "{suffix_optional_envvars}"
    )
    # append optional additional env-vars
    suffix_optional_envvars = ""
    if pythonpath:
        suffix_optional_envvars += f"PYTHONPATH={pythonpath} "
    return client_env_vars_template.format(
        DGL_DIST_MODE="distributed",
        DGL_ROLE="client",
        DGL_NUM_SAMPLER=num_samplers,
        DGL_NUM_CLIENT=tot_num_clients,
        DGL_CONF_PATH=part_config,
        DGL_IP_CONFIG=ip_config,
        DGL_NUM_SERVER=num_servers,
        DGL_GRAPH_FORMAT=graph_format,
        OMP_NUM_THREADS=num_omp_threads,
446
        DGL_GROUP_ID=group_id,
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        suffix_optional_envvars=suffix_optional_envvars,
    )


def wrap_cmd_with_local_envvars(cmd: str, env_vars: str) -> str:
    """Wraps a CLI command with desired env vars with the following properties:
    (1) env vars persist for the entire `cmd`, even if it consists of multiple "chained" commands like:
        cmd = "ls && pwd && python run/something.py"
    (2) env vars don't pollute the environment after `cmd` completes.

    Example:
        >>> cmd = "ls && pwd"
        >>> env_vars = "VAR1=value1 VAR2=value2"
        >>> wrap_cmd_with_local_envvars(cmd, env_vars)
        "(export VAR1=value1 VAR2=value2; ls && pwd)"

    Args:
        cmd:
        env_vars: A string containing env vars, eg "VAR1=val1 VAR2=val2"

    Returns:
        cmd_with_env_vars:

    """
    # use `export` to persist env vars for entire cmd block. required if udf_command is a chain of commands
    # also: wrap in parens to not pollute env:
    #     https://stackoverflow.com/a/45993803
    return f"(export {env_vars}; {cmd})"

476

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
def wrap_cmd_with_extra_envvars(cmd: str, env_vars: list) -> str:
    """Wraps a CLI command with extra env vars

    Example:
        >>> cmd = "ls && pwd"
        >>> env_vars = ["VAR1=value1", "VAR2=value2"]
        >>> wrap_cmd_with_extra_envvars(cmd, env_vars)
        "(export VAR1=value1 VAR2=value2; ls && pwd)"

    Args:
        cmd:
        env_vars: A list of strings containing env vars, e.g., ["VAR1=value1", "VAR2=value2"]

    Returns:
        cmd_with_env_vars:
    """
    env_vars = " ".join(env_vars)
    return wrap_cmd_with_local_envvars(cmd, env_vars)
495

496
497
498
499

g_monitor_file = None
g_group_id = 0

500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
def has_alive_servers(args):
    """Check whether there exists alive servers.

    For each group of long live servers, a monitor file named
    'dgl_dist_monitor_{args.server_name}' is created under '/tmp/' directory.
    We check the existence of this monitor file to determine whether to
    launch new servers or utilize the existing alive ones. If there
    exist alive servers, we obtain availale group ID from the monitor
    file which could be used in current client groups.

    Returns
    -------
    bool
        indicates whether there exists alive servers.
    """
    if args.server_name is None:
        return False
    global g_monitor_file
    global g_group_id
520
    monitor_file = "/tmp/dgl_dist_monitor_" + args.server_name
521
    from filelock import FileLock
522
523

    lock = FileLock(monitor_file + ".lock")
524
525
526
527
    with lock:
        next_group_id = None
        ret = os.path.exists(monitor_file)
        if ret:
528
529
530
531
532
533
            print(
                "Monitor file for alive servers already exist: {}.".format(
                    monitor_file
                )
            )
            lines = [line.rstrip("\n") for line in open(monitor_file)]
534
535
536
537
            g_group_id = int(lines[0])
            next_group_id = g_group_id + 1
        if not ret and args.keep_alive:
            next_group_id = 1
538
539
540
541
542
            print(
                "Monitor file for alive servers is created: {}.".format(
                    monitor_file
                )
            )
543
544
            g_monitor_file = monitor_file
        if next_group_id is not None:
545
            with open(monitor_file, "w") as f:
546
547
548
549
550
551
552
553
554
555
                f.write(str(next_group_id))
    return ret


def clean_alive_servers():
    """Remove keep alive related files"""
    global g_monitor_file
    try:
        if g_monitor_file is not None:
            os.remove(g_monitor_file)
556
557
558
559
560
561
            os.remove(g_monitor_file + ".lock")
            print(
                "Monitor file for alive servers is removed: {}.".format(
                    g_monitor_file
                )
            )
562
    except:
563
564
565
566
567
568
        print(
            "Failed to delete monitor file for alive servers: {}.".format(
                g_monitor_file
            )
        )

569
570
571
572

def get_available_port(ip):
    """Get available port with specified ip."""
    import socket
573

574
575
576
577
578
579
580
581
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    for port in range(1234, 65535):
        try:
            sock.connect((ip, port))
        except:
            return port
    raise RuntimeError("Failed to get available port for ip~{}".format(ip))

582

583
def submit_jobs(args, udf_command, dry_run=False):
584
    """Submit distributed jobs (server and client processes) via ssh"""
585
    if dry_run:
586
587
588
        print(
            "Currently it's in dry run mode which means no jobs will be launched."
        )
589
590
    servers_cmd = []
    clients_cmd = []
591
592
593
    hosts = []
    thread_list = []
    server_count_per_machine = 0
594
595

    # Get the IP addresses of the cluster.
596
    ip_config = os.path.join(args.workspace, args.ip_config)
597
598
    with open(ip_config) as f:
        for line in f:
599
600
601
602
603
604
605
            result = line.strip().split()
            if len(result) == 2:
                ip = result[0]
                port = int(result[1])
                hosts.append((ip, port))
            elif len(result) == 1:
                ip = result[0]
606
                port = get_available_port(ip)
607
608
609
610
                hosts.append((ip, port))
            else:
                raise RuntimeError("Format error of ip_config.")
            server_count_per_machine = args.num_servers
611
    # Get partition info of the graph data
612
    part_config = os.path.join(args.workspace, args.part_config)
613
614
    with open(part_config) as conf_f:
        part_metadata = json.load(conf_f)
615
    assert "num_parts" in part_metadata, "num_parts does not exist."
616
    # The number of partitions must match the number of machines in the cluster.
617
618
619
    assert part_metadata["num_parts"] == len(
        hosts
    ), "The number of graph partitions has to match the number of machines in the cluster."
620

621
    state_q = queue.Queue()
622
    tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts)
623
    # launch server tasks
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
    if not has_alive_servers(args):
        server_env_vars = construct_dgl_server_env_vars(
            num_samplers=args.num_samplers,
            num_server_threads=args.num_server_threads,
            tot_num_clients=tot_num_clients,
            part_config=args.part_config,
            ip_config=args.ip_config,
            num_servers=args.num_servers,
            graph_format=args.graph_format,
            keep_alive=args.keep_alive,
            pythonpath=os.environ.get("PYTHONPATH", ""),
        )
        for i in range(len(hosts) * server_count_per_machine):
            ip, _ = hosts[int(i / server_count_per_machine)]
            server_env_vars_cur = f"{server_env_vars} DGL_SERVER_ID={i}"
            cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur)
640
641
642
643
644
645
            cmd = (
                wrap_cmd_with_extra_envvars(cmd, args.extra_envs)
                if len(args.extra_envs) > 0
                else cmd
            )
            cmd = "cd " + str(args.workspace) + "; " + cmd
646
647
            servers_cmd.append(cmd)
            if not dry_run:
648
649
650
651
652
653
654
655
656
                thread_list.append(
                    execute_remote(
                        cmd,
                        state_q,
                        ip,
                        args.ssh_port,
                        username=args.ssh_username,
                    )
                )
657
658
    else:
        print(f"Use running server {args.server_name}.")
659

660
    # launch client tasks
661
662
663
664
665
666
667
    client_env_vars = construct_dgl_client_env_vars(
        num_samplers=args.num_samplers,
        tot_num_clients=tot_num_clients,
        part_config=args.part_config,
        ip_config=args.ip_config,
        num_servers=args.num_servers,
        graph_format=args.graph_format,
668
669
670
        num_omp_threads=os.environ.get(
            "OMP_NUM_THREADS", str(args.num_omp_threads)
        ),
671
        group_id=g_group_id,
672
673
        pythonpath=os.environ.get("PYTHONPATH", ""),
    )
674

675
676
    master_addr = hosts[0][0]
    master_port = get_available_port(master_addr)
Da Zheng's avatar
Da Zheng committed
677
678
    for node_id, host in enumerate(hosts):
        ip, _ = host
679
680
681
682
683
684
        # Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.launch ... UDF`
        torch_dist_udf_command = wrap_udf_in_torch_dist_launcher(
            udf_command=udf_command,
            num_trainers=args.num_trainers,
            num_nodes=len(hosts),
            node_rank=node_id,
685
            master_addr=master_addr,
686
687
688
689
            master_port=master_port,
        )
        cmd = wrap_cmd_with_local_envvars(
            torch_dist_udf_command, client_env_vars
690
        )
691
692
693
694
695
696
        cmd = (
            wrap_cmd_with_extra_envvars(cmd, args.extra_envs)
            if len(args.extra_envs) > 0
            else cmd
        )
        cmd = "cd " + str(args.workspace) + "; " + cmd
697
698
        clients_cmd.append(cmd)
        if not dry_run:
699
700
701
702
703
            thread_list.append(
                execute_remote(
                    cmd, state_q, ip, args.ssh_port, username=args.ssh_username
                )
            )
704
705
706
707

    # return commands of clients/servers directly if in dry run mode
    if dry_run:
        return clients_cmd, servers_cmd
708

709
    # Start a cleanup process dedicated for cleaning up remote training jobs.
710
    conn1, conn2 = multiprocessing.Pipe()
711
712
713
714
715
    func = partial(get_all_remote_pids, hosts, args.ssh_port, udf_command)
    process = multiprocessing.Process(target=cleanup_proc, args=(func, conn1))
    process.start()

    def signal_handler(signal, frame):
716
        logging.info("Stop launcher")
717
        # We need to tell the cleanup process to kill remote training jobs.
718
        conn2.send("cleanup")
719
        clean_alive_servers()
720
        sys.exit(0)
721

722
723
    signal.signal(signal.SIGINT, signal_handler)

724
    err = 0
725
726
    for thread in thread_list:
        thread.join()
727
728
729
730
731
732
        err_code = state_q.get()
        if err_code != 0:
            # Record err_code
            # We record one of the error if there are multiple
            err = err_code

733
    # The training processes complete. We should tell the cleanup process to exit.
734
    conn2.send("exit")
735
    process.join()
736
737
738
    if err != 0:
        print("Task failed")
        sys.exit(-1)
739

740

741
def main():
742
743
    parser = argparse.ArgumentParser(description="Launch a distributed job")
    parser.add_argument("--ssh_port", type=int, default=22, help="SSH Port.")
744
    parser.add_argument(
745
746
        "--ssh_username",
        default="",
747
        help="Optional. When issuing commands (via ssh) to cluster, use the provided username in the ssh cmd. "
748
749
        "Example: If you provide --ssh_username=bob, then the ssh command will be like: 'ssh bob@1.2.3.4 CMD' "
        "instead of 'ssh 1.2.3.4 CMD'",
750
    )
751
752
753
754
    parser.add_argument(
        "--workspace",
        type=str,
        help="Path of user directory of distributed tasks. \
755
                        This is used to specify a destination location where \
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
                        the contents of current directory will be rsyncd",
    )
    parser.add_argument(
        "--num_trainers",
        type=int,
        help="The number of trainer processes per machine",
    )
    parser.add_argument(
        "--num_omp_threads",
        type=int,
        help="The number of OMP threads per trainer",
    )
    parser.add_argument(
        "--num_samplers",
        type=int,
        default=0,
        help="The number of sampler processes per trainer process",
    )
    parser.add_argument(
        "--num_servers",
        type=int,
        help="The number of server processes per machine",
    )
    parser.add_argument(
        "--part_config",
        type=str,
        help="The file (in workspace) of the partition config",
    )
    parser.add_argument(
        "--ip_config",
        type=str,
        help="The file (in workspace) of IP configuration for server processes",
    )
    parser.add_argument(
        "--num_server_threads",
        type=int,
        default=1,
        help="The number of OMP threads in the server process. \
794
                        It should be small if server processes and trainer processes run on \
795
796
797
798
799
800
801
                        the same machine. By default, it is 1.",
    )
    parser.add_argument(
        "--graph_format",
        type=str,
        default="csc",
        help='The format of the graph structure of each partition. \
802
                        The allowed formats are csr, csc and coo. A user can specify multiple \
803
804
805
806
807
808
809
810
                        formats, separated by ",". For example, the graph format is "csr,csc".',
    )
    parser.add_argument(
        "--extra_envs",
        nargs="+",
        type=str,
        default=[],
        help="Extra environment parameters need to be set. For example, \
811
                        you can set the LD_LIBRARY_PATH and NCCL_DEBUG by adding: \
812
813
814
815
816
817
818
819
820
821
822
823
                        --extra_envs LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH NCCL_DEBUG=INFO ",
    )
    parser.add_argument(
        "--keep_alive",
        action="store_true",
        help="Servers keep alive when clients exit",
    )
    parser.add_argument(
        "--server_name",
        type=str,
        help="Used to check whether there exist alive servers",
    )
824
    args, udf_command = parser.parse_known_args()
825
    if args.keep_alive:
826
827
828
        assert (
            args.server_name is not None
        ), "Server name is required if '--keep_alive' is enabled."
829
        print("Servers will keep alive even clients exit...")
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
    assert len(udf_command) == 1, "Please provide user command line."
    assert (
        args.num_trainers is not None and args.num_trainers > 0
    ), "--num_trainers must be a positive number."
    assert (
        args.num_samplers is not None and args.num_samplers >= 0
    ), "--num_samplers must be a non-negative number."
    assert (
        args.num_servers is not None and args.num_servers > 0
    ), "--num_servers must be a positive number."
    assert (
        args.num_server_threads > 0
    ), "--num_server_threads must be a positive number."
    assert (
        args.workspace is not None
    ), "A user has to specify a workspace with --workspace."
    assert (
        args.part_config is not None
    ), "A user has to specify a partition configuration file with --part_config."
    assert (
        args.ip_config is not None
    ), "A user has to specify an IP configuration file with --ip_config."
852
853
854
    if args.num_omp_threads is None:
        # Here we assume all machines have the same number of CPU cores as the machine
        # where the launch script runs.
855
856
857
858
859
860
861
        args.num_omp_threads = max(
            multiprocessing.cpu_count() // 2 // args.num_trainers, 1
        )
        print(
            "The number of OMP threads per trainer is set to",
            args.num_omp_threads,
        )
862

863
    udf_command = str(udf_command[0])
864
865
866
867
    if "python" not in udf_command:
        raise RuntimeError(
            "DGL launching script can only support Python executable file."
        )
868
869
    submit_jobs(args, udf_command)

870
871
872

if __name__ == "__main__":
    fmt = "%(asctime)s %(levelname)s %(message)s"
873
874
    logging.basicConfig(format=fmt, level=logging.INFO)
    main()