Unverified Commit 1f28d136 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix pai mode config (#3269)

parent 969c4c2f
...@@ -18,6 +18,15 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str, ...@@ -18,6 +18,15 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
data = config.json() data = config.json()
ts = data.pop('trainingService') ts = data.pop('trainingService')
data['trial'] = {
'command': data.pop('trialCommand'),
'codeDir': data.pop('trialCodeDirectory'),
}
if 'trialGpuNumber' in data:
data['trial']['gpuNum'] = data.pop('trialGpuNumber')
if isinstance(ts, list): if isinstance(ts, list):
hybrid_names = [] hybrid_names = []
for conf in ts: for conf in ts:
...@@ -70,14 +79,6 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str, ...@@ -70,14 +79,6 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
if tuner_gpu_indices is not None: if tuner_gpu_indices is not None:
data['tuner']['gpuIndicies'] = tuner_gpu_indices data['tuner']['gpuIndicies'] = tuner_gpu_indices
data['trial'] = {
'command': data.pop('trialCommand'),
'codeDir': data.pop('trialCodeDirectory'),
}
if 'trialGpuNumber' in data:
data['trial']['gpuNum'] = data.pop('trialGpuNumber')
return data return data
def _handle_training_service(ts, data): def _handle_training_service(ts, data):
...@@ -113,6 +114,9 @@ def _handle_training_service(ts, data): ...@@ -113,6 +114,9 @@ def _handle_training_service(ts, data):
data['trial']['image'] = ts['dockerImage'] data['trial']['image'] = ts['dockerImage']
data['trial']['nniManagerNFSMountPath'] = ts['localStorageMountPoint'] data['trial']['nniManagerNFSMountPath'] = ts['localStorageMountPoint']
data['trial']['containerNFSMountPath'] = ts['containerStorageMountPoint'] data['trial']['containerNFSMountPath'] = ts['containerStorageMountPoint']
data['trial']['paiStorageConfigName'] = ts['storageConfigName']
data['trial']['cpuNum'] = ts['trialCpuNumber']
data['trial']['memoryMB'] = ts['trialMemorySize']
data['paiConfig'] = { data['paiConfig'] = {
'userName': ts['username'], 'userName': ts['username'],
'token': ts['token'], 'token': ts['token'],
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path, PurePosixPath
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .base import PathLike from .base import PathLike
...@@ -17,6 +17,9 @@ class OpenpaiConfig(TrainingServiceConfig): ...@@ -17,6 +17,9 @@ class OpenpaiConfig(TrainingServiceConfig):
host: str host: str
username: str username: str
token: str token: str
trial_cpu_number: int
trial_memory_size: str
storage_config_name: str
docker_image: str = 'msranni/nni:latest' docker_image: str = 'msranni/nni:latest'
local_storage_mount_point: PathLike local_storage_mount_point: PathLike
container_storage_mount_point: str container_storage_mount_point: str
...@@ -34,7 +37,7 @@ class OpenpaiConfig(TrainingServiceConfig): ...@@ -34,7 +37,7 @@ class OpenpaiConfig(TrainingServiceConfig):
_validation_rules = { _validation_rules = {
'platform': lambda value: (value == 'openpai', 'cannot be modified'), 'platform': lambda value: (value == 'openpai', 'cannot be modified'),
'local_storage_mount_point': lambda value: Path(value).is_dir(), 'local_storage_mount_point': lambda value: Path(value).is_dir(),
'container_storage_mount_point': lambda value: (Path(value).is_absolute(), 'is not absolute'), 'container_storage_mount_point': lambda value: (PurePosixPath(value).is_absolute(), 'is not absolute'),
'openpai_config_file': lambda value: Path(value).is_file() 'openpai_config_file': lambda value: Path(value).is_file()
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment