amlUtil.py 3.17 KB
Newer Older
SparkSnail's avatar
SparkSnail committed
1
2
3
4
5
6
7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import sys
import time
import json
8
import warnings
SparkSnail's avatar
SparkSnail committed
9
from argparse import ArgumentParser
10
11
12
13
from azureml.core import Experiment, RunConfiguration, ScriptRunConfig, Workspace
from azureml.core.authentication import (
    AzureCliAuthentication, InteractiveLoginAuthentication, AuthenticationException
)
SparkSnail's avatar
SparkSnail committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from azureml.core.compute import ComputeTarget
from azureml.core.run import RUNNING_STATES, RunStatus, Run
from azureml.core.conda_dependencies import CondaDependencies

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--subscription_id', help='the subscription id of aml')
    parser.add_argument('--resource_group', help='the resource group of aml')
    parser.add_argument('--workspace_name', help='the workspace name of aml')
    parser.add_argument('--compute_target', help='the compute cluster name of aml')
    parser.add_argument('--docker_image', help='the docker image of job')
    parser.add_argument('--experiment_name', help='the experiment name')
    parser.add_argument('--script_dir', help='script directory')
    parser.add_argument('--script_name', help='script name')
    args = parser.parse_args()

30
31
32
33
34
35
36
37
38
39
40
41
    try:
        auth = AzureCliAuthentication()
        auth.get_token()
    except AuthenticationException as e:
        warnings.warn(
            f'Azure-cli authentication failed: {e}',
            RuntimeWarning
        )
        warnings.warn('Falling back to interactive authentication.', RuntimeWarning)
        auth = InteractiveLoginAuthentication()

    ws = Workspace(args.subscription_id, args.resource_group, args.workspace_name, auth=auth)
SparkSnail's avatar
SparkSnail committed
42
43
44
    compute_target = ComputeTarget(workspace=ws, name=args.compute_target)
    experiment = Experiment(ws, args.experiment_name)
    run_config = RunConfiguration()
SparkSnail's avatar
SparkSnail committed
45
    run_config.environment.python.user_managed_dependencies = True
SparkSnail's avatar
SparkSnail committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    run_config.environment.docker.enabled = True
    run_config.environment.docker.base_image = args.docker_image
    run_config.target = compute_target
    run_config.node_count = 1
    config = ScriptRunConfig(source_directory=args.script_dir, script=args.script_name, run_config=run_config)
    run = experiment.submit(config)
    print(run.get_details()["runId"])
    while True:
        line = sys.stdin.readline().rstrip()
        if line == 'update_status':
            print('status:' + run.get_status())
        elif line == 'tracking_url':
            print('tracking_url:' + run.get_portal_url())
        elif line == 'stop':
            run.cancel()
61
62
63
64
65
66
67
68
            loop_count = 0
            status = run.get_status()
            # wait until the run is canceled
            while status != 'Canceled':
                if loop_count > 5:
                    print('stop_result:failed')
                    exit(0)
                loop_count += 1
J-shang's avatar
J-shang committed
69
70
                time.sleep(5)
                status = run.get_status()
71
            print('stop_result:success')
SparkSnail's avatar
SparkSnail committed
72
73
74
75
76
77
78
            exit(0)
        elif line == 'receive':
            print('receive:' + json.dumps(run.get_metrics()))
        elif line:
            items = line.split(':')
            if items[0] == 'command':
                run.log('nni_manager', line[8:])