Unverified Commit 1bd17a02 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Visualization (#3878)

parent 4b38e64f
...@@ -182,6 +182,8 @@ Visualize the Experiment ...@@ -182,6 +182,8 @@ Visualize the Experiment
Users can visualize their experiment in the same way as visualizing a normal hyper-parameter tuning experiment. For example, open ``localhost::8081`` in your browser, 8081 is the port that you set in ``exp.run``. Please refer to `here <../../Tutorial/WebUI.rst>`__ for details. Users can visualize their experiment in the same way as visualizing a normal hyper-parameter tuning experiment. For example, open ``localhost::8081`` in your browser, 8081 is the port that you set in ``exp.run``. Please refer to `here <../../Tutorial/WebUI.rst>`__ for details.
We support visualizing models with 3rd-party visualization engines (like `Netron <https://netron.app/>`__). This can be used by clicking ``Visualization`` in detail panel for each trial. Note that current visualization is based on `onnx <https://onnx.ai/>`__ . Built-in evaluators (e.g., Classification) will automatically export the model into a file, for your own evaluator, you need to save your file into ``$NNI_OUTPUT_DIR/model.onnx`` to make this work.
Export Top Models Export Top Models
----------------- -----------------
......
...@@ -24,6 +24,8 @@ The simplest way to customize a new evaluator is with functional APIs, which is ...@@ -24,6 +24,8 @@ The simplest way to customize a new evaluator is with functional APIs, which is
.. note:: Due to our current implementation limitation, the ``fit`` function should be put in another python file instead of putting it in the main file. This limitation will be fixed in future release. .. note:: Due to our current implementation limitation, the ``fit`` function should be put in another python file instead of putting it in the main file. This limitation will be fixed in future release.
.. note:: When using customized evaluators, if you want to visualize models, you need to export your model and save it into ``$NNI_OUTPUT_DIR/model.onnx`` in your evaluator.
With PyTorch-Lightning With PyTorch-Lightning
---------------------- ----------------------
......
...@@ -8,7 +8,7 @@ import socket ...@@ -8,7 +8,7 @@ import socket
from subprocess import Popen from subprocess import Popen
import sys import sys
import time import time
from typing import Optional, Tuple from typing import Optional, Tuple, List, Any
import colorama import colorama
...@@ -43,7 +43,7 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo ...@@ -43,7 +43,7 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
_check_rest_server(port) _check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform, _save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, str(config.experiment_working_directory)) config.experiment_name, proc.pid, str(config.experiment_working_directory), [])
_logger.info('Setting up...') _logger.info('Setting up...')
rest.post(port, '/experiment', config.json()) rest.post(port, '/experiment', config.json())
return proc return proc
...@@ -78,7 +78,7 @@ def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, ...@@ -78,7 +78,7 @@ def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int,
_check_rest_server(port) _check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform, _save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, config.experiment_working_directory) config.experiment_name, proc.pid, config.experiment_working_directory, ['retiarii'])
_logger.info('Setting up...') _logger.info('Setting up...')
rest.post(port, '/experiment', config.json()) rest.post(port, '/experiment', config.json())
return proc, pipe return proc, pipe
...@@ -156,9 +156,10 @@ def _check_rest_server(port: int, retry: int = 3) -> None: ...@@ -156,9 +156,10 @@ def _check_rest_server(port: int, retry: int = 3) -> None:
rest.get(port, '/check-status') rest.get(port, '/check-status')
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None: def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str,
name: str, pid: int, logDir: str, tag: List[Any]) -> None:
experiments_config = Experiments() experiments_config = Experiments()
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir) experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir, tag=tag)
def get_stopped_experiment_config(exp_id: str, mode: str) -> None: def get_stopped_experiment_config(exp_id: str, mode: str) -> None:
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import os
import warnings import warnings
from typing import Dict, Union, Optional, List from pathlib import Path
from typing import Dict, NoReturn, Union, Optional, List, Type
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn as nn import torch.nn as nn
...@@ -18,7 +20,13 @@ __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classificat ...@@ -18,7 +20,13 @@ __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classificat
class LightningModule(pl.LightningModule): class LightningModule(pl.LightningModule):
def set_model(self, model): """
Basic wrapper of generated model.
Lightning modules used in NNI should inherit this class.
"""
def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> NoReturn:
if isinstance(model, type): if isinstance(model, type):
self.model = model() self.model = model()
else: else:
...@@ -112,13 +120,23 @@ class _SupervisedLearningModule(LightningModule): ...@@ -112,13 +120,23 @@ class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric], def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam): optimizer: optim.Optimizer = optim.Adam,
export_onnx: Union[Path, str, bool, None] = None):
super().__init__() super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay') self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion() self.criterion = criterion()
self.optimizer = optimizer self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()}) self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
if export_onnx is None or export_onnx is True:
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
self.export_onnx.parent.mkdir(exist_ok=True)
elif export_onnx:
self.export_onnx = Path(export_onnx)
else:
self.export_onnx = None
self._already_exported = False
def forward(self, x): def forward(self, x):
y_hat = self.model(x) y_hat = self.model(x)
return y_hat return y_hat
...@@ -135,6 +153,11 @@ class _SupervisedLearningModule(LightningModule): ...@@ -135,6 +153,11 @@ class _SupervisedLearningModule(LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
x, y = batch x, y = batch
y_hat = self(x) y_hat = self(x)
if not self._already_exported:
self.to_onnx(self.export_onnx, x, export_params=True)
self._already_exported = True
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True) self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items(): for name, metric in self.metrics.items():
self.log('val_' + name, metric(y_hat, y), prog_bar=True) self.log('val_' + name, metric(y_hat, y), prog_bar=True)
...@@ -152,9 +175,8 @@ class _SupervisedLearningModule(LightningModule): ...@@ -152,9 +175,8 @@ class _SupervisedLearningModule(LightningModule):
def on_validation_epoch_end(self): def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics()) nni.report_intermediate_result(self._get_validation_metrics())
def teardown(self, stage): def on_fit_end(self):
if stage == 'fit': nni.report_final_result(self._get_validation_metrics())
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self): def _get_validation_metrics(self):
if len(self.metrics) == 1: if len(self.metrics) == 1:
...@@ -175,9 +197,11 @@ class _ClassificationModule(_SupervisedLearningModule): ...@@ -175,9 +197,11 @@ class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam): optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits}, super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer) learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)
class Classification(Lightning): class Classification(Lightning):
...@@ -200,6 +224,8 @@ class Classification(Lightning): ...@@ -200,6 +224,8 @@ class Classification(Lightning):
val_dataloaders : DataLoader or List of DataLoader val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default true
trainer_kwargs : dict trainer_kwargs : dict
Optional keyword arguments passed to trainer. See Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details. `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
...@@ -211,9 +237,10 @@ class Classification(Lightning): ...@@ -211,9 +237,10 @@ class Classification(Lightning):
optimizer: optim.Optimizer = optim.Adam, optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None, train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
**trainer_kwargs): **trainer_kwargs):
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate, module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer) weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs), super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
...@@ -223,9 +250,11 @@ class _RegressionModule(_SupervisedLearningModule): ...@@ -223,9 +250,11 @@ class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss, def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam): optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError}, super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer) learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)
class Regression(Lightning): class Regression(Lightning):
...@@ -248,6 +277,8 @@ class Regression(Lightning): ...@@ -248,6 +277,8 @@ class Regression(Lightning):
val_dataloaders : DataLoader or List of DataLoader val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default: true
trainer_kwargs : dict trainer_kwargs : dict
Optional keyword arguments passed to trainer. See Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details. `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
...@@ -259,8 +290,9 @@ class Regression(Lightning): ...@@ -259,8 +290,9 @@ class Regression(Lightning):
optimizer: optim.Optimizer = optim.Adam, optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None, train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
**trainer_kwargs): **trainer_kwargs):
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate, module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer) weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs), super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
...@@ -156,7 +156,7 @@ def mock_get_latest_metric_data(): ...@@ -156,7 +156,7 @@ def mock_get_latest_metric_data():
def mock_get_trial_log(): def mock_get_trial_log():
responses.add( responses.add(
responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-log/:id/:type', responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-file/:id/:filename',
json={"status":"RUNNING","errors":[]}, json={"status":"RUNNING","errors":[]},
status=200, status=200,
content_type='application/json', content_type='application/json',
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
'use strict'; 'use strict';
import { MetricDataRecord, MetricType, TrialJobInfo } from './datastore'; import { MetricDataRecord, MetricType, TrialJobInfo } from './datastore';
import { TrialJobStatus, LogType } from './trainingService'; import { TrialJobStatus } from './trainingService';
import { ExperimentConfig } from './experimentConfig'; import { ExperimentConfig } from './experimentConfig';
type ProfileUpdateType = 'TRIAL_CONCURRENCY' | 'MAX_EXEC_DURATION' | 'SEARCH_SPACE' | 'MAX_TRIAL_NUM'; type ProfileUpdateType = 'TRIAL_CONCURRENCY' | 'MAX_EXEC_DURATION' | 'SEARCH_SPACE' | 'MAX_TRIAL_NUM';
...@@ -59,7 +59,7 @@ abstract class Manager { ...@@ -59,7 +59,7 @@ abstract class Manager {
public abstract getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise<MetricDataRecord[]>; public abstract getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise<MetricDataRecord[]>;
public abstract getLatestMetricData(): Promise<MetricDataRecord[]>; public abstract getLatestMetricData(): Promise<MetricDataRecord[]>;
public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>; public abstract getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string>;
public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>; public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>;
public abstract getStatus(): NNIManagerStatus; public abstract getStatus(): NNIManagerStatus;
......
...@@ -8,8 +8,6 @@ ...@@ -8,8 +8,6 @@
*/ */
type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED' | 'EARLY_STOPPED'; type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED' | 'EARLY_STOPPED';
type LogType = 'TRIAL_LOG' | 'TRIAL_STDOUT' | 'TRIAL_ERROR';
interface TrainingServiceMetadata { interface TrainingServiceMetadata {
readonly key: string; readonly key: string;
readonly value: string; readonly value: string;
...@@ -81,7 +79,7 @@ abstract class TrainingService { ...@@ -81,7 +79,7 @@ abstract class TrainingService {
public abstract submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail>; public abstract submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail>; public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>; public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>;
public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>; public abstract getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string>;
public abstract setClusterMetadata(key: string, value: string): Promise<void>; public abstract setClusterMetadata(key: string, value: string): Promise<void>;
public abstract getClusterMetadata(key: string): Promise<string>; public abstract getClusterMetadata(key: string): Promise<string>;
public abstract getTrialOutputLocalPath(trialJobId: string): Promise<string>; public abstract getTrialOutputLocalPath(trialJobId: string): Promise<string>;
...@@ -103,5 +101,5 @@ class NNIManagerIpConfig { ...@@ -103,5 +101,5 @@ class NNIManagerIpConfig {
export { export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm, TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters, TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
NNIManagerIpConfig, LogType NNIManagerIpConfig
}; };
...@@ -19,7 +19,7 @@ import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/exp ...@@ -19,7 +19,7 @@ import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/exp
import { ExperimentManager } from '../common/experimentManager'; import { ExperimentManager } from '../common/experimentManager';
import { TensorboardManager } from '../common/tensorboardManager'; import { TensorboardManager } from '../common/tensorboardManager';
import { import {
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../common/trainingService'; } from '../common/trainingService';
import { delay, getCheckpointDir, getExperimentRootDir, getLogDir, getMsgDispatcherCommand, mkDirP, getTunerProc, getLogLevel, isAlive, killPid } from '../common/utils'; import { delay, getCheckpointDir, getExperimentRootDir, getLogDir, getMsgDispatcherCommand, mkDirP, getTunerProc, getLogLevel, isAlive, killPid } from '../common/utils';
import { import {
...@@ -402,8 +402,8 @@ class NNIManager implements Manager { ...@@ -402,8 +402,8 @@ class NNIManager implements Manager {
// FIXME: unit test // FIXME: unit test
} }
public async getTrialLog(trialJobId: string, logType: LogType): Promise<string> { public async getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string> {
return this.trainingService.getTrialLog(trialJobId, logType); return this.trainingService.getTrialFile(trialJobId, fileName);
} }
public getExperimentProfile(): Promise<ExperimentProfile> { public getExperimentProfile(): Promise<ExperimentProfile> {
......
...@@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred'; ...@@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred';
import { Provider } from 'typescript-ioc'; import { Provider } from 'typescript-ioc';
import { MethodNotImplementedError } from '../../common/errors'; import { MethodNotImplementedError } from '../../common/errors';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType } from '../../common/trainingService'; import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from '../../common/trainingService';
const testTrainingServiceProvider: Provider = { const testTrainingServiceProvider: Provider = {
get: () => { return new MockedTrainingService(); } get: () => { return new MockedTrainingService(); }
...@@ -63,7 +63,7 @@ class MockedTrainingService extends TrainingService { ...@@ -63,7 +63,7 @@ class MockedTrainingService extends TrainingService {
return deferred.promise; return deferred.promise;
} }
public getTrialLog(trialJobId: string, logType: LogType): Promise<string> { public getTrialFile(trialJobId: string, fileName: string): Promise<string> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"child-process-promise": "^2.2.1", "child-process-promise": "^2.2.1",
"express": "^4.17.1", "express": "^4.17.1",
"express-joi-validator": "^2.0.1", "express-joi-validator": "^2.0.1",
"http-proxy": "^1.18.1",
"ignore": "^5.1.8", "ignore": "^5.1.8",
"js-base64": "^3.6.1", "js-base64": "^3.6.1",
"kubernetes-client": "^6.12.1", "kubernetes-client": "^6.12.1",
...@@ -37,6 +38,7 @@ ...@@ -37,6 +38,7 @@
"@types/chai-as-promised": "^7.1.0", "@types/chai-as-promised": "^7.1.0",
"@types/express": "^4.17.2", "@types/express": "^4.17.2",
"@types/glob": "^7.1.3", "@types/glob": "^7.1.3",
"@types/http-proxy": "^1.17.7",
"@types/js-base64": "^3.3.1", "@types/js-base64": "^3.3.1",
"@types/js-yaml": "^4.0.1", "@types/js-yaml": "^4.0.1",
"@types/lockfile": "^1.0.0", "@types/lockfile": "^1.0.0",
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import * as bodyParser from 'body-parser'; import * as bodyParser from 'body-parser';
import * as express from 'express'; import * as express from 'express';
import * as httpProxy from 'http-proxy';
import * as path from 'path'; import * as path from 'path';
import * as component from '../common/component'; import * as component from '../common/component';
import { RestServer } from '../common/restServer' import { RestServer } from '../common/restServer'
...@@ -21,6 +22,7 @@ import { getAPIRootUrl } from '../common/experimentStartupInfo'; ...@@ -21,6 +22,7 @@ import { getAPIRootUrl } from '../common/experimentStartupInfo';
@component.Singleton @component.Singleton
export class NNIRestServer extends RestServer { export class NNIRestServer extends RestServer {
private readonly LOGS_ROOT_URL: string = '/logs'; private readonly LOGS_ROOT_URL: string = '/logs';
protected netronProxy: any = null;
protected API_ROOT_URL: string = '/api/v1/nni'; protected API_ROOT_URL: string = '/api/v1/nni';
/** /**
...@@ -29,6 +31,7 @@ export class NNIRestServer extends RestServer { ...@@ -29,6 +31,7 @@ export class NNIRestServer extends RestServer {
constructor() { constructor() {
super(); super();
this.API_ROOT_URL = getAPIRootUrl(); this.API_ROOT_URL = getAPIRootUrl();
this.netronProxy = httpProxy.createProxyServer();
} }
/** /**
...@@ -39,6 +42,14 @@ export class NNIRestServer extends RestServer { ...@@ -39,6 +42,14 @@ export class NNIRestServer extends RestServer {
this.app.use(bodyParser.json({limit: '50mb'})); this.app.use(bodyParser.json({limit: '50mb'}));
this.app.use(this.API_ROOT_URL, createRestHandler(this)); this.app.use(this.API_ROOT_URL, createRestHandler(this));
this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir())); this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir()));
this.app.all('/netron/*', (req: express.Request, res: express.Response) => {
delete req.headers.host;
req.url = req.url.replace('/netron', '/');
this.netronProxy.web(req, res, {
changeOrigin: true,
target: 'https://netron.app'
});
});
this.app.get('*', (req: express.Request, res: express.Response) => { this.app.get('*', (req: express.Request, res: express.Response) => {
res.sendFile(path.resolve('static/index.html')); res.sendFile(path.resolve('static/index.html'));
}); });
......
...@@ -19,7 +19,7 @@ import { NNIRestServer } from './nniRestServer'; ...@@ -19,7 +19,7 @@ import { NNIRestServer } from './nniRestServer';
import { getVersion } from '../common/utils'; import { getVersion } from '../common/utils';
import { MetricType } from '../common/datastore'; import { MetricType } from '../common/datastore';
import { ProfileUpdateType } from '../common/manager'; import { ProfileUpdateType } from '../common/manager';
import { LogType, TrialJobStatus } from '../common/trainingService'; import { TrialJobStatus } from '../common/trainingService';
const expressJoi = require('express-joi-validator'); const expressJoi = require('express-joi-validator');
...@@ -53,6 +53,7 @@ class NNIRestHandler { ...@@ -53,6 +53,7 @@ class NNIRestHandler {
this.version(router); this.version(router);
this.checkStatus(router); this.checkStatus(router);
this.getExperimentProfile(router); this.getExperimentProfile(router);
this.getExperimentMetadata(router);
this.updateExperimentProfile(router); this.updateExperimentProfile(router);
this.importData(router); this.importData(router);
this.getImportedData(router); this.getImportedData(router);
...@@ -66,7 +67,7 @@ class NNIRestHandler { ...@@ -66,7 +67,7 @@ class NNIRestHandler {
this.getMetricData(router); this.getMetricData(router);
this.getMetricDataByRange(router); this.getMetricDataByRange(router);
this.getLatestMetricData(router); this.getLatestMetricData(router);
this.getTrialLog(router); this.getTrialFile(router);
this.exportData(router); this.exportData(router);
this.getExperimentsInfo(router); this.getExperimentsInfo(router);
this.startTensorboardTask(router); this.startTensorboardTask(router);
...@@ -296,13 +297,20 @@ class NNIRestHandler { ...@@ -296,13 +297,20 @@ class NNIRestHandler {
}); });
} }
private getTrialLog(router: Router): void { private getTrialFile(router: Router): void {
router.get('/trial-log/:id/:type', async(req: Request, res: Response) => { router.get('/trial-file/:id/:filename', async(req: Request, res: Response) => {
this.nniManager.getTrialLog(req.params.id, req.params.type as LogType).then((log: string) => { let encoding: string | null = null;
if (log === '') { const filename = req.params.filename;
log = 'No logs available.' if (!filename.includes('.') || filename.match(/.*\.(txt|log)/g)) {
encoding = 'utf8';
}
this.nniManager.getTrialFile(req.params.id, filename).then((content: Buffer | string) => {
if (content instanceof Buffer) {
res.header('Content-Type', 'application/octet-stream');
} else if (content === '') {
content = `${filename} is empty.`;
} }
res.send(log); res.send(content);
}).catch((err: Error) => { }).catch((err: Error) => {
this.handleError(err, res); this.handleError(err, res);
}); });
...@@ -319,6 +327,24 @@ class NNIRestHandler { ...@@ -319,6 +327,24 @@ class NNIRestHandler {
}); });
} }
private getExperimentMetadata(router: Router): void {
router.get('/experiment-metadata', (req: Request, res: Response) => {
Promise.all([
this.nniManager.getExperimentProfile(),
this.experimentsManager.getExperimentsInfo()
]).then(([profile, experimentInfo]) => {
for (const info of experimentInfo as any) {
if (info.id === profile.id) {
res.send(info);
break;
}
}
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private getExperimentsInfo(router: Router): void { private getExperimentsInfo(router: Router): void {
router.get('/experiments-info', (req: Request, res: Response) => { router.get('/experiments-info', (req: Request, res: Response) => {
this.experimentsManager.getExperimentsInfo().then((experimentInfo: JSON) => { this.experimentsManager.getExperimentsInfo().then((experimentInfo: JSON) => {
......
...@@ -13,7 +13,7 @@ import { ...@@ -13,7 +13,7 @@ import {
TrialJobStatistics, NNIManagerStatus TrialJobStatistics, NNIManagerStatus
} from '../../common/manager'; } from '../../common/manager';
import { import {
TrialJobApplicationForm, TrialJobDetail, TrialJobStatus, LogType TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
} from '../../common/trainingService'; } from '../../common/trainingService';
export const testManagerProvider: Provider = { export const testManagerProvider: Provider = {
...@@ -129,7 +129,7 @@ export class MockedNNIManager extends Manager { ...@@ -129,7 +129,7 @@ export class MockedNNIManager extends Manager {
public getLatestMetricData(): Promise<MetricDataRecord[]> { public getLatestMetricData(): Promise<MetricDataRecord[]> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
public getTrialLog(trialJobId: string, logType: LogType): Promise<string> { public getTrialFile(trialJobId: string, fileName: string): Promise<string> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
public getExperimentProfile(): Promise<ExperimentProfile> { public getExperimentProfile(): Promise<ExperimentProfile> {
......
...@@ -14,7 +14,7 @@ import {getExperimentId} from '../../common/experimentStartupInfo'; ...@@ -14,7 +14,7 @@ import {getExperimentId} from '../../common/experimentStartupInfo';
import {getLogger, Logger} from '../../common/log'; import {getLogger, Logger} from '../../common/log';
import {MethodNotImplementedError} from '../../common/errors'; import {MethodNotImplementedError} from '../../common/errors';
import { import {
NNIManagerIpConfig, TrialJobDetail, TrialJobMetric, LogType NNIManagerIpConfig, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import {delay, getExperimentRootDir, getIPV4Address, getJobCancelStatus, getVersion, uniqueString} from '../../common/utils'; import {delay, getExperimentRootDir, getIPV4Address, getJobCancelStatus, getVersion, uniqueString} from '../../common/utils';
import {AzureStorageClientUtility} from './azureStorageClientUtils'; import {AzureStorageClientUtility} from './azureStorageClientUtils';
...@@ -99,7 +99,7 @@ abstract class KubernetesTrainingService { ...@@ -99,7 +99,7 @@ abstract class KubernetesTrainingService {
return Promise.resolve(kubernetesTrialJob); return Promise.resolve(kubernetesTrialJob);
} }
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> { public async getTrialFile(_trialJobId: string, _filename: string): Promise<string | Buffer> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
......
...@@ -13,7 +13,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo'; ...@@ -13,7 +13,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { import {
HyperParameters, TrainingService, TrialJobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService'; } from '../../common/trainingService';
import { import {
delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString
...@@ -170,18 +170,20 @@ class LocalTrainingService implements TrainingService { ...@@ -170,18 +170,20 @@ class LocalTrainingService implements TrainingService {
return trialJob; return trialJob;
} }
public async getTrialLog(trialJobId: string, logType: LogType): Promise<string> { public async getTrialFile(trialJobId: string, fileName: string): Promise<string | Buffer> {
let logPath: string; // check filename here for security
if (logType === 'TRIAL_LOG') { if (!['trial.log', 'stderr', 'model.onnx', 'stdout'].includes(fileName)) {
logPath = path.join(this.rootDir, 'trials', trialJobId, 'trial.log'); throw new Error(`File unaccessible: ${fileName}`);
} else if (logType === 'TRIAL_STDOUT'){ }
logPath = path.join(this.rootDir, 'trials', trialJobId, 'stdout'); let encoding: string | null = null;
} else if (logType === 'TRIAL_ERROR') { if (!fileName.includes('.') || fileName.match(/.*\.(txt|log)/g)) {
logPath = path.join(this.rootDir, 'trials', trialJobId, 'stderr'); encoding = 'utf8';
} else { }
throw new Error('unexpected log type'); const logPath = path.join(this.rootDir, 'trials', trialJobId, fileName);
if (!fs.existsSync(logPath)) {
throw new Error(`File not found: ${logPath}`);
} }
return fs.promises.readFile(logPath, 'utf8'); return fs.promises.readFile(logPath, {encoding: encoding as any});
} }
public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void { public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
......
...@@ -15,7 +15,7 @@ import { getLogger, Logger } from '../../common/log'; ...@@ -15,7 +15,7 @@ import { getLogger, Logger } from '../../common/log';
import { MethodNotImplementedError } from '../../common/errors'; import { MethodNotImplementedError } from '../../common/errors';
import { import {
HyperParameters, NNIManagerIpConfig, TrainingService, HyperParameters, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay } from '../../common/utils'; import { delay } from '../../common/utils';
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../common/experimentConfig'; import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../common/experimentConfig';
...@@ -127,7 +127,7 @@ class PAITrainingService implements TrainingService { ...@@ -127,7 +127,7 @@ class PAITrainingService implements TrainingService {
return jobs; return jobs;
} }
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> { public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
......
...@@ -16,7 +16,7 @@ import { getLogger, Logger } from '../../common/log'; ...@@ -16,7 +16,7 @@ import { getLogger, Logger } from '../../common/log';
import { ObservableTimer } from '../../common/observableTimer'; import { ObservableTimer } from '../../common/observableTimer';
import { import {
HyperParameters, TrainingService, TrialJobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, LogType TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { import {
delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus, delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus,
...@@ -204,7 +204,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -204,7 +204,7 @@ class RemoteMachineTrainingService implements TrainingService {
* @param _trialJobId ID of trial job * @param _trialJobId ID of trial job
* @param _logType 'TRIAL_LOG' | 'TRIAL_STDERR' * @param _logType 'TRIAL_LOG' | 'TRIAL_STDERR'
*/ */
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> { public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { MethodNotImplementedError } from '../../common/errors'; import { MethodNotImplementedError } from '../../common/errors';
import { ExperimentConfig, RemoteConfig, OpenpaiConfig } from '../../common/experimentConfig'; import { ExperimentConfig, RemoteConfig, OpenpaiConfig } from '../../common/experimentConfig';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType } from '../../common/trainingService'; import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from '../../common/trainingService';
import { delay } from '../../common/utils'; import { delay } from '../../common/utils';
import { PAITrainingService } from '../pai/paiTrainingService'; import { PAITrainingService } from '../pai/paiTrainingService';
import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService'; import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService';
...@@ -52,7 +52,7 @@ class RouterTrainingService implements TrainingService { ...@@ -52,7 +52,7 @@ class RouterTrainingService implements TrainingService {
return await this.internalTrainingService.getTrialJob(trialJobId); return await this.internalTrainingService.getTrialJob(trialJobId);
} }
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> { public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
......
...@@ -13,7 +13,7 @@ import * as component from '../../common/component'; ...@@ -13,7 +13,7 @@ import * as component from '../../common/component';
import { NNIError, NNIErrorNames, MethodNotImplementedError } from '../../common/errors'; import { NNIError, NNIErrorNames, MethodNotImplementedError } from '../../common/errors';
import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo'; import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus, LogType } from '../../common/trainingService'; import { TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus } from '../../common/trainingService';
import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, randomSelect, uniqueString } from '../../common/utils'; import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, randomSelect, uniqueString } from '../../common/utils';
import { ExperimentConfig, SharedStorageConfig } from '../../common/experimentConfig'; import { ExperimentConfig, SharedStorageConfig } from '../../common/experimentConfig';
import { GPU_INFO, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, REPORT_METRIC_DATA, SEND_TRIAL_JOB_PARAMETER, STDOUT, TRIAL_END, VERSION_CHECK } from '../../core/commands'; import { GPU_INFO, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, REPORT_METRIC_DATA, SEND_TRIAL_JOB_PARAMETER, STDOUT, TRIAL_END, VERSION_CHECK } from '../../core/commands';
...@@ -157,7 +157,7 @@ class TrialDispatcher implements TrainingService { ...@@ -157,7 +157,7 @@ class TrialDispatcher implements TrainingService {
return trial; return trial;
} }
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> { public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
......
...@@ -100,8 +100,8 @@ describe('Unit Test for LocalTrainingService', () => { ...@@ -100,8 +100,8 @@ describe('Unit Test for LocalTrainingService', () => {
fs.mkdirSync(jobDetail.workingDirectory) fs.mkdirSync(jobDetail.workingDirectory)
fs.writeFileSync(path.join(jobDetail.workingDirectory, 'trial.log'), 'trial log') fs.writeFileSync(path.join(jobDetail.workingDirectory, 'trial.log'), 'trial log')
fs.writeFileSync(path.join(jobDetail.workingDirectory, 'stderr'), 'trial stderr') fs.writeFileSync(path.join(jobDetail.workingDirectory, 'stderr'), 'trial stderr')
chai.expect(await localTrainingService.getTrialLog(jobDetail.id, 'TRIAL_LOG')).to.be.equals('trial log'); chai.expect(await localTrainingService.getTrialFile(jobDetail.id, 'trial.log')).to.be.equals('trial log');
chai.expect(await localTrainingService.getTrialLog(jobDetail.id, 'TRIAL_ERROR')).to.be.equals('trial stderr'); chai.expect(await localTrainingService.getTrialFile(jobDetail.id, 'stderr')).to.be.equals('trial stderr');
fs.unlinkSync(path.join(jobDetail.workingDirectory, 'trial.log')) fs.unlinkSync(path.join(jobDetail.workingDirectory, 'trial.log'))
fs.unlinkSync(path.join(jobDetail.workingDirectory, 'stderr')) fs.unlinkSync(path.join(jobDetail.workingDirectory, 'stderr'))
fs.rmdirSync(jobDetail.workingDirectory) fs.rmdirSync(jobDetail.workingDirectory)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment