dlcUtil.py 5.28 KB
Newer Older
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4

5
import logging
6
import os
7
import pathlib
8
import sys
9
import traceback
10
import time
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from argparse import ArgumentParser
# ref: https://help.aliyun.com/document_detail/203290.html?spm=a2c4g.11186623.6.727.6f9b5db6bzJh4x
from alibabacloud_pai_dlc20201203.client import Client
from alibabacloud_tea_openapi.models import Config
from alibabacloud_pai_dlc20201203.models import * #CreateJobRequest, JobSpec

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--type', help='the type of job spec')
    parser.add_argument('--image', help='the docker image of job')
    parser.add_argument('--job_type', choices=['TFJob', 'PyTorchJob'], help='the job type')
    parser.add_argument('--pod_count', type=int, default=1, help='pod count')
    parser.add_argument('--ecs_spec', help='ecs spec')
    parser.add_argument('--region', help='region')
25
    parser.add_argument('--workspace_id', help='workspace id for your project')
26
    parser.add_argument('--nas_data_source_id', help='nas data_source_id of DLC dataset configuration')
27
    parser.add_argument('--oss_data_source_id', help='oss data_source_id of DLC dataset configuration')
28
29
30
31
    parser.add_argument('--access_key_id', help='access_key_id')
    parser.add_argument('--access_key_secret', help='access_key_secret')
    parser.add_argument('--experiment_name', help='the experiment name')
    parser.add_argument('--user_command', help='user command')
32
    parser.add_argument('--log_dir', help='exception log dir')
33
34
    args = parser.parse_args()

35
36
37
38
    pathlib.Path(args.log_dir).mkdir(parents=True, exist_ok=True)
    logging.basicConfig(filename=os.path.join(args.log_dir, 'dlc_exception.log'),
                        format='%(asctime)s %(message)s',
                        level=logging.INFO)
39

40
41
    # DLC submit
    try:
42

43
        # init client
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        if args.region == 'share':
            client = Client(
                Config(
                    access_key_id=args.access_key_id,
                    access_key_secret=args.access_key_secret,
                    endpoint='pai-dlc-share.aliyuncs.com'
                )
            )
        else:
            client = Client(
                Config(
                    access_key_id=args.access_key_id,
                    access_key_secret=args.access_key_secret,
                    region_id=args.region,
                    endpoint=f'pai-dlc.{args.region}.aliyuncs.com'
                )
60
            )
61

62
63
64
65
        nas_1 = DataSourceItem(
            data_source_type='nas',
            data_source_id=args.nas_data_source_id,
        )
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        oss = None
        if args.oss_data_source_id:
            oss = DataSourceItem(
                data_source_type='oss',
                data_source_id=args.oss_data_source_id,
            )

        # job spec
        spec = JobSpec(
            type=args.type,
            image=args.image,
            pod_count=args.pod_count,
            ecs_spec=args.ecs_spec,
        )

82
83
84
85
86
        
        if args.workspace_id == 'None':
            args.workspace_id = None
            logging.info("args.workspace_id %s %s",args.workspace_id,type(args.workspace_id))

87
88
89
90
91
92
93
94
95
96
97
        data_sources = [nas_1]
        if oss:
            data_sources = [nas_1, oss]
        req = CreateJobRequest(
            display_name=args.experiment_name,
            job_type=args.job_type,
            job_specs=[spec],
            data_sources=data_sources,
            user_command=args.user_command,
            workspace_id=args.workspace_id,
        )
98

99
100
        response = client.create_job(req)
        job_id = response.body.job_id
101
        print('job_id:' + job_id)
102

103
104
105
        while True:
            line = sys.stdin.readline().rstrip()
            if line == 'update_status':
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
                # when the dlc sudden failure,such as 503,
                # we will not get the status
                # We'll keep getting the state until we get it
                while True:
                    try:
                        # to avoid user flow control
                        time.sleep(60)
                        status = client.get_job(job_id).body.status
                        logging.info('job_id %s, client.get_job(job_id).body.status %s',job_id, status)
                        print('status:' + status)
                        break
                    except Exception as e:
                        logging.exception('dlc get status error: \n')

                logging.info("exit job_id %s update status",job_id)
121
122
123
124
            elif line == 'tracking_url':
                #TODO: 1. get this url by api? 2. change this url in private dlc mode.
                print('tracking_url:' + f'https://pai-dlc.console.aliyun.com/#/jobs/detail?jobId={job_id}&regionId={args.region}')
            elif line == 'stop':
125
126
127
128
129
130
131
132
133
134
135
136
                # when the dlc 503,we will not stop the job
                # We'll keep stopping the job until we stop it
                while True:
                    try:
                        # to avoid user flow control
                        time.sleep(60)
                        client.stop_job(job_id)
                        exit(0)
                    except Exception as e:
                        logging.exception('dlc stop error: \n')

                        
137
    except Exception as e:
138
139
        logging.exception('DLC submit Exception: \n')