"docs/archive_en_US/NAS/NasGuide.md" did not exist on "bc0f8f338ba8a7e42e29cbbf47a0edca8244cfcd"
Unverified Commit 403195f0 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge branch 'master' into nn-meter

parents 99aa8226 a7278d2d
...@@ -5,7 +5,7 @@ from collections import Counter ...@@ -5,7 +5,7 @@ from collections import Counter
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from nni.retiarii import Sampler, basic_unit from nni.retiarii import InvalidMutation, Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.execution.python import _unpack_if_only_one from nni.retiarii.execution.python import _unpack_if_only_one
...@@ -520,6 +520,45 @@ class GraphIR(unittest.TestCase): ...@@ -520,6 +520,45 @@ class GraphIR(unittest.TestCase):
model = mutator.bind_sampler(sampler).apply(model) model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64])) self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
def test_nasbench201_cell(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.NasBench201Cell([
lambda x, y: nn.Linear(x, y),
lambda x, y: nn.Linear(x, y, bias=False)
], 10, 16)
def forward(self, x):
return self.cell(x)
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16]))
def test_autoactivation(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.act = nn.AutoActivation()
def forward(self, x):
return self.act(x)
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 10]))
class Python(GraphIR): class Python(GraphIR):
def _get_converted_pytorch_model(self, model_ir): def _get_converted_pytorch_model(self, model_ir):
...@@ -545,3 +584,29 @@ class Python(GraphIR): ...@@ -545,3 +584,29 @@ class Python(GraphIR):
@unittest.skip @unittest.skip
def test_valuechoice_access_functional_expression(self): ... def test_valuechoice_access_functional_expression(self): ...
def test_nasbench101_cell(self):
# this is only supported in python engine for now.
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.NasBench101Cell([lambda x: nn.Linear(x, x), lambda x: nn.Linear(x, x, bias=False)],
10, 16, lambda x, y: nn.Linear(x, y), max_num_nodes=5, max_num_edges=7)
def forward(self, x):
return self.cell(x)
raw_model, mutators = self._get_model_with_mutators(Net())
succeeded = 0
sampler = RandomSampler()
while succeeded <= 10:
try:
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
succeeded += 1
except InvalidMutation:
continue
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16]))
...@@ -116,7 +116,7 @@ class AnalysisUtilsTest(TestCase): ...@@ -116,7 +116,7 @@ class AnalysisUtilsTest(TestCase):
pruner.export_model(ck_file, mask_file) pruner.export_model(ck_file, mask_file)
pruner._unwrap_model() pruner._unwrap_model()
# Fix the mask conflict # Fix the mask conflict
fixed_mask, _ = fix_mask_conflict(mask_file, net, dummy_input) fixed_mask = fix_mask_conflict(mask_file, net, dummy_input)
# use the channel dependency groud truth to check if # use the channel dependency groud truth to check if
# fix the mask conflict successfully # fix the mask conflict successfully
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import logging
import os import os
import gc
import psutil import psutil
import sys import sys
import numpy as np import numpy as np
...@@ -9,18 +11,20 @@ import torch ...@@ -9,18 +11,20 @@ import torch
import torchvision.models as models import torchvision.models as models
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision.models.vgg import vgg16 from torchvision.models.vgg import vgg16, vgg11
from torchvision.models.resnet import resnet18 from torchvision.models.resnet import resnet18
from torchvision.models.mobilenet import mobilenet_v2
import unittest import unittest
from unittest import TestCase, main from unittest import TestCase, main
from nni.compression.pytorch import ModelSpeedup, apply_compression_results from nni.compression.pytorch import ModelSpeedup, apply_compression_results
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, LevelPruner
from nni.algorithms.compression.pytorch.pruning.weight_masker import WeightMasker from nni.algorithms.compression.pytorch.pruning.weight_masker import WeightMasker
from nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner import DependencyAwarePruner from nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner import DependencyAwarePruner
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 2 BATCH_SIZE = 2
# the relative distance # the relative distance
RELATIVE_THRESHOLD = 0.01 RELATIVE_THRESHOLD = 0.01
...@@ -105,6 +109,55 @@ class TransposeModel(torch.nn.Module): ...@@ -105,6 +109,55 @@ class TransposeModel(torch.nn.Module):
return x return x
class TupleUnpack_backbone(nn.Module):
def __init__(self, width):
super(TupleUnpack_backbone, self).__init__()
self.model_backbone = mobilenet_v2(
pretrained=False, width_mult=width, num_classes=3)
def forward(self, x):
x1 = self.model_backbone.features[:7](x)
x2 = self.model_backbone.features[7:14](x1)
x3 = self.model_backbone.features[14:18](x2)
return [x1, x2, x3]
class TupleUnpack_FPN(nn.Module):
def __init__(self):
super(TupleUnpack_FPN, self).__init__()
self.conv1 = nn.Conv2d(32, 48, kernel_size=(
1, 1), stride=(1, 1), bias=False)
self.conv2 = nn.Conv2d(96, 48, kernel_size=(
1, 1), stride=(1, 1), bias=False)
self.conv3 = nn.Conv2d(320, 48, kernel_size=(
1, 1), stride=(1, 1), bias=False)
# self.init_weights()
def forward(self, inputs):
"""Forward function."""
laterals = []
laterals.append(self.conv1(inputs[0])) # inputs[0]==x1
laterals.append(self.conv2(inputs[1])) # inputs[1]==x2
laterals.append(self.conv3(inputs[2])) # inputs[2]==x3
return laterals
class TupleUnpack_Model(nn.Module):
def __init__(self):
super(TupleUnpack_Model, self).__init__()
self.backbone = TupleUnpack_backbone(1.0)
self.fpn = TupleUnpack_FPN()
def forward(self, x):
x1 = self.backbone(x)
out = self.fpn(x1)
return out
dummy_input = torch.randn(2, 1, 28, 28) dummy_input = torch.randn(2, 1, 28, 28)
SPARSITY = 0.5 SPARSITY = 0.5
MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth' MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth'
...@@ -129,6 +182,7 @@ def generate_random_sparsity(model): ...@@ -129,6 +182,7 @@ def generate_random_sparsity(model):
'sparsity': sparsity}) 'sparsity': sparsity})
return cfg_list return cfg_list
def generate_random_sparsity_v2(model): def generate_random_sparsity_v2(model):
""" """
Only select 50% layers to prune. Only select 50% layers to prune.
...@@ -139,9 +193,10 @@ def generate_random_sparsity_v2(model): ...@@ -139,9 +193,10 @@ def generate_random_sparsity_v2(model):
if np.random.uniform(0, 1.0) > 0.5: if np.random.uniform(0, 1.0) > 0.5:
sparsity = np.random.uniform(0.5, 0.99) sparsity = np.random.uniform(0.5, 0.99)
cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name], cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
'sparsity': sparsity}) 'sparsity': sparsity})
return cfg_list return cfg_list
def zero_bn_bias(model): def zero_bn_bias(model):
with torch.no_grad(): with torch.no_grad():
for name, module in model.named_modules(): for name, module in model.named_modules():
...@@ -231,19 +286,6 @@ def channel_prune(model): ...@@ -231,19 +286,6 @@ def channel_prune(model):
class SpeedupTestCase(TestCase): class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self):
prune_model_l1(vgg16())
model = vgg16()
model.train()
ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), MASK_FILE)
ms.speedup_model()
orig_model = vgg16()
assert model.training
assert model.features[2].out_channels == int(
orig_model.features[2].out_channels * SPARSITY)
assert model.classifier[0].in_features == int(
orig_model.classifier[0].in_features * SPARSITY)
def test_speedup_bigmodel(self): def test_speedup_bigmodel(self):
prune_model_l1(BigModel()) prune_model_l1(BigModel())
...@@ -253,7 +295,7 @@ class SpeedupTestCase(TestCase): ...@@ -253,7 +295,7 @@ class SpeedupTestCase(TestCase):
mask_out = model(dummy_input) mask_out = model(dummy_input)
model.train() model.train()
ms = ModelSpeedup(model, dummy_input, MASK_FILE) ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=2)
ms.speedup_model() ms.speedup_model()
assert model.training assert model.training
...@@ -289,7 +331,7 @@ class SpeedupTestCase(TestCase): ...@@ -289,7 +331,7 @@ class SpeedupTestCase(TestCase):
new_model = TransposeModel() new_model = TransposeModel()
state_dict = torch.load(MODEL_FILE) state_dict = torch.load(MODEL_FILE)
new_model.load_state_dict(state_dict) new_model.load_state_dict(state_dict)
ms = ModelSpeedup(new_model, dummy_input, MASK_FILE) ms = ModelSpeedup(new_model, dummy_input, MASK_FILE, confidence=2)
ms.speedup_model() ms.speedup_model()
zero_bn_bias(ori_model) zero_bn_bias(ori_model)
zero_bn_bias(new_model) zero_bn_bias(new_model)
...@@ -297,26 +339,38 @@ class SpeedupTestCase(TestCase): ...@@ -297,26 +339,38 @@ class SpeedupTestCase(TestCase):
new_out = new_model(dummy_input) new_out = new_model(dummy_input)
ori_sum = torch.sum(ori_out) ori_sum = torch.sum(ori_out)
speeded_sum = torch.sum(new_out) speeded_sum = torch.sum(new_out)
print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format(ori_sum, speeded_sum)) print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format(
ori_sum, speeded_sum))
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
# FIXME: This test case might fail randomly, no idea why def test_speedup_integration_small(self):
# Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282 model_list = ['resnet18', 'mobilenet_v2', 'alexnet']
self.speedup_integration(model_list)
def test_speedup_integration_big(self):
model_list = ['vgg11', 'vgg16', 'resnet34', 'squeezenet1_1',
'densenet121', 'resnet50', 'wide_resnet50_2']
mem_info = psutil.virtual_memory()
ava_gb = mem_info.available/1024.0/1024/1024
print('Avaliable memory size: %.2f GB' % ava_gb)
if ava_gb < 8.0:
# memory size is too small that we may run into an OOM exception
# Skip this test in the pipeline test due to memory limitation
return
self.speedup_integration(model_list)
def test_speedup_integration(self): def speedup_integration(self, model_list, speedup_cfg=None):
# skip this test on windows(7GB mem available) due to memory limit
# Note: hack trick, may be updated in the future # Note: hack trick, may be updated in the future
if 'win' in sys.platform or 'Win'in sys.platform: if 'win' in sys.platform or 'Win'in sys.platform:
print('Skip test_speedup_integration on windows due to memory limit!') print('Skip test_speedup_integration on windows due to memory limit!')
return return
Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2] Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2]
for model_name in ['resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121' , 'densenet169', # for model_name in ['vgg16', 'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121',
# 'inception_v3' inception is too large and may fail the pipeline # # 'inception_v3' inception is too large and may fail the pipeline
'resnet50']: # 'resnet50']:
for model_name in model_list:
for gen_cfg_func in Gen_cfg_funcs: for gen_cfg_func in Gen_cfg_funcs:
kwargs = { kwargs = {
'pretrained': True 'pretrained': True
...@@ -334,7 +388,10 @@ class SpeedupTestCase(TestCase): ...@@ -334,7 +388,10 @@ class SpeedupTestCase(TestCase):
speedup_model.eval() speedup_model.eval()
# random generate the prune config for the pruner # random generate the prune config for the pruner
cfgs = gen_cfg_func(net) cfgs = gen_cfg_func(net)
print("Testing {} with compression config \n {}".format(model_name, cfgs)) print("Testing {} with compression config \n {}".format(
model_name, cfgs))
if len(cfgs) == 0:
continue
pruner = L1FilterPruner(net, cfgs) pruner = L1FilterPruner(net, cfgs)
pruner.compress() pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE) pruner.export_model(MODEL_FILE, MASK_FILE)
...@@ -345,7 +402,10 @@ class SpeedupTestCase(TestCase): ...@@ -345,7 +402,10 @@ class SpeedupTestCase(TestCase):
zero_bn_bias(speedup_model) zero_bn_bias(speedup_model)
data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device) data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
ms = ModelSpeedup(speedup_model, data, MASK_FILE) if speedup_cfg is None:
speedup_cfg = {}
ms = ModelSpeedup(speedup_model, data,
MASK_FILE, confidence=2, **speedup_cfg)
ms.speedup_model() ms.speedup_model()
speedup_model.eval() speedup_model.eval()
...@@ -355,12 +415,13 @@ class SpeedupTestCase(TestCase): ...@@ -355,12 +415,13 @@ class SpeedupTestCase(TestCase):
ori_sum = torch.sum(ori_out).item() ori_sum = torch.sum(ori_out).item()
speeded_sum = torch.sum(speeded_out).item() speeded_sum = torch.sum(speeded_out).item()
print('Sum of the output of %s (before speedup):' % print('Sum of the output of %s (before speedup):' %
model_name, ori_sum) model_name, ori_sum)
print('Sum of the output of %s (after speedup):' % print('Sum of the output of %s (after speedup):' %
model_name, speeded_sum) model_name, speeded_sum)
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
print("Collecting Garbage")
gc.collect(2)
def test_channel_prune(self): def test_channel_prune(self):
orig_net = resnet18(num_classes=10).to(device) orig_net = resnet18(num_classes=10).to(device)
...@@ -378,7 +439,7 @@ class SpeedupTestCase(TestCase): ...@@ -378,7 +439,7 @@ class SpeedupTestCase(TestCase):
net.eval() net.eval()
data = torch.randn(BATCH_SIZE, 3, 128, 128).to(device) data = torch.randn(BATCH_SIZE, 3, 128, 128).to(device)
ms = ModelSpeedup(net, data, MASK_FILE) ms = ModelSpeedup(net, data, MASK_FILE, confidence=2)
ms.speedup_model() ms.speedup_model()
ms.bound_model(data) ms.bound_model(data)
...@@ -391,11 +452,56 @@ class SpeedupTestCase(TestCase): ...@@ -391,11 +452,56 @@ class SpeedupTestCase(TestCase):
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def test_speedup_tupleunpack(self):
"""This test is reported in issue3645"""
model = TupleUnpack_Model()
cfg_list = [{'op_types': ['Conv2d'], 'sparsity':0.5}]
dummy_input = torch.rand(2, 3, 224, 224)
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
model(dummy_input)
pruner.export_model(MODEL_FILE, MASK_FILE)
ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=2)
ms.speedup_model()
def test_finegrained_speedup(self):
""" Test the speedup on the fine-grained sparsity"""
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.fc1 = nn.Linear(1024, 1024)
self.fc2 = nn.Linear(1024, 1024)
self.fc3 = nn.Linear(1024, 512)
self.fc4 = nn.Linear(512, 10)
def forward(self, x):
x = x.view(-1, 1024)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
return x
model = MLP().to(device)
dummy_input = torch.rand(16, 1, 32, 32).to(device)
cfg_list = [{'op_types': ['Linear'], 'sparsity':0.99}]
pruner = LevelPruner(model, cfg_list)
pruner.compress()
print('Original Arch')
print(model)
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=4)
ms.speedup_model()
print("Fine-grained speeduped model")
print(model)
def tearDown(self): def tearDown(self):
if os.path.exists(MODEL_FILE): if os.path.exists(MODEL_FILE):
os.remove(MODEL_FILE) os.remove(MODEL_FILE)
if os.path.exists(MASK_FILE): if os.path.exists(MASK_FILE):
os.remove(MASK_FILE) os.remove(MASK_FILE)
# GC to release memory
gc.collect(2)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -156,7 +156,7 @@ def mock_get_latest_metric_data(): ...@@ -156,7 +156,7 @@ def mock_get_latest_metric_data():
def mock_get_trial_log(): def mock_get_trial_log():
responses.add( responses.add(
responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-log/:id/:type', responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-file/:id/:filename',
json={"status":"RUNNING","errors":[]}, json={"status":"RUNNING","errors":[]},
status=200, status=200,
content_type='application/json', content_type='application/json',
......
...@@ -161,6 +161,7 @@ export interface ExperimentConfig { ...@@ -161,6 +161,7 @@ export interface ExperimentConfig {
trialConcurrency: number; trialConcurrency: number;
trialGpuNumber?: number; trialGpuNumber?: number;
maxExperimentDuration?: string; maxExperimentDuration?: string;
maxTrialDuration?: string;
maxTrialNumber?: number; maxTrialNumber?: number;
nniManagerIp?: string; nniManagerIp?: string;
//useAnnotation: boolean; // dealed inside nnictl //useAnnotation: boolean; // dealed inside nnictl
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
'use strict'; 'use strict';
import { MetricDataRecord, MetricType, TrialJobInfo } from './datastore'; import { MetricDataRecord, MetricType, TrialJobInfo } from './datastore';
import { TrialJobStatus, LogType } from './trainingService'; import { TrialJobStatus } from './trainingService';
import { ExperimentConfig } from './experimentConfig'; import { ExperimentConfig } from './experimentConfig';
type ProfileUpdateType = 'TRIAL_CONCURRENCY' | 'MAX_EXEC_DURATION' | 'SEARCH_SPACE' | 'MAX_TRIAL_NUM'; type ProfileUpdateType = 'TRIAL_CONCURRENCY' | 'MAX_EXEC_DURATION' | 'SEARCH_SPACE' | 'MAX_TRIAL_NUM';
...@@ -59,7 +59,7 @@ abstract class Manager { ...@@ -59,7 +59,7 @@ abstract class Manager {
public abstract getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise<MetricDataRecord[]>; public abstract getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise<MetricDataRecord[]>;
public abstract getLatestMetricData(): Promise<MetricDataRecord[]>; public abstract getLatestMetricData(): Promise<MetricDataRecord[]>;
public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>; public abstract getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string>;
public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>; public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>;
public abstract getStatus(): NNIManagerStatus; public abstract getStatus(): NNIManagerStatus;
......
...@@ -8,8 +8,6 @@ ...@@ -8,8 +8,6 @@
*/ */
type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED' | 'EARLY_STOPPED'; type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED' | 'EARLY_STOPPED';
type LogType = 'TRIAL_LOG' | 'TRIAL_STDOUT' | 'TRIAL_ERROR';
interface TrainingServiceMetadata { interface TrainingServiceMetadata {
readonly key: string; readonly key: string;
readonly value: string; readonly value: string;
...@@ -81,7 +79,7 @@ abstract class TrainingService { ...@@ -81,7 +79,7 @@ abstract class TrainingService {
public abstract submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail>; public abstract submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail>; public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>; public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>;
public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>; public abstract getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string>;
public abstract setClusterMetadata(key: string, value: string): Promise<void>; public abstract setClusterMetadata(key: string, value: string): Promise<void>;
public abstract getClusterMetadata(key: string): Promise<string>; public abstract getClusterMetadata(key: string): Promise<string>;
public abstract getTrialOutputLocalPath(trialJobId: string): Promise<string>; public abstract getTrialOutputLocalPath(trialJobId: string): Promise<string>;
...@@ -103,5 +101,5 @@ class NNIManagerIpConfig { ...@@ -103,5 +101,5 @@ class NNIManagerIpConfig {
export { export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm, TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters, TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
NNIManagerIpConfig, LogType NNIManagerIpConfig
}; };
...@@ -223,7 +223,7 @@ let cachedIpv4Address: string | null = null; ...@@ -223,7 +223,7 @@ let cachedIpv4Address: string | null = null;
/** /**
* Get IPv4 address of current machine. * Get IPv4 address of current machine.
*/ */
function getIPV4Address(): string { async function getIPV4Address(): Promise<string> {
if (cachedIpv4Address !== null) { if (cachedIpv4Address !== null) {
return cachedIpv4Address; return cachedIpv4Address;
} }
...@@ -232,12 +232,20 @@ function getIPV4Address(): string { ...@@ -232,12 +232,20 @@ function getIPV4Address(): string {
// since udp is connectionless, this does not send actual packets. // since udp is connectionless, this does not send actual packets.
const socket = dgram.createSocket('udp4'); const socket = dgram.createSocket('udp4');
socket.connect(1, '192.0.2.0'); socket.connect(1, '192.0.2.0');
cachedIpv4Address = socket.address().address; for (let i = 0; i < 10; i++) { // wait the system to initialize "connection"
await yield_();
try { cachedIpv4Address = socket.address().address; } catch (error) { /* retry */ }
}
cachedIpv4Address = socket.address().address; // if it still fails, throw the error
socket.close(); socket.close();
return cachedIpv4Address; return cachedIpv4Address;
} }
async function yield_(): Promise<void> {
/* trigger the scheduler, do nothing */
}
/** /**
* Get the status of canceled jobs according to the hint isEarlyStopped * Get the status of canceled jobs according to the hint isEarlyStopped
*/ */
......
...@@ -19,7 +19,7 @@ import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/exp ...@@ -19,7 +19,7 @@ import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/exp
import { ExperimentManager } from '../common/experimentManager'; import { ExperimentManager } from '../common/experimentManager';
import { TensorboardManager } from '../common/tensorboardManager'; import { TensorboardManager } from '../common/tensorboardManager';
import { import {
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../common/trainingService'; } from '../common/trainingService';
import { delay, getCheckpointDir, getExperimentRootDir, getLogDir, getMsgDispatcherCommand, mkDirP, getTunerProc, getLogLevel, isAlive, killPid } from '../common/utils'; import { delay, getCheckpointDir, getExperimentRootDir, getLogDir, getMsgDispatcherCommand, mkDirP, getTunerProc, getLogLevel, isAlive, killPid } from '../common/utils';
import { import {
...@@ -189,7 +189,6 @@ class NNIManager implements Manager { ...@@ -189,7 +189,6 @@ class NNIManager implements Manager {
this.log.debug(`dispatcher command: ${dispatcherCommand}`); this.log.debug(`dispatcher command: ${dispatcherCommand}`);
const checkpointDir: string = await this.createCheckpointDir(); const checkpointDir: string = await this.createCheckpointDir();
this.setupTuner(dispatcherCommand, undefined, 'start', checkpointDir); this.setupTuner(dispatcherCommand, undefined, 'start', checkpointDir);
this.setStatus('RUNNING'); this.setStatus('RUNNING');
await this.storeExperimentProfile(); await this.storeExperimentProfile();
this.run().catch((err: Error) => { this.run().catch((err: Error) => {
...@@ -403,8 +402,8 @@ class NNIManager implements Manager { ...@@ -403,8 +402,8 @@ class NNIManager implements Manager {
// FIXME: unit test // FIXME: unit test
} }
public async getTrialLog(trialJobId: string, logType: LogType): Promise<string> { public async getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string> {
return this.trainingService.getTrialLog(trialJobId, logType); return this.trainingService.getTrialFile(trialJobId, fileName);
} }
public getExperimentProfile(): Promise<ExperimentProfile> { public getExperimentProfile(): Promise<ExperimentProfile> {
...@@ -433,6 +432,11 @@ class NNIManager implements Manager { ...@@ -433,6 +432,11 @@ class NNIManager implements Manager {
return (value === undefined ? Infinity : value); return (value === undefined ? Infinity : value);
} }
private get maxTrialDuration(): number {
const value = this.experimentProfile.params.maxTrialDuration;
return (value === undefined ? Infinity : toSeconds(value));
}
private async initTrainingService(config: ExperimentConfig): Promise<TrainingService> { private async initTrainingService(config: ExperimentConfig): Promise<TrainingService> {
let platform: string; let platform: string;
if (Array.isArray(config.trainingService)) { if (Array.isArray(config.trainingService)) {
...@@ -539,6 +543,17 @@ class NNIManager implements Manager { ...@@ -539,6 +543,17 @@ class NNIManager implements Manager {
} }
} }
private async stopTrialJobIfOverMaxDurationTimer(trialJobId: string): Promise<void> {
const trialJobDetail: TrialJobDetail | undefined = this.trialJobs.get(trialJobId);
if(undefined !== trialJobDetail &&
trialJobDetail.status === 'RUNNING' &&
trialJobDetail.startTime !== undefined){
const isEarlyStopped = true;
await this.trainingService.cancelTrialJob(trialJobId, isEarlyStopped);
this.log.info(`Trial job ${trialJobId} has stoped because it is over maxTrialDuration.`);
}
}
private async requestTrialJobsStatus(): Promise<number> { private async requestTrialJobsStatus(): Promise<number> {
let finishedTrialJobNum: number = 0; let finishedTrialJobNum: number = 0;
if (this.dispatcher === undefined) { if (this.dispatcher === undefined) {
...@@ -662,6 +677,7 @@ class NNIManager implements Manager { ...@@ -662,6 +677,7 @@ class NNIManager implements Manager {
this.currSubmittedTrialNum++; this.currSubmittedTrialNum++;
this.log.info('submitTrialJob: form:', form); this.log.info('submitTrialJob: form:', form);
const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(form); const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(form);
setTimeout(async ()=> this.stopTrialJobIfOverMaxDurationTimer(trialJobDetail.id), 1000 * this.maxTrialDuration);
const Snapshot: TrialJobDetail = Object.assign({}, trialJobDetail); const Snapshot: TrialJobDetail = Object.assign({}, trialJobDetail);
await this.storeExperimentProfile(); await this.storeExperimentProfile();
this.trialJobs.set(trialJobDetail.id, Snapshot); this.trialJobs.set(trialJobDetail.id, Snapshot);
......
...@@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred'; ...@@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred';
import { Provider } from 'typescript-ioc'; import { Provider } from 'typescript-ioc';
import { MethodNotImplementedError } from '../../common/errors'; import { MethodNotImplementedError } from '../../common/errors';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType } from '../../common/trainingService'; import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from '../../common/trainingService';
const testTrainingServiceProvider: Provider = { const testTrainingServiceProvider: Provider = {
get: () => { return new MockedTrainingService(); } get: () => { return new MockedTrainingService(); }
...@@ -63,7 +63,7 @@ class MockedTrainingService extends TrainingService { ...@@ -63,7 +63,7 @@ class MockedTrainingService extends TrainingService {
return deferred.promise; return deferred.promise;
} }
public getTrialLog(trialJobId: string, logType: LogType): Promise<string> { public getTrialFile(trialJobId: string, fileName: string): Promise<string> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"child-process-promise": "^2.2.1", "child-process-promise": "^2.2.1",
"express": "^4.17.1", "express": "^4.17.1",
"express-joi-validator": "^2.0.1", "express-joi-validator": "^2.0.1",
"http-proxy": "^1.18.1",
"ignore": "^5.1.8", "ignore": "^5.1.8",
"js-base64": "^3.6.1", "js-base64": "^3.6.1",
"kubernetes-client": "^6.12.1", "kubernetes-client": "^6.12.1",
...@@ -37,6 +38,7 @@ ...@@ -37,6 +38,7 @@
"@types/chai-as-promised": "^7.1.0", "@types/chai-as-promised": "^7.1.0",
"@types/express": "^4.17.2", "@types/express": "^4.17.2",
"@types/glob": "^7.1.3", "@types/glob": "^7.1.3",
"@types/http-proxy": "^1.17.7",
"@types/js-base64": "^3.3.1", "@types/js-base64": "^3.3.1",
"@types/js-yaml": "^4.0.1", "@types/js-yaml": "^4.0.1",
"@types/lockfile": "^1.0.0", "@types/lockfile": "^1.0.0",
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import * as bodyParser from 'body-parser'; import * as bodyParser from 'body-parser';
import * as express from 'express'; import * as express from 'express';
import * as httpProxy from 'http-proxy';
import * as path from 'path'; import * as path from 'path';
import * as component from '../common/component'; import * as component from '../common/component';
import { RestServer } from '../common/restServer' import { RestServer } from '../common/restServer'
...@@ -21,6 +22,7 @@ import { getAPIRootUrl } from '../common/experimentStartupInfo'; ...@@ -21,6 +22,7 @@ import { getAPIRootUrl } from '../common/experimentStartupInfo';
@component.Singleton @component.Singleton
export class NNIRestServer extends RestServer { export class NNIRestServer extends RestServer {
private readonly LOGS_ROOT_URL: string = '/logs'; private readonly LOGS_ROOT_URL: string = '/logs';
protected netronProxy: any = null;
protected API_ROOT_URL: string = '/api/v1/nni'; protected API_ROOT_URL: string = '/api/v1/nni';
/** /**
...@@ -29,6 +31,7 @@ export class NNIRestServer extends RestServer { ...@@ -29,6 +31,7 @@ export class NNIRestServer extends RestServer {
constructor() { constructor() {
super(); super();
this.API_ROOT_URL = getAPIRootUrl(); this.API_ROOT_URL = getAPIRootUrl();
this.netronProxy = httpProxy.createProxyServer();
} }
/** /**
...@@ -39,6 +42,14 @@ export class NNIRestServer extends RestServer { ...@@ -39,6 +42,14 @@ export class NNIRestServer extends RestServer {
this.app.use(bodyParser.json({limit: '50mb'})); this.app.use(bodyParser.json({limit: '50mb'}));
this.app.use(this.API_ROOT_URL, createRestHandler(this)); this.app.use(this.API_ROOT_URL, createRestHandler(this));
this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir())); this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir()));
this.app.all('/netron/*', (req: express.Request, res: express.Response) => {
delete req.headers.host;
req.url = req.url.replace('/netron', '/');
this.netronProxy.web(req, res, {
changeOrigin: true,
target: 'https://netron.app'
});
});
this.app.get('*', (req: express.Request, res: express.Response) => { this.app.get('*', (req: express.Request, res: express.Response) => {
res.sendFile(path.resolve('static/index.html')); res.sendFile(path.resolve('static/index.html'));
}); });
......
...@@ -19,7 +19,7 @@ import { NNIRestServer } from './nniRestServer'; ...@@ -19,7 +19,7 @@ import { NNIRestServer } from './nniRestServer';
import { getVersion } from '../common/utils'; import { getVersion } from '../common/utils';
import { MetricType } from '../common/datastore'; import { MetricType } from '../common/datastore';
import { ProfileUpdateType } from '../common/manager'; import { ProfileUpdateType } from '../common/manager';
import { LogType, TrialJobStatus } from '../common/trainingService'; import { TrialJobStatus } from '../common/trainingService';
const expressJoi = require('express-joi-validator'); const expressJoi = require('express-joi-validator');
...@@ -53,6 +53,7 @@ class NNIRestHandler { ...@@ -53,6 +53,7 @@ class NNIRestHandler {
this.version(router); this.version(router);
this.checkStatus(router); this.checkStatus(router);
this.getExperimentProfile(router); this.getExperimentProfile(router);
this.getExperimentMetadata(router);
this.updateExperimentProfile(router); this.updateExperimentProfile(router);
this.importData(router); this.importData(router);
this.getImportedData(router); this.getImportedData(router);
...@@ -66,7 +67,7 @@ class NNIRestHandler { ...@@ -66,7 +67,7 @@ class NNIRestHandler {
this.getMetricData(router); this.getMetricData(router);
this.getMetricDataByRange(router); this.getMetricDataByRange(router);
this.getLatestMetricData(router); this.getLatestMetricData(router);
this.getTrialLog(router); this.getTrialFile(router);
this.exportData(router); this.exportData(router);
this.getExperimentsInfo(router); this.getExperimentsInfo(router);
this.startTensorboardTask(router); this.startTensorboardTask(router);
...@@ -296,13 +297,20 @@ class NNIRestHandler { ...@@ -296,13 +297,20 @@ class NNIRestHandler {
}); });
} }
private getTrialLog(router: Router): void { private getTrialFile(router: Router): void {
router.get('/trial-log/:id/:type', async(req: Request, res: Response) => { router.get('/trial-file/:id/:filename', async(req: Request, res: Response) => {
this.nniManager.getTrialLog(req.params.id, req.params.type as LogType).then((log: string) => { let encoding: string | null = null;
if (log === '') { const filename = req.params.filename;
log = 'No logs available.' if (!filename.includes('.') || filename.match(/.*\.(txt|log)/g)) {
encoding = 'utf8';
}
this.nniManager.getTrialFile(req.params.id, filename).then((content: Buffer | string) => {
if (content instanceof Buffer) {
res.header('Content-Type', 'application/octet-stream');
} else if (content === '') {
content = `${filename} is empty.`;
} }
res.send(log); res.send(content);
}).catch((err: Error) => { }).catch((err: Error) => {
this.handleError(err, res); this.handleError(err, res);
}); });
...@@ -319,6 +327,24 @@ class NNIRestHandler { ...@@ -319,6 +327,24 @@ class NNIRestHandler {
}); });
} }
private getExperimentMetadata(router: Router): void {
router.get('/experiment-metadata', (req: Request, res: Response) => {
Promise.all([
this.nniManager.getExperimentProfile(),
this.experimentsManager.getExperimentsInfo()
]).then(([profile, experimentInfo]) => {
for (const info of experimentInfo as any) {
if (info.id === profile.id) {
res.send(info);
break;
}
}
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private getExperimentsInfo(router: Router): void { private getExperimentsInfo(router: Router): void {
router.get('/experiments-info', (req: Request, res: Response) => { router.get('/experiments-info', (req: Request, res: Response) => {
this.experimentsManager.getExperimentsInfo().then((experimentInfo: JSON) => { this.experimentsManager.getExperimentsInfo().then((experimentInfo: JSON) => {
......
...@@ -13,7 +13,7 @@ import { ...@@ -13,7 +13,7 @@ import {
TrialJobStatistics, NNIManagerStatus TrialJobStatistics, NNIManagerStatus
} from '../../common/manager'; } from '../../common/manager';
import { import {
TrialJobApplicationForm, TrialJobDetail, TrialJobStatus, LogType TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
} from '../../common/trainingService'; } from '../../common/trainingService';
export const testManagerProvider: Provider = { export const testManagerProvider: Provider = {
...@@ -129,7 +129,7 @@ export class MockedNNIManager extends Manager { ...@@ -129,7 +129,7 @@ export class MockedNNIManager extends Manager {
public getLatestMetricData(): Promise<MetricDataRecord[]> { public getLatestMetricData(): Promise<MetricDataRecord[]> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
public getTrialLog(trialJobId: string, logType: LogType): Promise<string> { public getTrialFile(trialJobId: string, fileName: string): Promise<string> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
public getExperimentProfile(): Promise<ExperimentProfile> { public getExperimentProfile(): Promise<ExperimentProfile> {
......
...@@ -14,7 +14,7 @@ import {getExperimentId} from '../../common/experimentStartupInfo'; ...@@ -14,7 +14,7 @@ import {getExperimentId} from '../../common/experimentStartupInfo';
import {getLogger, Logger} from '../../common/log'; import {getLogger, Logger} from '../../common/log';
import {MethodNotImplementedError} from '../../common/errors'; import {MethodNotImplementedError} from '../../common/errors';
import { import {
NNIManagerIpConfig, TrialJobDetail, TrialJobMetric, LogType NNIManagerIpConfig, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import {delay, getExperimentRootDir, getIPV4Address, getJobCancelStatus, getVersion, uniqueString} from '../../common/utils'; import {delay, getExperimentRootDir, getIPV4Address, getJobCancelStatus, getVersion, uniqueString} from '../../common/utils';
import {AzureStorageClientUtility} from './azureStorageClientUtils'; import {AzureStorageClientUtility} from './azureStorageClientUtils';
...@@ -99,7 +99,7 @@ abstract class KubernetesTrainingService { ...@@ -99,7 +99,7 @@ abstract class KubernetesTrainingService {
return Promise.resolve(kubernetesTrialJob); return Promise.resolve(kubernetesTrialJob);
} }
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> { public async getTrialFile(_trialJobId: string, _filename: string): Promise<string | Buffer> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
...@@ -277,7 +277,7 @@ abstract class KubernetesTrainingService { ...@@ -277,7 +277,7 @@ abstract class KubernetesTrainingService {
if (gpuNum === 0) { if (gpuNum === 0) {
nvidiaScript = 'export CUDA_VISIBLE_DEVICES='; nvidiaScript = 'export CUDA_VISIBLE_DEVICES=';
} }
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address(); const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : await getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : ''; const version: string = this.versionCheck ? await getVersion() : '';
const runScript: string = String.Format( const runScript: string = String.Format(
kubernetesScriptFormat, kubernetesScriptFormat,
......
...@@ -13,7 +13,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo'; ...@@ -13,7 +13,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { import {
HyperParameters, TrainingService, TrialJobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService'; } from '../../common/trainingService';
import { import {
delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString
...@@ -170,18 +170,20 @@ class LocalTrainingService implements TrainingService { ...@@ -170,18 +170,20 @@ class LocalTrainingService implements TrainingService {
return trialJob; return trialJob;
} }
public async getTrialLog(trialJobId: string, logType: LogType): Promise<string> { public async getTrialFile(trialJobId: string, fileName: string): Promise<string | Buffer> {
let logPath: string; // check filename here for security
if (logType === 'TRIAL_LOG') { if (!['trial.log', 'stderr', 'model.onnx', 'stdout'].includes(fileName)) {
logPath = path.join(this.rootDir, 'trials', trialJobId, 'trial.log'); throw new Error(`File unaccessible: ${fileName}`);
} else if (logType === 'TRIAL_STDOUT'){ }
logPath = path.join(this.rootDir, 'trials', trialJobId, 'stdout'); let encoding: string | null = null;
} else if (logType === 'TRIAL_ERROR') { if (!fileName.includes('.') || fileName.match(/.*\.(txt|log)/g)) {
logPath = path.join(this.rootDir, 'trials', trialJobId, 'stderr'); encoding = 'utf8';
} else { }
throw new Error('unexpected log type'); const logPath = path.join(this.rootDir, 'trials', trialJobId, fileName);
if (!fs.existsSync(logPath)) {
throw new Error(`File not found: ${logPath}`);
} }
return fs.promises.readFile(logPath, 'utf8'); return fs.promises.readFile(logPath, {encoding: encoding as any});
} }
public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void { public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
......
...@@ -15,7 +15,7 @@ import { getLogger, Logger } from '../../common/log'; ...@@ -15,7 +15,7 @@ import { getLogger, Logger } from '../../common/log';
import { MethodNotImplementedError } from '../../common/errors'; import { MethodNotImplementedError } from '../../common/errors';
import { import {
HyperParameters, NNIManagerIpConfig, TrainingService, HyperParameters, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay } from '../../common/utils'; import { delay } from '../../common/utils';
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../common/experimentConfig'; import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../common/experimentConfig';
...@@ -23,10 +23,7 @@ import { PAIJobInfoCollector } from './paiJobInfoCollector'; ...@@ -23,10 +23,7 @@ import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { PAIJobRestServer } from './paiJobRestServer'; import { PAIJobRestServer } from './paiJobRestServer';
import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig'; import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import { import { generateParamFileName, getIPV4Address, uniqueString } from '../../common/utils';
generateParamFileName,
getIPV4Address, uniqueString
} from '../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { execMkdir, validateCodeDir, execCopydir } from '../common/util'; import { execMkdir, validateCodeDir, execCopydir } from '../common/util';
...@@ -127,7 +124,7 @@ class PAITrainingService implements TrainingService { ...@@ -127,7 +124,7 @@ class PAITrainingService implements TrainingService {
return jobs; return jobs;
} }
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> { public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
...@@ -332,7 +329,7 @@ class PAITrainingService implements TrainingService { ...@@ -332,7 +329,7 @@ class PAITrainingService implements TrainingService {
return trialJobDetail; return trialJobDetail;
} }
private generateNNITrialCommand(trialJobDetail: PAITrialJobDetail, command: string): string { private async generateNNITrialCommand(trialJobDetail: PAITrialJobDetail, command: string): Promise<string> {
const containerNFSExpCodeDir = `${this.config.containerStorageMountPoint}/${this.experimentId}/nni-code`; const containerNFSExpCodeDir = `${this.config.containerStorageMountPoint}/${this.experimentId}/nni-code`;
const containerWorkingDir: string = `${this.config.containerStorageMountPoint}/${this.experimentId}/${trialJobDetail.id}`; const containerWorkingDir: string = `${this.config.containerStorageMountPoint}/${this.experimentId}/${trialJobDetail.id}`;
const nniPaiTrialCommand: string = String.Format( const nniPaiTrialCommand: string = String.Format(
...@@ -345,7 +342,7 @@ class PAITrainingService implements TrainingService { ...@@ -345,7 +342,7 @@ class PAITrainingService implements TrainingService {
false, // multi-phase false, // multi-phase
containerNFSExpCodeDir, containerNFSExpCodeDir,
command, command,
this.config.nniManagerIp || getIPV4Address(), this.config.nniManagerIp || await getIPV4Address(),
this.paiRestServerPort, this.paiRestServerPort,
this.nniVersion, this.nniVersion,
this.logCollection this.logCollection
...@@ -356,7 +353,7 @@ class PAITrainingService implements TrainingService { ...@@ -356,7 +353,7 @@ class PAITrainingService implements TrainingService {
} }
private generateJobConfigInYamlFormat(trialJobDetail: PAITrialJobDetail): any { private async generateJobConfigInYamlFormat(trialJobDetail: PAITrialJobDetail): Promise<any> {
const jobName = `nni_exp_${this.experimentId}_trial_${trialJobDetail.id}` const jobName = `nni_exp_${this.experimentId}_trial_${trialJobDetail.id}`
let nniJobConfig: any = undefined; let nniJobConfig: any = undefined;
...@@ -367,7 +364,7 @@ class PAITrainingService implements TrainingService { ...@@ -367,7 +364,7 @@ class PAITrainingService implements TrainingService {
// Each command will be formatted to NNI style // Each command will be formatted to NNI style
for (const taskRoleIndex in nniJobConfig.taskRoles) { for (const taskRoleIndex in nniJobConfig.taskRoles) {
const commands = nniJobConfig.taskRoles[taskRoleIndex].commands const commands = nniJobConfig.taskRoles[taskRoleIndex].commands
const nniTrialCommand = this.generateNNITrialCommand(trialJobDetail, commands.join(" && ").replace(/(["'$`\\])/g, '\\$1')); const nniTrialCommand = await this.generateNNITrialCommand(trialJobDetail, commands.join(" && ").replace(/(["'$`\\])/g, '\\$1'));
nniJobConfig.taskRoles[taskRoleIndex].commands = [nniTrialCommand] nniJobConfig.taskRoles[taskRoleIndex].commands = [nniTrialCommand]
} }
...@@ -399,7 +396,7 @@ class PAITrainingService implements TrainingService { ...@@ -399,7 +396,7 @@ class PAITrainingService implements TrainingService {
memoryMB: toMegaBytes(this.config.trialMemorySize) memoryMB: toMegaBytes(this.config.trialMemorySize)
}, },
commands: [ commands: [
this.generateNNITrialCommand(trialJobDetail, this.config.trialCommand) await this.generateNNITrialCommand(trialJobDetail, this.config.trialCommand)
] ]
} }
}, },
...@@ -456,7 +453,7 @@ class PAITrainingService implements TrainingService { ...@@ -456,7 +453,7 @@ class PAITrainingService implements TrainingService {
} }
//Generate Job Configuration in yaml format //Generate Job Configuration in yaml format
const paiJobConfig = this.generateJobConfigInYamlFormat(trialJobDetail); const paiJobConfig = await this.generateJobConfigInYamlFormat(trialJobDetail);
this.log.debug(paiJobConfig); this.log.debug(paiJobConfig);
// Step 2. Submit PAI job via Rest call // Step 2. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
......
...@@ -16,7 +16,7 @@ import { getLogger, Logger } from '../../common/log'; ...@@ -16,7 +16,7 @@ import { getLogger, Logger } from '../../common/log';
import { ObservableTimer } from '../../common/observableTimer'; import { ObservableTimer } from '../../common/observableTimer';
import { import {
HyperParameters, TrainingService, TrialJobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, LogType TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { import {
delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus, delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus,
...@@ -204,7 +204,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -204,7 +204,7 @@ class RemoteMachineTrainingService implements TrainingService {
* @param _trialJobId ID of trial job * @param _trialJobId ID of trial job
* @param _logType 'TRIAL_LOG' | 'TRIAL_STDERR' * @param _logType 'TRIAL_LOG' | 'TRIAL_STDERR'
*/ */
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> { public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
...@@ -491,7 +491,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -491,7 +491,7 @@ class RemoteMachineTrainingService implements TrainingService {
cudaVisible = `CUDA_VISIBLE_DEVICES=" "`; cudaVisible = `CUDA_VISIBLE_DEVICES=" "`;
} }
} }
const nniManagerIp: string = this.config.nniManagerIp ? this.config.nniManagerIp : getIPV4Address(); const nniManagerIp: string = this.config.nniManagerIp ? this.config.nniManagerIp : await getIPV4Address();
if (this.remoteRestServerPort === undefined) { if (this.remoteRestServerPort === undefined) {
const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer); const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer);
this.remoteRestServerPort = restServer.clusterRestServerPort; this.remoteRestServerPort = restServer.clusterRestServerPort;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { MethodNotImplementedError } from '../../common/errors'; import { MethodNotImplementedError } from '../../common/errors';
import { ExperimentConfig, RemoteConfig, OpenpaiConfig } from '../../common/experimentConfig'; import { ExperimentConfig, RemoteConfig, OpenpaiConfig } from '../../common/experimentConfig';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType } from '../../common/trainingService'; import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from '../../common/trainingService';
import { delay } from '../../common/utils'; import { delay } from '../../common/utils';
import { PAITrainingService } from '../pai/paiTrainingService'; import { PAITrainingService } from '../pai/paiTrainingService';
import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService'; import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService';
...@@ -52,7 +52,7 @@ class RouterTrainingService implements TrainingService { ...@@ -52,7 +52,7 @@ class RouterTrainingService implements TrainingService {
return await this.internalTrainingService.getTrialJob(trialJobId); return await this.internalTrainingService.getTrialJob(trialJobId);
} }
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> { public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
......
...@@ -13,7 +13,7 @@ import * as component from '../../common/component'; ...@@ -13,7 +13,7 @@ import * as component from '../../common/component';
import { NNIError, NNIErrorNames, MethodNotImplementedError } from '../../common/errors'; import { NNIError, NNIErrorNames, MethodNotImplementedError } from '../../common/errors';
import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo'; import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus, LogType } from '../../common/trainingService'; import { TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus } from '../../common/trainingService';
import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, randomSelect, uniqueString } from '../../common/utils'; import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, randomSelect, uniqueString } from '../../common/utils';
import { ExperimentConfig, SharedStorageConfig } from '../../common/experimentConfig'; import { ExperimentConfig, SharedStorageConfig } from '../../common/experimentConfig';
import { GPU_INFO, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, REPORT_METRIC_DATA, SEND_TRIAL_JOB_PARAMETER, STDOUT, TRIAL_END, VERSION_CHECK } from '../../core/commands'; import { GPU_INFO, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, REPORT_METRIC_DATA, SEND_TRIAL_JOB_PARAMETER, STDOUT, TRIAL_END, VERSION_CHECK } from '../../core/commands';
...@@ -157,7 +157,7 @@ class TrialDispatcher implements TrainingService { ...@@ -157,7 +157,7 @@ class TrialDispatcher implements TrainingService {
return trial; return trial;
} }
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> { public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
...@@ -216,7 +216,7 @@ class TrialDispatcher implements TrainingService { ...@@ -216,7 +216,7 @@ class TrialDispatcher implements TrainingService {
for(const environmentService of this.environmentServiceList) { for(const environmentService of this.environmentServiceList) {
const runnerSettings: RunnerSettings = new RunnerSettings(); const runnerSettings: RunnerSettings = new RunnerSettings();
runnerSettings.nniManagerIP = this.config.nniManagerIp === undefined? getIPV4Address() : this.config.nniManagerIp; runnerSettings.nniManagerIP = this.config.nniManagerIp === undefined? await getIPV4Address() : this.config.nniManagerIp;
runnerSettings.nniManagerPort = getBasePort() + 1; runnerSettings.nniManagerPort = getBasePort() + 1;
runnerSettings.commandChannel = environmentService.getCommandChannel.channelName; runnerSettings.commandChannel = environmentService.getCommandChannel.channelName;
runnerSettings.enableGpuCollector = this.enableGpuScheduler; runnerSettings.enableGpuCollector = this.enableGpuScheduler;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment