Unverified Commit 1f28d136 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix pai mode config (#3269)

parent 969c4c2f
......@@ -18,6 +18,15 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
data = config.json()
ts = data.pop('trainingService')
data['trial'] = {
'command': data.pop('trialCommand'),
'codeDir': data.pop('trialCodeDirectory'),
}
if 'trialGpuNumber' in data:
data['trial']['gpuNum'] = data.pop('trialGpuNumber')
if isinstance(ts, list):
hybrid_names = []
for conf in ts:
......@@ -70,14 +79,6 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
if tuner_gpu_indices is not None:
data['tuner']['gpuIndicies'] = tuner_gpu_indices
data['trial'] = {
'command': data.pop('trialCommand'),
'codeDir': data.pop('trialCodeDirectory'),
}
if 'trialGpuNumber' in data:
data['trial']['gpuNum'] = data.pop('trialGpuNumber')
return data
def _handle_training_service(ts, data):
......@@ -113,6 +114,9 @@ def _handle_training_service(ts, data):
data['trial']['image'] = ts['dockerImage']
data['trial']['nniManagerNFSMountPath'] = ts['localStorageMountPoint']
data['trial']['containerNFSMountPath'] = ts['containerStorageMountPoint']
data['trial']['paiStorageConfigName'] = ts['storageConfigName']
data['trial']['cpuNum'] = ts['trialCpuNumber']
data['trial']['memoryMB'] = ts['trialMemorySize']
data['paiConfig'] = {
'userName': ts['username'],
'token': ts['token'],
......
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
from dataclasses import dataclass
from pathlib import Path
from pathlib import Path, PurePosixPath
from typing import Any, Dict, Optional
from .base import PathLike
......@@ -17,6 +17,9 @@ class OpenpaiConfig(TrainingServiceConfig):
host: str
username: str
token: str
trial_cpu_number: int
trial_memory_size: str
storage_config_name: str
docker_image: str = 'msranni/nni:latest'
local_storage_mount_point: PathLike
container_storage_mount_point: str
......@@ -34,7 +37,7 @@ class OpenpaiConfig(TrainingServiceConfig):
_validation_rules = {
'platform': lambda value: (value == 'openpai', 'cannot be modified'),
'local_storage_mount_point': lambda value: Path(value).is_dir(),
'container_storage_mount_point': lambda value: (Path(value).is_absolute(), 'is not absolute'),
'container_storage_mount_point': lambda value: (PurePosixPath(value).is_absolute(), 'is not absolute'),
'openpai_config_file': lambda value: Path(value).is_file()
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment