Unverified Commit 1bd17a02 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Visualization (#3878)

parent 4b38e64f
......@@ -182,6 +182,8 @@ Visualize the Experiment
Users can visualize their experiment in the same way as visualizing a normal hyper-parameter tuning experiment. For example, open ``localhost::8081`` in your browser, 8081 is the port that you set in ``exp.run``. Please refer to `here <../../Tutorial/WebUI.rst>`__ for details.
We support visualizing models with 3rd-party visualization engines (like `Netron <https://netron.app/>`__). This can be used by clicking ``Visualization`` in detail panel for each trial. Note that current visualization is based on `onnx <https://onnx.ai/>`__ . Built-in evaluators (e.g., Classification) will automatically export the model into a file, for your own evaluator, you need to save your file into ``$NNI_OUTPUT_DIR/model.onnx`` to make this work.
Export Top Models
-----------------
......
......@@ -24,6 +24,8 @@ The simplest way to customize a new evaluator is with functional APIs, which is
.. note:: Due to our current implementation limitation, the ``fit`` function should be put in another python file instead of putting it in the main file. This limitation will be fixed in future release.
.. note:: When using customized evaluators, if you want to visualize models, you need to export your model and save it into ``$NNI_OUTPUT_DIR/model.onnx`` in your evaluator.
With PyTorch-Lightning
----------------------
......
......@@ -8,7 +8,7 @@ import socket
from subprocess import Popen
import sys
import time
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Any
import colorama
......@@ -43,7 +43,7 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
_check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, str(config.experiment_working_directory))
config.experiment_name, proc.pid, str(config.experiment_working_directory), [])
_logger.info('Setting up...')
rest.post(port, '/experiment', config.json())
return proc
......@@ -78,7 +78,7 @@ def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int,
_check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, config.experiment_working_directory)
config.experiment_name, proc.pid, config.experiment_working_directory, ['retiarii'])
_logger.info('Setting up...')
rest.post(port, '/experiment', config.json())
return proc, pipe
......@@ -156,9 +156,10 @@ def _check_rest_server(port: int, retry: int = 3) -> None:
rest.get(port, '/check-status')
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None:
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str,
name: str, pid: int, logDir: str, tag: List[Any]) -> None:
experiments_config = Experiments()
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir)
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir, tag=tag)
def get_stopped_experiment_config(exp_id: str, mode: str) -> None:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import warnings
from typing import Dict, Union, Optional, List
from pathlib import Path
from typing import Dict, NoReturn, Union, Optional, List, Type
import pytorch_lightning as pl
import torch.nn as nn
......@@ -18,7 +20,13 @@ __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classificat
class LightningModule(pl.LightningModule):
def set_model(self, model):
"""
Basic wrapper of generated model.
Lightning modules used in NNI should inherit this class.
"""
def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> NoReturn:
if isinstance(model, type):
self.model = model()
else:
......@@ -112,13 +120,23 @@ class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: Union[Path, str, bool, None] = None):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
if export_onnx is None or export_onnx is True:
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
self.export_onnx.parent.mkdir(exist_ok=True)
elif export_onnx:
self.export_onnx = Path(export_onnx)
else:
self.export_onnx = None
self._already_exported = False
def forward(self, x):
y_hat = self.model(x)
return y_hat
......@@ -135,6 +153,11 @@ class _SupervisedLearningModule(LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
if not self._already_exported:
self.to_onnx(self.export_onnx, x, export_params=True)
self._already_exported = True
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('val_' + name, metric(y_hat, y), prog_bar=True)
......@@ -152,8 +175,7 @@ class _SupervisedLearningModule(LightningModule):
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
def teardown(self, stage):
if stage == 'fit':
def on_fit_end(self):
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
......@@ -175,9 +197,11 @@ class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)
class Classification(Lightning):
......@@ -200,6 +224,8 @@ class Classification(Lightning):
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
......@@ -211,9 +237,10 @@ class Classification(Lightning):
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
**trainer_kwargs):
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
......@@ -223,9 +250,11 @@ class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)
class Regression(Lightning):
......@@ -248,6 +277,8 @@ class Regression(Lightning):
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default: true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
......@@ -259,8 +290,9 @@ class Regression(Lightning):
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
**trainer_kwargs):
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
......@@ -156,7 +156,7 @@ def mock_get_latest_metric_data():
def mock_get_trial_log():
responses.add(
responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-log/:id/:type',
responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-file/:id/:filename',
json={"status":"RUNNING","errors":[]},
status=200,
content_type='application/json',
......
......@@ -4,7 +4,7 @@
'use strict';
import { MetricDataRecord, MetricType, TrialJobInfo } from './datastore';
import { TrialJobStatus, LogType } from './trainingService';
import { TrialJobStatus } from './trainingService';
import { ExperimentConfig } from './experimentConfig';
type ProfileUpdateType = 'TRIAL_CONCURRENCY' | 'MAX_EXEC_DURATION' | 'SEARCH_SPACE' | 'MAX_TRIAL_NUM';
......@@ -59,7 +59,7 @@ abstract class Manager {
public abstract getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise<MetricDataRecord[]>;
public abstract getLatestMetricData(): Promise<MetricDataRecord[]>;
public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>;
public abstract getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string>;
public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>;
public abstract getStatus(): NNIManagerStatus;
......
......@@ -8,8 +8,6 @@
*/
type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED' | 'EARLY_STOPPED';
type LogType = 'TRIAL_LOG' | 'TRIAL_STDOUT' | 'TRIAL_ERROR';
interface TrainingServiceMetadata {
readonly key: string;
readonly value: string;
......@@ -81,7 +79,7 @@ abstract class TrainingService {
public abstract submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>;
public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>;
public abstract getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string>;
public abstract setClusterMetadata(key: string, value: string): Promise<void>;
public abstract getClusterMetadata(key: string): Promise<string>;
public abstract getTrialOutputLocalPath(trialJobId: string): Promise<string>;
......@@ -103,5 +101,5 @@ class NNIManagerIpConfig {
export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
NNIManagerIpConfig, LogType
NNIManagerIpConfig
};
......@@ -19,7 +19,7 @@ import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/exp
import { ExperimentManager } from '../common/experimentManager';
import { TensorboardManager } from '../common/tensorboardManager';
import {
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../common/trainingService';
import { delay, getCheckpointDir, getExperimentRootDir, getLogDir, getMsgDispatcherCommand, mkDirP, getTunerProc, getLogLevel, isAlive, killPid } from '../common/utils';
import {
......@@ -402,8 +402,8 @@ class NNIManager implements Manager {
// FIXME: unit test
}
public async getTrialLog(trialJobId: string, logType: LogType): Promise<string> {
return this.trainingService.getTrialLog(trialJobId, logType);
public async getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string> {
return this.trainingService.getTrialFile(trialJobId, fileName);
}
public getExperimentProfile(): Promise<ExperimentProfile> {
......
......@@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred';
import { Provider } from 'typescript-ioc';
import { MethodNotImplementedError } from '../../common/errors';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType } from '../../common/trainingService';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from '../../common/trainingService';
const testTrainingServiceProvider: Provider = {
get: () => { return new MockedTrainingService(); }
......@@ -63,7 +63,7 @@ class MockedTrainingService extends TrainingService {
return deferred.promise;
}
public getTrialLog(trialJobId: string, logType: LogType): Promise<string> {
public getTrialFile(trialJobId: string, fileName: string): Promise<string> {
throw new MethodNotImplementedError();
}
......
......@@ -15,6 +15,7 @@
"child-process-promise": "^2.2.1",
"express": "^4.17.1",
"express-joi-validator": "^2.0.1",
"http-proxy": "^1.18.1",
"ignore": "^5.1.8",
"js-base64": "^3.6.1",
"kubernetes-client": "^6.12.1",
......@@ -37,6 +38,7 @@
"@types/chai-as-promised": "^7.1.0",
"@types/express": "^4.17.2",
"@types/glob": "^7.1.3",
"@types/http-proxy": "^1.17.7",
"@types/js-base64": "^3.3.1",
"@types/js-yaml": "^4.0.1",
"@types/lockfile": "^1.0.0",
......
......@@ -5,6 +5,7 @@
import * as bodyParser from 'body-parser';
import * as express from 'express';
import * as httpProxy from 'http-proxy';
import * as path from 'path';
import * as component from '../common/component';
import { RestServer } from '../common/restServer'
......@@ -21,6 +22,7 @@ import { getAPIRootUrl } from '../common/experimentStartupInfo';
@component.Singleton
export class NNIRestServer extends RestServer {
private readonly LOGS_ROOT_URL: string = '/logs';
protected netronProxy: any = null;
protected API_ROOT_URL: string = '/api/v1/nni';
/**
......@@ -29,6 +31,7 @@ export class NNIRestServer extends RestServer {
constructor() {
super();
this.API_ROOT_URL = getAPIRootUrl();
this.netronProxy = httpProxy.createProxyServer();
}
/**
......@@ -39,6 +42,14 @@ export class NNIRestServer extends RestServer {
this.app.use(bodyParser.json({limit: '50mb'}));
this.app.use(this.API_ROOT_URL, createRestHandler(this));
this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir()));
this.app.all('/netron/*', (req: express.Request, res: express.Response) => {
delete req.headers.host;
req.url = req.url.replace('/netron', '/');
this.netronProxy.web(req, res, {
changeOrigin: true,
target: 'https://netron.app'
});
});
this.app.get('*', (req: express.Request, res: express.Response) => {
res.sendFile(path.resolve('static/index.html'));
});
......
......@@ -19,7 +19,7 @@ import { NNIRestServer } from './nniRestServer';
import { getVersion } from '../common/utils';
import { MetricType } from '../common/datastore';
import { ProfileUpdateType } from '../common/manager';
import { LogType, TrialJobStatus } from '../common/trainingService';
import { TrialJobStatus } from '../common/trainingService';
const expressJoi = require('express-joi-validator');
......@@ -53,6 +53,7 @@ class NNIRestHandler {
this.version(router);
this.checkStatus(router);
this.getExperimentProfile(router);
this.getExperimentMetadata(router);
this.updateExperimentProfile(router);
this.importData(router);
this.getImportedData(router);
......@@ -66,7 +67,7 @@ class NNIRestHandler {
this.getMetricData(router);
this.getMetricDataByRange(router);
this.getLatestMetricData(router);
this.getTrialLog(router);
this.getTrialFile(router);
this.exportData(router);
this.getExperimentsInfo(router);
this.startTensorboardTask(router);
......@@ -296,13 +297,20 @@ class NNIRestHandler {
});
}
private getTrialLog(router: Router): void {
router.get('/trial-log/:id/:type', async(req: Request, res: Response) => {
this.nniManager.getTrialLog(req.params.id, req.params.type as LogType).then((log: string) => {
if (log === '') {
log = 'No logs available.'
private getTrialFile(router: Router): void {
router.get('/trial-file/:id/:filename', async(req: Request, res: Response) => {
let encoding: string | null = null;
const filename = req.params.filename;
if (!filename.includes('.') || filename.match(/.*\.(txt|log)/g)) {
encoding = 'utf8';
}
res.send(log);
this.nniManager.getTrialFile(req.params.id, filename).then((content: Buffer | string) => {
if (content instanceof Buffer) {
res.header('Content-Type', 'application/octet-stream');
} else if (content === '') {
content = `${filename} is empty.`;
}
res.send(content);
}).catch((err: Error) => {
this.handleError(err, res);
});
......@@ -319,6 +327,24 @@ class NNIRestHandler {
});
}
private getExperimentMetadata(router: Router): void {
router.get('/experiment-metadata', (req: Request, res: Response) => {
Promise.all([
this.nniManager.getExperimentProfile(),
this.experimentsManager.getExperimentsInfo()
]).then(([profile, experimentInfo]) => {
for (const info of experimentInfo as any) {
if (info.id === profile.id) {
res.send(info);
break;
}
}
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private getExperimentsInfo(router: Router): void {
router.get('/experiments-info', (req: Request, res: Response) => {
this.experimentsManager.getExperimentsInfo().then((experimentInfo: JSON) => {
......
......@@ -13,7 +13,7 @@ import {
TrialJobStatistics, NNIManagerStatus
} from '../../common/manager';
import {
TrialJobApplicationForm, TrialJobDetail, TrialJobStatus, LogType
TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
} from '../../common/trainingService';
export const testManagerProvider: Provider = {
......@@ -129,7 +129,7 @@ export class MockedNNIManager extends Manager {
public getLatestMetricData(): Promise<MetricDataRecord[]> {
throw new MethodNotImplementedError();
}
public getTrialLog(trialJobId: string, logType: LogType): Promise<string> {
public getTrialFile(trialJobId: string, fileName: string): Promise<string> {
throw new MethodNotImplementedError();
}
public getExperimentProfile(): Promise<ExperimentProfile> {
......
......@@ -14,7 +14,7 @@ import {getExperimentId} from '../../common/experimentStartupInfo';
import {getLogger, Logger} from '../../common/log';
import {MethodNotImplementedError} from '../../common/errors';
import {
NNIManagerIpConfig, TrialJobDetail, TrialJobMetric, LogType
NNIManagerIpConfig, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import {delay, getExperimentRootDir, getIPV4Address, getJobCancelStatus, getVersion, uniqueString} from '../../common/utils';
import {AzureStorageClientUtility} from './azureStorageClientUtils';
......@@ -99,7 +99,7 @@ abstract class KubernetesTrainingService {
return Promise.resolve(kubernetesTrialJob);
}
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> {
public async getTrialFile(_trialJobId: string, _filename: string): Promise<string | Buffer> {
throw new MethodNotImplementedError();
}
......
......@@ -13,7 +13,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import {
HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType
TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService';
import {
delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString
......@@ -170,18 +170,20 @@ class LocalTrainingService implements TrainingService {
return trialJob;
}
public async getTrialLog(trialJobId: string, logType: LogType): Promise<string> {
let logPath: string;
if (logType === 'TRIAL_LOG') {
logPath = path.join(this.rootDir, 'trials', trialJobId, 'trial.log');
} else if (logType === 'TRIAL_STDOUT'){
logPath = path.join(this.rootDir, 'trials', trialJobId, 'stdout');
} else if (logType === 'TRIAL_ERROR') {
logPath = path.join(this.rootDir, 'trials', trialJobId, 'stderr');
} else {
throw new Error('unexpected log type');
public async getTrialFile(trialJobId: string, fileName: string): Promise<string | Buffer> {
// check filename here for security
if (!['trial.log', 'stderr', 'model.onnx', 'stdout'].includes(fileName)) {
throw new Error(`File unaccessible: ${fileName}`);
}
let encoding: string | null = null;
if (!fileName.includes('.') || fileName.match(/.*\.(txt|log)/g)) {
encoding = 'utf8';
}
const logPath = path.join(this.rootDir, 'trials', trialJobId, fileName);
if (!fs.existsSync(logPath)) {
throw new Error(`File not found: ${logPath}`);
}
return fs.promises.readFile(logPath, 'utf8');
return fs.promises.readFile(logPath, {encoding: encoding as any});
}
public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
......
......@@ -15,7 +15,7 @@ import { getLogger, Logger } from '../../common/log';
import { MethodNotImplementedError } from '../../common/errors';
import {
HyperParameters, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import { delay } from '../../common/utils';
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../common/experimentConfig';
......@@ -127,7 +127,7 @@ class PAITrainingService implements TrainingService {
return jobs;
}
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> {
public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError();
}
......
......@@ -16,7 +16,7 @@ import { getLogger, Logger } from '../../common/log';
import { ObservableTimer } from '../../common/observableTimer';
import {
HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, LogType
TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import {
delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus,
......@@ -204,7 +204,7 @@ class RemoteMachineTrainingService implements TrainingService {
* @param _trialJobId ID of trial job
* @param _logType 'TRIAL_LOG' | 'TRIAL_STDERR'
*/
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> {
public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError();
}
......
......@@ -6,7 +6,7 @@
import { getLogger, Logger } from '../../common/log';
import { MethodNotImplementedError } from '../../common/errors';
import { ExperimentConfig, RemoteConfig, OpenpaiConfig } from '../../common/experimentConfig';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType } from '../../common/trainingService';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from '../../common/trainingService';
import { delay } from '../../common/utils';
import { PAITrainingService } from '../pai/paiTrainingService';
import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService';
......@@ -52,7 +52,7 @@ class RouterTrainingService implements TrainingService {
return await this.internalTrainingService.getTrialJob(trialJobId);
}
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> {
public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError();
}
......
......@@ -13,7 +13,7 @@ import * as component from '../../common/component';
import { NNIError, NNIErrorNames, MethodNotImplementedError } from '../../common/errors';
import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import { TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus, LogType } from '../../common/trainingService';
import { TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus } from '../../common/trainingService';
import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, randomSelect, uniqueString } from '../../common/utils';
import { ExperimentConfig, SharedStorageConfig } from '../../common/experimentConfig';
import { GPU_INFO, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, REPORT_METRIC_DATA, SEND_TRIAL_JOB_PARAMETER, STDOUT, TRIAL_END, VERSION_CHECK } from '../../core/commands';
......@@ -157,7 +157,7 @@ class TrialDispatcher implements TrainingService {
return trial;
}
public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> {
public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError();
}
......
......@@ -100,8 +100,8 @@ describe('Unit Test for LocalTrainingService', () => {
fs.mkdirSync(jobDetail.workingDirectory)
fs.writeFileSync(path.join(jobDetail.workingDirectory, 'trial.log'), 'trial log')
fs.writeFileSync(path.join(jobDetail.workingDirectory, 'stderr'), 'trial stderr')
chai.expect(await localTrainingService.getTrialLog(jobDetail.id, 'TRIAL_LOG')).to.be.equals('trial log');
chai.expect(await localTrainingService.getTrialLog(jobDetail.id, 'TRIAL_ERROR')).to.be.equals('trial stderr');
chai.expect(await localTrainingService.getTrialFile(jobDetail.id, 'trial.log')).to.be.equals('trial log');
chai.expect(await localTrainingService.getTrialFile(jobDetail.id, 'stderr')).to.be.equals('trial stderr');
fs.unlinkSync(path.join(jobDetail.workingDirectory, 'trial.log'))
fs.unlinkSync(path.join(jobDetail.workingDirectory, 'stderr'))
fs.rmdirSync(jobDetail.workingDirectory)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment