Unverified Commit f5b89bb6 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

Merge pull request #4776 from microsoft/v2.7

parents 7aa44612 1546962f
...@@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Dict, Iterator ...@@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Dict, Iterator
import torch.nn as nn import torch.nn as nn
from nni.common.serializer import is_traceable from nni.common.serializer import is_traceable, is_wrapped_with_trace
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node, Evaluator from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node, Evaluator
from nni.retiarii.mutator import Mutator from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
...@@ -361,7 +361,7 @@ class EvaluatorValueChoiceMutator(Mutator): ...@@ -361,7 +361,7 @@ class EvaluatorValueChoiceMutator(Mutator):
# we only need one such mutator for one model/evaluator # we only need one such mutator for one model/evaluator
def _mutate_traceable_object(self, obj: Any, value_choice_decisions: Dict[str, Any]) -> Any: def _mutate_traceable_object(self, obj: Any, value_choice_decisions: Dict[str, Any]) -> Any:
if not is_traceable(obj): if not _is_traceable_object(obj):
return obj return obj
updates = {} updates = {}
...@@ -400,7 +400,7 @@ class EvaluatorValueChoiceMutator(Mutator): ...@@ -400,7 +400,7 @@ class EvaluatorValueChoiceMutator(Mutator):
def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]: def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]:
# take all the value choice in the kwargs of evaluaator into a list # take all the value choice in the kwargs of evaluaator into a list
# `existing_mutators` can mutators generated from `model` # `existing_mutators` can mutators generated from `model`
if not is_traceable(evaluator): if not _is_traceable_object(evaluator):
return [] return []
mutator_candidates = {} mutator_candidates = {}
for param in _expand_nested_trace_kwargs(evaluator): for param in _expand_nested_trace_kwargs(evaluator):
...@@ -464,9 +464,12 @@ def _expand_nested_trace_kwargs(obj: Any) -> Iterator[Any]: ...@@ -464,9 +464,12 @@ def _expand_nested_trace_kwargs(obj: Any) -> Iterator[Any]:
# Get items from `trace_kwargs`. # Get items from `trace_kwargs`.
# If some item is traceable itself, get items recursively. # If some item is traceable itself, get items recursively.
if not is_traceable(obj): if _is_traceable_object(obj):
return for param in obj.trace_kwargs.values():
yield param
yield from _expand_nested_trace_kwargs(param)
for param in obj.trace_kwargs.values():
yield param def _is_traceable_object(obj: Any) -> bool:
yield from _expand_nested_trace_kwargs(param) # Is it a traceable "object" (not class)?
return is_traceable(obj) and not is_wrapped_with_trace(obj)
...@@ -280,7 +280,7 @@ class NasBench101Cell(Mutable): ...@@ -280,7 +280,7 @@ class NasBench101Cell(Mutable):
Warnings Warnings
-------- --------
:class:`NasBench101Cell` is not supported in :ref:`graph-based execution engine <graph-based-exeuction-engine>`. :class:`NasBench101Cell` is not supported in :ref:`graph-based execution engine <graph-based-execution-engine>`.
""" """
@staticmethod @staticmethod
......
...@@ -11,7 +11,7 @@ from pathlib import Path ...@@ -11,7 +11,7 @@ from pathlib import Path
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
__all__ = ['NoContextError', 'ContextStack', 'ModelNamespace'] __all__ = ['NoContextError', 'ContextStack', 'ModelNamespace', 'original_state_dict_hooks']
def import_(target: str, allow_none: bool = False) -> Any: def import_(target: str, allow_none: bool = False) -> Any:
......
...@@ -359,6 +359,7 @@ kubeflow_config_schema = { ...@@ -359,6 +359,7 @@ kubeflow_config_schema = {
'path': setType('path', str) 'path': setType('path', str)
}, },
Optional('reuse'): setType('reuse', bool), Optional('reuse'): setType('reuse', bool),
Optional('namespace'): setType('namespace', str),
}, { }, {
'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'), 'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
'apiVersion': setType('apiVersion', str), 'apiVersion': setType('apiVersion', str),
...@@ -377,6 +378,7 @@ kubeflow_config_schema = { ...@@ -377,6 +378,7 @@ kubeflow_config_schema = {
}, },
Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999), Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999),
Optional('reuse'): setType('reuse', bool), Optional('reuse'): setType('reuse', bool),
Optional('namespace'): setType('namespace', str),
}) })
} }
......
...@@ -42,6 +42,11 @@ stages: ...@@ -42,6 +42,11 @@ stages:
python tools/chineselink.py check python tools/chineselink.py check
displayName: Translation up-to-date displayName: Translation up-to-date
- script: |
cd docs
make -e SPHINXOPTS="-W -T -b linkcheck -q --keep-going" html
displayName: External links integrity check
- job: python - job: python
pool: pool:
vmImage: ubuntu-latest vmImage: ubuntu-latest
......
{ {
"lr":{"_type":"choice", "_value":[0.1, 0.01, 0.001, 0.0001]}, "lr":{"_type":"choice", "_value":[0.1, 0.01, 0.001, 0.0001]},
"optimizer":{"_type":"choice", "_value":["SGD", "Adadelta", "Adagrad", "Adam", "Adamax"]}, "optimizer":{"_type":"choice", "_value":["SGD", "Adadelta", "Adagrad", "Adam", "Adamax"]},
"model":{"_type":"choice", "_value":["vgg", "resnet18"]} "model":{"_type":"choice", "_value":["vgg"]}
} }
...@@ -18,6 +18,7 @@ kubeflow: ...@@ -18,6 +18,7 @@ kubeflow:
azureStorage: azureStorage:
accountName: accountName:
azureShare: azureShare:
namespace: kubeflow
trial: trial:
worker: worker:
replicas: 1 replicas: 1
...@@ -35,7 +36,7 @@ frameworkcontroller: ...@@ -35,7 +36,7 @@ frameworkcontroller:
maxTrialNum: 2 maxTrialNum: 2
trialConcurrency: 2 trialConcurrency: 2
frameworkcontrollerConfig: frameworkcontrollerConfig:
serviceAccountName: frameworkbarrier serviceAccountName: frameworkcontroller
storage: azureStorage storage: azureStorage
keyVault: keyVault:
vaultName: vaultName:
...@@ -43,6 +44,7 @@ frameworkcontroller: ...@@ -43,6 +44,7 @@ frameworkcontroller:
azureStorage: azureStorage:
accountName: accountName:
azureShare: azureShare:
namespace: kubeflow
trial: trial:
taskRoles: taskRoles:
- name: worker - name: worker
......
...@@ -20,6 +20,7 @@ kubeflow: ...@@ -20,6 +20,7 @@ kubeflow:
trainingService: trainingService:
reuseMode: true reuseMode: true
platform: kubeflow platform: kubeflow
namespace: kubeflow
worker: worker:
command: command:
code_directory: code_directory:
...@@ -44,6 +45,7 @@ frameworkcontroller: ...@@ -44,6 +45,7 @@ frameworkcontroller:
trainingService: trainingService:
reuseMode: true reuseMode: true
platform: frameworkcontroller platform: frameworkcontroller
namespace: kubeflow
serviceAccountName: frameworkcontroller serviceAccountName: frameworkcontroller
taskRoles: taskRoles:
- name: worker - name: worker
......
...@@ -122,12 +122,18 @@ def print_file_content(filepath): ...@@ -122,12 +122,18 @@ def print_file_content(filepath):
print(content, flush=True) print(content, flush=True)
def print_trial_job_log(training_service, trial_jobs_url): def print_trial_job_log(training_service, trial_jobs_url):
trial_jobs = get_trial_jobs(trial_jobs_url) trial_log_root = os.path.join(get_experiment_dir(EXPERIMENT_URL), 'trials')
for trial_job in trial_jobs: if not os.path.exists(trial_log_root):
trial_log_dir = os.path.join(get_experiment_dir(EXPERIMENT_URL), 'trials', trial_job['trialJobId']) print('trial log folder does not exist: {}'.format(trial_log_root), flush=True)
return
folders = os.listdir(trial_log_root)
for name in folders:
trial_log_dir = os.path.join(trial_log_root, name)
log_files = ['stderr', 'trial.log'] if training_service == 'local' else ['stdout_log_collection.log'] log_files = ['stderr', 'trial.log'] if training_service == 'local' else ['stdout_log_collection.log']
for log_file in log_files: for log_file in log_files:
print_file_content(os.path.join(trial_log_dir, log_file)) log_file_path = os.path.join(trial_log_dir, log_file)
if os.path.exists(log_file_path):
print_file_content(log_file_path)
def print_experiment_log(experiment_id): def print_experiment_log(experiment_id):
log_dir = get_nni_log_dir(experiment_id=experiment_id) log_dir = get_nni_log_dir(experiment_id=experiment_id)
......
...@@ -1229,6 +1229,11 @@ class Shared(unittest.TestCase): ...@@ -1229,6 +1229,11 @@ class Shared(unittest.TestCase):
assert len(set(values)) == 3 assert len(set(values)) == 3
@unittest.skipIf(pytorch_lightning.__version__ < '1.0', 'Legacy PyTorch-lightning not supported')
def test_valuechoice_classification(self):
evaluator = pl.Classification(criterion=nn.CrossEntropyLoss)
process_evaluator_mutations(evaluator, [])
def test_retiarii_nn_import(self): def test_retiarii_nn_import(self):
dummy = torch.zeros(1, 16, 32, 24) dummy = torch.zeros(1, 16, 32, 24)
nn.init.uniform_(dummy) nn.init.uniform_(dummy)
......
...@@ -132,6 +132,7 @@ export interface KubeflowConfig extends TrainingServiceConfig { ...@@ -132,6 +132,7 @@ export interface KubeflowConfig extends TrainingServiceConfig {
master?: KubeflowRoleConfig; master?: KubeflowRoleConfig;
reuseMode: boolean; reuseMode: boolean;
maxTrialNumberPerGpu?: number; maxTrialNumberPerGpu?: number;
namespace?: string;
} }
export interface FrameworkControllerTaskRoleConfig { export interface FrameworkControllerTaskRoleConfig {
...@@ -156,7 +157,7 @@ export interface FrameworkControllerConfig extends TrainingServiceConfig { ...@@ -156,7 +157,7 @@ export interface FrameworkControllerConfig extends TrainingServiceConfig {
taskRoles: FrameworkControllerTaskRoleConfig[]; taskRoles: FrameworkControllerTaskRoleConfig[];
reuseMode: boolean; reuseMode: boolean;
maxTrialNumberPerGpu?: number; maxTrialNumberPerGpu?: number;
namespace?: 'default'; namespace?: string;
apiVersion?: string; apiVersion?: string;
} }
......
...@@ -52,7 +52,8 @@ if __name__ == "__main__": ...@@ -52,7 +52,8 @@ if __name__ == "__main__":
print('stop_result:failed') print('stop_result:failed')
exit(0) exit(0)
loop_count += 1 loop_count += 1
time.sleep(500) time.sleep(5)
status = run.get_status()
print('stop_result:success') print('stop_result:success')
exit(0) exit(0)
elif line == 'receive': elif line == 'receive':
......
...@@ -11,7 +11,7 @@ class AdlClientV1 extends KubernetesCRDClient { ...@@ -11,7 +11,7 @@ class AdlClientV1 extends KubernetesCRDClient {
/** /**
* constructor, to initialize adl CRD definition * constructor, to initialize adl CRD definition
*/ */
protected readonly namespace: string; public readonly namespace: string;
public constructor(namespace: string) { public constructor(namespace: string) {
super(); super();
......
...@@ -118,7 +118,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -118,7 +118,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
} else { } else {
configTaskRoles = this.parseCustomTaskRoles(this.fcTemplate.spec.taskRoles) configTaskRoles = this.parseCustomTaskRoles(this.fcTemplate.spec.taskRoles)
} }
const namespace = this.fcClusterConfig.namespace ? this.fcClusterConfig.namespace : "default"; const namespace = this.fcClusterConfig.namespace ?? "default";
this.genericK8sClient.setNamespace = namespace; this.genericK8sClient.setNamespace = namespace;
if (this.kubernetesRestServerPort === undefined) { if (this.kubernetesRestServerPort === undefined) {
...@@ -134,7 +134,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -134,7 +134,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
const trialJobId: string = uniqueString(5); const trialJobId: string = uniqueString(5);
// Set trial's NFS working folder // Set trial's NFS working folder
const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId); const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId);
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId); const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials', trialJobId);
let frameworkcontrollerJobName: string = `nniexp${this.experimentId}trial${trialJobId}`.toLowerCase(); let frameworkcontrollerJobName: string = `nniexp${this.experimentId}trial${trialJobId}`.toLowerCase();
let frameworkcontrollerJobConfig: any; let frameworkcontrollerJobConfig: any;
...@@ -204,6 +204,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -204,6 +204,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
let namespace: string | undefined; let namespace: string | undefined;
this.fcClusterConfig = FrameworkControllerClusterConfigFactory this.fcClusterConfig = FrameworkControllerClusterConfigFactory
.generateFrameworkControllerClusterConfig(frameworkcontrollerClusterJsonObject); .generateFrameworkControllerClusterConfig(frameworkcontrollerClusterJsonObject);
this.genericK8sClient.setNamespace = this.fcClusterConfig.namespace ?? "default";
if (this.fcClusterConfig.storageType === 'azureStorage') { if (this.fcClusterConfig.storageType === 'azureStorage') {
const azureFrameworkControllerClusterConfig: FrameworkControllerClusterConfigAzure = const azureFrameworkControllerClusterConfig: FrameworkControllerClusterConfigAzure =
<FrameworkControllerClusterConfigAzure>this.fcClusterConfig; <FrameworkControllerClusterConfigAzure>this.fcClusterConfig;
...@@ -346,8 +347,8 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -346,8 +347,8 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
for (const taskRole of configTaskRoles) { for (const taskRole of configTaskRoles) {
const runScriptContent: string = const runScriptContent: string =
await this.generateRunScript('frameworkcontroller', trialJobId, trialWorkingFolder, await this.generateRunScript('frameworkcontroller', trialJobId, trialWorkingFolder,
this.generateCommandScript(configTaskRoles, taskRole.command), form.sequenceId.toString(), this.generateCommandScript(configTaskRoles, taskRole.command),
taskRole.name, taskRole.gpuNum ? taskRole.gpuNum : 0); form.sequenceId.toString(), taskRole.name, taskRole.gpuNum ? taskRole.gpuNum : 0);
await fs.promises.writeFile(path.join(trialLocalTempFolder, `run_${taskRole.name}.sh`), runScriptContent, {encoding: 'utf8'}); await fs.promises.writeFile(path.join(trialLocalTempFolder, `run_${taskRole.name}.sh`), runScriptContent, {encoding: 'utf8'});
} }
...@@ -439,7 +440,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -439,7 +440,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
kind: 'Framework', kind: 'Framework',
metadata: { metadata: {
name: frameworkcontrollerJobName, name: frameworkcontrollerJobName,
namespace: this.fcClusterConfig.namespace ? this.fcClusterConfig.namespace : "default", namespace: this.fcClusterConfig.namespace ?? "default",
labels: { labels: {
app: this.NNI_KUBERNETES_TRIAL_LABEL, app: this.NNI_KUBERNETES_TRIAL_LABEL,
expId: getExperimentId(), expId: getExperimentId(),
......
...@@ -17,7 +17,7 @@ class TFOperatorClientV1Alpha2 extends KubernetesCRDClient { ...@@ -17,7 +17,7 @@ class TFOperatorClientV1Alpha2 extends KubernetesCRDClient {
} }
protected get operator(): any { protected get operator(): any {
return this.client.apis['kubeflow.org'].v1alpha2.namespaces('default').tfjobs; return this.client.apis['kubeflow.org'].v1alpha2.namespaces(this.namespace).tfjobs;
} }
public get containerName(): string { public get containerName(): string {
...@@ -36,7 +36,7 @@ class TFOperatorClientV1Beta1 extends KubernetesCRDClient { ...@@ -36,7 +36,7 @@ class TFOperatorClientV1Beta1 extends KubernetesCRDClient {
} }
protected get operator(): any { protected get operator(): any {
return this.client.apis['kubeflow.org'].v1beta1.namespaces('default').tfjobs; return this.client.apis['kubeflow.org'].v1beta1.namespaces(this.namespace).tfjobs;
} }
public get containerName(): string { public get containerName(): string {
...@@ -55,7 +55,7 @@ class TFOperatorClientV1Beta2 extends KubernetesCRDClient { ...@@ -55,7 +55,7 @@ class TFOperatorClientV1Beta2 extends KubernetesCRDClient {
} }
protected get operator(): any { protected get operator(): any {
return this.client.apis['kubeflow.org'].v1beta2.namespaces('default').tfjobs; return this.client.apis['kubeflow.org'].v1beta2.namespaces(this.namespace).tfjobs;
} }
public get containerName(): string { public get containerName(): string {
...@@ -74,7 +74,7 @@ class TFOperatorClientV1 extends KubernetesCRDClient { ...@@ -74,7 +74,7 @@ class TFOperatorClientV1 extends KubernetesCRDClient {
} }
protected get operator(): any { protected get operator(): any {
return this.client.apis['kubeflow.org'].v1.namespaces('default').tfjobs; return this.client.apis['kubeflow.org'].v1.namespaces(this.namespace).tfjobs;
} }
public get containerName(): string { public get containerName(): string {
...@@ -92,7 +92,7 @@ class PyTorchOperatorClientV1 extends KubernetesCRDClient { ...@@ -92,7 +92,7 @@ class PyTorchOperatorClientV1 extends KubernetesCRDClient {
} }
protected get operator(): any { protected get operator(): any {
return this.client.apis['kubeflow.org'].v1.namespaces('default').pytorchjobs; return this.client.apis['kubeflow.org'].v1.namespaces(this.namespace).pytorchjobs;
} }
public get containerName(): string { public get containerName(): string {
...@@ -110,7 +110,7 @@ class PyTorchOperatorClientV1Alpha2 extends KubernetesCRDClient { ...@@ -110,7 +110,7 @@ class PyTorchOperatorClientV1Alpha2 extends KubernetesCRDClient {
} }
protected get operator(): any { protected get operator(): any {
return this.client.apis['kubeflow.org'].v1alpha2.namespaces('default').pytorchjobs; return this.client.apis['kubeflow.org'].v1alpha2.namespaces(this.namespace).pytorchjobs;
} }
public get containerName(): string { public get containerName(): string {
...@@ -129,7 +129,7 @@ class PyTorchOperatorClientV1Beta1 extends KubernetesCRDClient { ...@@ -129,7 +129,7 @@ class PyTorchOperatorClientV1Beta1 extends KubernetesCRDClient {
} }
protected get operator(): any { protected get operator(): any {
return this.client.apis['kubeflow.org'].v1beta1.namespaces('default').pytorchjobs; return this.client.apis['kubeflow.org'].v1beta1.namespaces(this.namespace).pytorchjobs;
} }
public get containerName(): string { public get containerName(): string {
...@@ -148,7 +148,7 @@ class PyTorchOperatorClientV1Beta2 extends KubernetesCRDClient { ...@@ -148,7 +148,7 @@ class PyTorchOperatorClientV1Beta2 extends KubernetesCRDClient {
} }
protected get operator(): any { protected get operator(): any {
return this.client.apis['kubeflow.org'].v1beta2.namespaces('default').pytorchjobs; return this.client.apis['kubeflow.org'].v1beta2.namespaces(this.namespace).pytorchjobs;
} }
public get containerName(): string { public get containerName(): string {
......
...@@ -18,8 +18,8 @@ export type OperatorApiVersion = 'v1alpha2' | 'v1beta1' | 'v1beta2' | 'v1'; ...@@ -18,8 +18,8 @@ export type OperatorApiVersion = 'v1alpha2' | 'v1beta1' | 'v1beta2' | 'v1';
*/ */
export class KubeflowClusterConfig extends KubernetesClusterConfig { export class KubeflowClusterConfig extends KubernetesClusterConfig {
public readonly operator: KubeflowOperator; public readonly operator: KubeflowOperator;
constructor(apiVersion: string, operator: KubeflowOperator) { constructor(apiVersion: string, operator: KubeflowOperator, namespace?: string) {
super(apiVersion); super(apiVersion, undefined, namespace);
this.operator = operator; this.operator = operator;
} }
} }
...@@ -30,9 +30,10 @@ export class KubeflowClusterConfigNFS extends KubernetesClusterConfigNFS { ...@@ -30,9 +30,10 @@ export class KubeflowClusterConfigNFS extends KubernetesClusterConfigNFS {
operator: KubeflowOperator, operator: KubeflowOperator,
apiVersion: string, apiVersion: string,
nfs: NFSConfig, nfs: NFSConfig,
storage?: KubernetesStorageKind storage?: KubernetesStorageKind,
namespace?: string
) { ) {
super(apiVersion, nfs, storage); super(apiVersion, nfs, storage, namespace);
this.operator = operator; this.operator = operator;
} }
...@@ -48,7 +49,8 @@ export class KubeflowClusterConfigNFS extends KubernetesClusterConfigNFS { ...@@ -48,7 +49,8 @@ export class KubeflowClusterConfigNFS extends KubernetesClusterConfigNFS {
kubeflowClusterConfigObjectNFS.operator, kubeflowClusterConfigObjectNFS.operator,
kubeflowClusterConfigObjectNFS.apiVersion, kubeflowClusterConfigObjectNFS.apiVersion,
kubeflowClusterConfigObjectNFS.nfs, kubeflowClusterConfigObjectNFS.nfs,
kubeflowClusterConfigObjectNFS.storage kubeflowClusterConfigObjectNFS.storage,
kubeflowClusterConfigObjectNFS.namespace
); );
} }
} }
...@@ -61,9 +63,10 @@ export class KubeflowClusterConfigAzure extends KubernetesClusterConfigAzure { ...@@ -61,9 +63,10 @@ export class KubeflowClusterConfigAzure extends KubernetesClusterConfigAzure {
apiVersion: string, apiVersion: string,
keyVault: KeyVaultConfig, keyVault: KeyVaultConfig,
azureStorage: AzureStorage, azureStorage: AzureStorage,
storage?: KubernetesStorageKind storage?: KubernetesStorageKind,
namespace?: string
) { ) {
super(apiVersion, keyVault, azureStorage, storage); super(apiVersion, keyVault, azureStorage, storage, undefined, namespace);
this.operator = operator; this.operator = operator;
} }
...@@ -79,7 +82,8 @@ export class KubeflowClusterConfigAzure extends KubernetesClusterConfigAzure { ...@@ -79,7 +82,8 @@ export class KubeflowClusterConfigAzure extends KubernetesClusterConfigAzure {
kubeflowClusterConfigObjectAzure.apiVersion, kubeflowClusterConfigObjectAzure.apiVersion,
kubeflowClusterConfigObjectAzure.keyVault, kubeflowClusterConfigObjectAzure.keyVault,
kubeflowClusterConfigObjectAzure.azureStorage, kubeflowClusterConfigObjectAzure.azureStorage,
kubeflowClusterConfigObjectAzure.storage kubeflowClusterConfigObjectAzure.storage,
kubeflowClusterConfigObjectAzure.namespace
); );
} }
} }
......
...@@ -14,7 +14,7 @@ export class KubeflowJobRestServer extends KubernetesJobRestServer { ...@@ -14,7 +14,7 @@ export class KubeflowJobRestServer extends KubernetesJobRestServer {
/** /**
* constructor to provide NNIRestServer's own rest property, e.g. port * constructor to provide NNIRestServer's own rest property, e.g. port
*/ */
constructor() { constructor(kubeflowTrainingService: KubeflowTrainingService) {
super(component.get(KubeflowTrainingService)); super(kubeflowTrainingService);
} }
} }
...@@ -69,7 +69,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -69,7 +69,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
} }
if (this.kubernetesRestServerPort === undefined) { if (this.kubernetesRestServerPort === undefined) {
const restServer: KubeflowJobRestServer = component.get(KubeflowJobRestServer); const restServer: KubeflowJobRestServer = new KubeflowJobRestServer(this);
this.kubernetesRestServerPort = restServer.clusterRestServerPort; this.kubernetesRestServerPort = restServer.clusterRestServerPort;
} }
...@@ -81,7 +81,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -81,7 +81,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
const trialJobId: string = uniqueString(5); const trialJobId: string = uniqueString(5);
const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId); const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId);
const kubeflowJobName: string = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase(); const kubeflowJobName: string = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase();
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId); const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials', trialJobId);
//prepare the runscript //prepare the runscript
await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, form); await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, form);
//upload script files to sotrage //upload script files to sotrage
...@@ -120,6 +120,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -120,6 +120,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
case TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG: { case TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG: {
const kubeflowClusterJsonObject: object = JSON.parse(value); const kubeflowClusterJsonObject: object = JSON.parse(value);
this.kubeflowClusterConfig = KubeflowClusterConfigFactory.generateKubeflowClusterConfig(kubeflowClusterJsonObject); this.kubeflowClusterConfig = KubeflowClusterConfigFactory.generateKubeflowClusterConfig(kubeflowClusterJsonObject);
this.genericK8sClient.setNamespace = this.kubeflowClusterConfig.namespace ?? "default";
if (this.kubeflowClusterConfig.storageType === 'azureStorage') { if (this.kubeflowClusterConfig.storageType === 'azureStorage') {
const azureKubeflowClusterConfig: KubeflowClusterConfigAzure = <KubeflowClusterConfigAzure>this.kubeflowClusterConfig; const azureKubeflowClusterConfig: KubeflowClusterConfigAzure = <KubeflowClusterConfigAzure>this.kubeflowClusterConfig;
this.azureStorageAccountName = azureKubeflowClusterConfig.azureStorage.accountName; this.azureStorageAccountName = azureKubeflowClusterConfig.azureStorage.accountName;
...@@ -137,6 +138,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -137,6 +138,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
} }
this.kubernetesCRDClient = KubeflowOperatorClientFactory.createClient( this.kubernetesCRDClient = KubeflowOperatorClientFactory.createClient(
this.kubeflowClusterConfig.operator, this.kubeflowClusterConfig.apiVersion); this.kubeflowClusterConfig.operator, this.kubeflowClusterConfig.apiVersion);
this.kubernetesCRDClient.namespace = this.kubeflowClusterConfig.namespace ?? "default";
break; break;
} }
case TrialConfigMetadataKey.TRIAL_CONFIG: { case TrialConfigMetadataKey.TRIAL_CONFIG: {
...@@ -310,7 +312,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -310,7 +312,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
// Generate kubeflow job resource config object // Generate kubeflow job resource config object
const kubeflowJobConfig: any = await this.generateKubeflowJobConfig(trialJobId, trialWorkingFolder, kubeflowJobName, workerPodResources, const kubeflowJobConfig: any = await this.generateKubeflowJobConfig(trialJobId, trialWorkingFolder, kubeflowJobName, workerPodResources,
nonWorkerResources); nonWorkerResources);
this.log.info('kubeflowJobConfig:', kubeflowJobConfig);
return Promise.resolve(kubeflowJobConfig); return Promise.resolve(kubeflowJobConfig);
} }
...@@ -368,7 +370,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -368,7 +370,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
kind: this.kubernetesCRDClient.jobKind, kind: this.kubernetesCRDClient.jobKind,
metadata: { metadata: {
name: kubeflowJobName, name: kubeflowJobName,
namespace: 'default', namespace: this.kubernetesCRDClient.namespace,
labels: { labels: {
app: this.NNI_KUBERNETES_TRIAL_LABEL, app: this.NNI_KUBERNETES_TRIAL_LABEL,
expId: getExperimentId(), expId: getExperimentId(),
......
...@@ -150,6 +150,7 @@ abstract class KubernetesCRDClient { ...@@ -150,6 +150,7 @@ abstract class KubernetesCRDClient {
protected readonly client: any; protected readonly client: any;
protected readonly log: Logger = getLogger('KubernetesCRDClient'); protected readonly log: Logger = getLogger('KubernetesCRDClient');
protected crdSchema: any; protected crdSchema: any;
public namespace: string = 'default';
constructor() { constructor() {
this.client = new Client1_10({config: getKubernetesConfig()}); this.client = new Client1_10({config: getKubernetesConfig()});
......
...@@ -230,7 +230,7 @@ abstract class KubernetesTrainingService { ...@@ -230,7 +230,7 @@ abstract class KubernetesTrainingService {
this.azureStorageSecretName = String.Format('nni-secret-{0}', uniqueString(8) this.azureStorageSecretName = String.Format('nni-secret-{0}', uniqueString(8)
.toLowerCase()); .toLowerCase());
const namespace = this.genericK8sClient.getNamespace ? this.genericK8sClient.getNamespace : "default" const namespace = this.genericK8sClient.getNamespace ?? "default";
await this.genericK8sClient.createSecret( await this.genericK8sClient.createSecret(
{ {
apiVersion: 'v1', apiVersion: 'v1',
...@@ -330,7 +330,7 @@ abstract class KubernetesTrainingService { ...@@ -330,7 +330,7 @@ abstract class KubernetesTrainingService {
const body = fs.readFileSync(filePath).toString('base64'); const body = fs.readFileSync(filePath).toString('base64');
const registrySecretName = String.Format('nni-secret-{0}', uniqueString(8) const registrySecretName = String.Format('nni-secret-{0}', uniqueString(8)
.toLowerCase()); .toLowerCase());
const namespace = this.genericK8sClient.getNamespace ? this.genericK8sClient.getNamespace : "default" const namespace = this.genericK8sClient.getNamespace ?? "default";
await this.genericK8sClient.createSecret( await this.genericK8sClient.createSecret(
{ {
apiVersion: 'v1', apiVersion: 'v1',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment