"comfy/vscode:/vscode.git/clone" did not exist on "fcb25d37dbd0bab5c1c1936778b41ab614dd3a6d"
Unverified Commit 32efaa36 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #219 from microsoft/master

merge master
parents cd3a912a 97b258b0
......@@ -6,7 +6,7 @@ Assessor receives the intermediate result from Trial and decides whether the Tri
Here is an experimental result of MNIST after using 'Curvefitting' Assessor in 'maximize' mode, you can see that assessor successfully **early stopped** many trials with bad hyperparameters in advance. If you use assessor, we may get better hyperparameters under the same computing resources.
*Implemented code directory: config_assessor.yml <https://github.com/Microsoft/nni/blob/master/examples/trials/mnist/config_assessor.yml>*
*Implemented code directory: config_assessor.yml <https://github.com/Microsoft/nni/blob/master/examples/trials/mnist-tfv1/config_assessor.yml>*
.. image:: ../img/Assessor.png
......
......@@ -17,3 +17,4 @@ Builtin-Tuners
Network Morphism <Tuner/NetworkmorphismTuner>
Hyperband <Tuner/HyperbandAdvisor>
BOHB <Tuner/BohbAdvisor>
PPO Tuner <Tuner/PPOTuner>
......@@ -28,7 +28,7 @@ author = 'Microsoft'
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = 'v1.1'
release = 'v1.2'
# -- General configuration ---------------------------------------------------
......
#################
###################
Feature Engineering
#################
###################
We are glad to announce the alpha release for Feature Engineering toolkit on top of NNI,
it's still in the experiment phase which might evolve based on usage feedback.
......
#################
##############
NAS Algorithms
#################
##############
Automatic neural architecture search is taking an increasingly important role on finding better models.
Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models.
......
......@@ -9,4 +9,4 @@ json_tricks
numpy
scipy
coverage
sklearn
scikit-learn==0.20
......@@ -94,7 +94,7 @@ if __name__ == "__main__":
print()
test_benchmark.run_all_test(pipeline1)
pipeline2 = make_pipeline(FeatureGradientSelector(n_features=20), LogisticRegression())
pipeline2 = make_pipeline(FeatureGradientSelector(), LogisticRegression())
print("Test data selected by FeatureGradientSelector in LogisticRegression.")
print()
test_benchmark.run_all_test(pipeline2)
......@@ -103,5 +103,10 @@ if __name__ == "__main__":
print("Test data selected by TreeClssifier in LogisticRegression.")
print()
test_benchmark.run_all_test(pipeline3)
pipeline4 = make_pipeline(FeatureGradientSelector(n_features=20), LogisticRegression())
print("Test data selected by FeatureGradientSelector top 20 in LogisticRegression.")
print()
test_benchmark.run_all_test(pipeline4)
print("Done.")
\ No newline at end of file
from nni.compression.torch import FPGMPruner
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import FPGMPruner
class Mnist(torch.nn.Module):
def __init__(self):
......@@ -23,8 +22,8 @@ class Mnist(torch.nn.Module):
return F.log_softmax(x, dim=1)
def _get_conv_weight_sparsity(self, conv_layer):
num_zero_filters = (conv_layer.weight.data.sum((2,3)) == 0).sum()
num_filters = conv_layer.weight.data.size(0) * conv_layer.weight.data.size(1)
num_zero_filters = (conv_layer.weight.data.sum((1, 2, 3)) == 0).sum()
num_filters = conv_layer.weight.data.size(0)
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters
def print_conv_filter_sparsity(self):
......@@ -41,7 +40,8 @@ def train(model, device, train_loader, optimizer):
output = model(data)
loss = F.nll_loss(output, target)
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
print('{:.2f}% Loss {:.4f}'.format(100 * batch_idx / len(train_loader), loss.item()))
if batch_idx == 0:
model.print_conv_filter_sparsity()
loss.backward()
optimizer.step()
......@@ -59,7 +59,7 @@ def test(model, device, test_loader):
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
print('Loss: {:.4f} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
......@@ -78,9 +78,6 @@ def main():
model = Mnist()
model.print_conv_filter_sparsity()
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
'''
configure_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
......@@ -96,6 +93,7 @@ def main():
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
pruner.export_model('model.pth', 'mask.pth')
if __name__ == '__main__':
main()
......@@ -26,7 +26,7 @@ class fc1(nn.Module):
def train(model, train_loader, optimizer, criterion):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for batch_idx, (imgs, targets) in enumerate(train_loader):
for imgs, targets in train_loader:
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
......@@ -64,7 +64,7 @@ if __name__ == '__main__':
criterion = nn.CrossEntropyLoss()
configure_list = [{
'prune_iterations': 10,
'prune_iterations': 5,
'sparsity': 0.96,
'op_types': ['default']
}]
......@@ -75,7 +75,7 @@ if __name__ == '__main__':
pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(50):
for epoch in range(10):
loss = train(model, train_loader, optimizer, criterion)
accuracy = test(model, test_loader, criterion)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
......
from nni.compression.tensorflow import AGP_Pruner
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def weight_variable(shape):
return tf.Variable(tf.truncated_normal(shape, stddev=0.1))
def bias_variable(shape):
return tf.Variable(tf.constant(0.1, shape=shape))
def conv2d(x_input, w_matrix):
return tf.nn.conv2d(x_input, w_matrix, strides=[1, 1, 1, 1], padding='SAME')
def max_pool(x_input, pool_size):
size = [1, pool_size, pool_size, 1]
return tf.nn.max_pool(x_input, ksize=size, strides=size, padding='SAME')
class Mnist:
def __init__(self):
images = tf.placeholder(tf.float32, [None, 784], name='input_x')
labels = tf.placeholder(tf.float32, [None, 10], name='input_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
self.images = images
self.labels = labels
self.keep_prob = keep_prob
self.train_step = None
self.accuracy = None
self.w1 = None
self.b1 = None
self.fcw1 = None
self.cross = None
with tf.name_scope('reshape'):
x_image = tf.reshape(images, [-1, 28, 28, 1])
with tf.name_scope('conv1'):
w_conv1 = weight_variable([5, 5, 1, 32])
self.w1 = w_conv1
b_conv1 = bias_variable([32])
self.b1 = b_conv1
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
with tf.name_scope('pool1'):
h_pool1 = max_pool(h_conv1, 2)
with tf.name_scope('conv2'):
w_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
with tf.name_scope('pool2'):
h_pool2 = max_pool(h_conv2, 2)
with tf.name_scope('fc1'):
w_fc1 = weight_variable([7 * 7 * 64, 1024])
self.fcw1 = w_fc1
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
with tf.name_scope('dropout'):
h_fc1_drop = tf.nn.dropout(h_fc1, 0.5)
with tf.name_scope('fc2'):
w_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2
with tf.name_scope('loss'):
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=y_conv))
self.cross = cross_entropy
with tf.name_scope('adam_optimizer'):
self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
with tf.name_scope('accuracy'):
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
def main():
tf.set_random_seed(0)
data = input_data.read_data_sets('data', one_hot=True)
model = Mnist()
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
'''
configure_list = [{
'initial_sparsity': 0,
'final_sparsity': 0.8,
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_types': ['default']
}]
pruner = AGP_Pruner(tf.get_default_graph(), configure_list)
# if you want to load from yaml file
# configure_file = nni.compressors.tf_compressor._nnimc_tf._tf_default_load_configure_file('configure_example.yaml','AGPruner')
# configure_list = configure_file.get('config',[])
# pruner.load_configure(configure_list)
# you can also handle it yourself and input an configure list in json
pruner.compress()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for batch_idx in range(2000):
if batch_idx % 10 == 0:
pruner.update_epoch(batch_idx / 10, sess)
batch = data.train.next_batch(2000)
model.train_step.run(feed_dict={
model.images: batch[0],
model.labels: batch[1],
model.keep_prob: 0.5
})
if batch_idx % 10 == 0:
test_acc = model.accuracy.eval(feed_dict={
model.images: data.test.images,
model.labels: data.test.labels,
model.keep_prob: 1.0
})
print('test accuracy', test_acc)
test_acc = model.accuracy.eval(feed_dict={
model.images: data.test.images,
model.labels: data.test.labels,
model.keep_prob: 1.0
})
print('final result is', test_acc)
if __name__ == '__main__':
main()
from nni.compression.tensorflow import QAT_Quantizer
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def weight_variable(shape):
return tf.Variable(tf.truncated_normal(shape, stddev = 0.1))
def bias_variable(shape):
return tf.Variable(tf.constant(0.1, shape = shape))
def conv2d(x_input, w_matrix):
return tf.nn.conv2d(x_input, w_matrix, strides = [ 1, 1, 1, 1 ], padding = 'SAME')
def max_pool(x_input, pool_size):
size = [ 1, pool_size, pool_size, 1 ]
return tf.nn.max_pool(x_input, ksize = size, strides = size, padding = 'SAME')
class Mnist:
def __init__(self):
images = tf.placeholder(tf.float32, [ None, 784 ], name = 'input_x')
labels = tf.placeholder(tf.float32, [ None, 10 ], name = 'input_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
self.images = images
self.labels = labels
self.keep_prob = keep_prob
self.train_step = None
self.accuracy = None
self.w1 = None
self.b1 = None
self.fcw1 = None
self.cross = None
with tf.name_scope('reshape'):
x_image = tf.reshape(images, [ -1, 28, 28, 1 ])
with tf.name_scope('conv1'):
w_conv1 = weight_variable([ 5, 5, 1, 32 ])
self.w1 = w_conv1
b_conv1 = bias_variable([ 32 ])
self.b1 = b_conv1
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
with tf.name_scope('pool1'):
h_pool1 = max_pool(h_conv1, 2)
with tf.name_scope('conv2'):
w_conv2 = weight_variable([ 5, 5, 32, 64 ])
b_conv2 = bias_variable([ 64 ])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
with tf.name_scope('pool2'):
h_pool2 = max_pool(h_conv2, 2)
with tf.name_scope('fc1'):
w_fc1 = weight_variable([ 7 * 7 * 64, 1024 ])
self.fcw1 = w_fc1
b_fc1 = bias_variable([ 1024 ])
h_pool2_flat = tf.reshape(h_pool2, [ -1, 7 * 7 * 64 ])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
with tf.name_scope('dropout'):
h_fc1_drop = tf.nn.dropout(h_fc1, 0.5)
with tf.name_scope('fc2'):
w_fc2 = weight_variable([ 1024, 10 ])
b_fc2 = bias_variable([ 10 ])
y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2
with tf.name_scope('loss'):
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = labels, logits = y_conv))
self.cross = cross_entropy
with tf.name_scope('adam_optimizer'):
self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
with tf.name_scope('accuracy'):
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
def main():
tf.set_random_seed(0)
data = input_data.read_data_sets('data', one_hot = True)
model = Mnist()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(tf.get_default_graph())
'''
configure_list = [{'q_bits':8, 'op_types':['default']}]
quantizer = QAT_Quantizer(tf.get_default_graph(), configure_list)
quantizer.compress()
# you can also use compress(model) or compress_default_graph()
# method like QATquantizer(q_bits = 8).compress_default_graph()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for batch_idx in range(2000):
batch = data.train.next_batch(2000)
model.train_step.run(feed_dict = {
model.images: batch[0],
model.labels: batch[1],
model.keep_prob: 0.5
})
if batch_idx % 10 == 0:
test_acc = model.accuracy.eval(feed_dict = {
model.images: data.test.images,
model.labels: data.test.labels,
model.keep_prob: 1.0
})
print('test accuracy', test_acc)
test_acc = model.accuracy.eval(feed_dict = {
model.images: data.test.images,
model.labels: data.test.labels,
model.keep_prob: 1.0
})
print('final result is', test_acc)
if __name__ == '__main__':
main()
......@@ -82,8 +82,6 @@ def main():
pruner = AGP_Pruner(model, configure_list)
model = pruner.compress()
# you can also use compress(model) method
# like that pruner.compress(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10):
......
from nni.compression.torch import QAT_Quantizer
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim = 1)
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction = 'sum').item()
pred = output.argmax(dim = 1, keepdim = True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
def main():
torch.manual_seed(0)
device = torch.device('cpu')
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train = True, download = True, transform = trans),
batch_size = 64, shuffle = True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train = False, transform = trans),
batch_size = 1000, shuffle = True)
model = Mnist()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
configure_list = [{'q_bits':8, 'op_types':['default']}]
quantizer = QAT_Quantizer(model, configure_list)
quantizer.compress()
# you can also use compress(model) method
# like thaht quantizer.compress(model)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.5)
for epoch in range(10):
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
if __name__ == '__main__':
main()
......@@ -27,11 +27,14 @@ if __name__ == "__main__":
parser = ArgumentParser("pdarts")
parser.add_argument('--add_layers', action='append',
default=[0, 6, 12], help='add layers')
parser.add_argument('--dropped_ops', action='append',
default=[3, 2, 1], help='drop ops')
parser.add_argument("--nodes", default=4, type=int)
parser.add_argument("--layers", default=5, type=int)
parser.add_argument("--init_layers", default=5, type=int)
parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--unrolled", default=False, action="store_true")
args = parser.parse_args()
logger.info("loading data")
......@@ -48,15 +51,16 @@ if __name__ == "__main__":
logger.info("initializing trainer")
trainer = PdartsTrainer(model_creator,
layers=args.layers,
init_layers=args.init_layers,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
pdarts_num_layers=[0, 6, 12],
pdarts_num_to_drop=[3, 2, 2],
pdarts_num_layers=args.add_layers,
pdarts_num_to_drop=args.dropped_ops,
num_epochs=args.epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled,
callbacks=[ArchitectureCheckpoint("./checkpoints")])
logger.info("training")
trainer.train()
......@@ -40,7 +40,7 @@ setup(
'schema',
'PythonWebHDFS',
'colorama',
'sklearn'
'scikit-learn==0.20'
],
entry_points = {
......
{
"env": {
"browser": true,
"node": true,
"es6": true
},
"parser": "@typescript-eslint/parser",
"parserOptions": {
"ecmaVersion": 2018,
"sourceType": "module"
},
"plugins": [
"@typescript-eslint"
],
"extends": [
"eslint:recommended",
"plugin:@typescript-eslint/eslint-recommended",
"plugin:@typescript-eslint/recommended"
],
"rules": {
"@typescript-eslint/no-explicit-any": 0
},
"ignorePatterns": [
"node_modules/",
"test/",
"dist/",
"types/",
"**/*.js"
]
}
......@@ -6,7 +6,7 @@
"build": "tsc",
"test": "nyc mocha -r ts-node/register -t 15000 --recursive **/*.test.ts --exclude node_modules/**/**/*.test.ts --colors",
"start": "node dist/main.js",
"tslint": "tslint -p ."
"eslint": "npx eslint ./ --ext .ts"
},
"license": "MIT",
"dependencies": {
......@@ -42,9 +42,13 @@
"@types/ssh2": "^0.5.35",
"@types/stream-buffers": "^3.0.2",
"@types/tmp": "^0.0.33",
"@typescript-eslint/eslint-plugin": "^2.10.0",
"@typescript-eslint/parser": "^2.10.0",
"chai": "^4.1.2",
"eslint": "^6.7.2",
"glob": "^7.1.3",
"mocha": "^5.2.0",
"npx": "^10.2.0",
"nyc": "^13.1.0",
"request": "^2.87.0",
"rmdir": "^1.2.0",
......
......@@ -9,12 +9,12 @@ import { GeneralK8sClient, KubernetesCRDClient } from '../kubernetesApiClient';
/**
* FrameworkController Client
*/
abstract class FrameworkControllerClient extends KubernetesCRDClient {
class FrameworkControllerClientFactory {
/**
* Factory method to generate operator client
*/
// tslint:disable-next-line:function-name
public static generateFrameworkControllerClient(): KubernetesCRDClient {
public static createClient(): KubernetesCRDClient {
return new FrameworkControllerClientV1();
}
}
......@@ -22,7 +22,7 @@ abstract class FrameworkControllerClient extends KubernetesCRDClient {
/**
* FrameworkController ClientV1
*/
class FrameworkControllerClientV1 extends FrameworkControllerClient {
class FrameworkControllerClientV1 extends KubernetesCRDClient {
/**
* constructor, to initialize frameworkcontroller CRD definition
*/
......@@ -43,4 +43,4 @@ class FrameworkControllerClientV1 extends FrameworkControllerClient {
}
}
export { FrameworkControllerClient, GeneralK8sClient };
export { FrameworkControllerClientFactory, GeneralK8sClient };
......@@ -19,7 +19,7 @@ import { AzureStorageClientUtility } from '../azureStorageClientUtils';
import { NFSConfig } from '../kubernetesConfig';
import { KubernetesTrialJobDetail } from '../kubernetesData';
import { KubernetesTrainingService } from '../kubernetesTrainingService';
import { FrameworkControllerClient } from './frameworkcontrollerApiClient';
import { FrameworkControllerClientFactory } from './frameworkcontrollerApiClient';
import { FrameworkControllerClusterConfig, FrameworkControllerClusterConfigAzure, FrameworkControllerClusterConfigFactory,
FrameworkControllerClusterConfigNFS, FrameworkControllerTrialConfig} from './frameworkcontrollerConfig';
import { FrameworkControllerJobInfoCollector } from './frameworkcontrollerJobInfoCollector';
......@@ -142,7 +142,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
nfsFrameworkControllerClusterConfig.nfs.path
);
}
this.kubernetesCRDClient = FrameworkControllerClient.generateFrameworkControllerClient();
this.kubernetesCRDClient = FrameworkControllerClientFactory.createClient();
break;
case TrialConfigMetadataKey.TRIAL_CONFIG:
const frameworkcontrollerTrialJsonObjsect: any = JSON.parse(value);
......
......@@ -7,55 +7,9 @@ import * as fs from 'fs';
import { GeneralK8sClient, KubernetesCRDClient } from '../kubernetesApiClient';
import { KubeflowOperator } from './kubeflowConfig';
/**
* KubeflowOperator Client
*/
abstract class KubeflowOperatorClient extends KubernetesCRDClient {
/**
* Factory method to generate operator client
*/
// tslint:disable-next-line:function-name
public static generateOperatorClient(kubeflowOperator: KubeflowOperator,
operatorApiVersion: string): KubernetesCRDClient {
switch (kubeflowOperator) {
case 'tf-operator': {
switch (operatorApiVersion) {
case 'v1alpha2': {
return new TFOperatorClientV1Alpha2();
}
case 'v1beta1': {
return new TFOperatorClientV1Beta1();
}
case 'v1beta2': {
return new TFOperatorClientV1Beta2();
}
default:
throw new Error(`Invalid tf-operator apiVersion ${operatorApiVersion}`);
}
}
case 'pytorch-operator': {
switch (operatorApiVersion) {
case 'v1alpha2': {
return new PyTorchOperatorClientV1Alpha2();
}
case 'v1beta1': {
return new PyTorchOperatorClientV1Beta1();
}
case 'v1beta2': {
return new PyTorchOperatorClientV1Beta2();
}
default:
throw new Error(`Invalid pytorch-operator apiVersion ${operatorApiVersion}`);
}
}
default:
throw new Error(`Invalid operator ${kubeflowOperator}`);
}
}
}
// tslint:disable: no-unsafe-any no-any completed-docs
class TFOperatorClientV1Alpha2 extends KubeflowOperatorClient {
class TFOperatorClientV1Alpha2 extends KubernetesCRDClient {
/**
* constructor, to initialize tfjob CRD definition
*/
......@@ -112,7 +66,7 @@ class TFOperatorClientV1Beta2 extends KubernetesCRDClient {
}
}
class PyTorchOperatorClientV1Alpha2 extends KubeflowOperatorClient {
class PyTorchOperatorClientV1Alpha2 extends KubernetesCRDClient {
/**
* constructor, to initialize tfjob CRD definition
*/
......@@ -169,5 +123,51 @@ class PyTorchOperatorClientV1Beta2 extends KubernetesCRDClient {
}
}
/**
* KubeflowOperator Client
*/
class KubeflowOperatorClientFactory {
/**
* Factory method to generate operator client
*/
// tslint:disable-next-line:function-name
public static createClient(kubeflowOperator: KubeflowOperator, operatorApiVersion: string): KubernetesCRDClient {
switch (kubeflowOperator) {
case 'tf-operator': {
switch (operatorApiVersion) {
case 'v1alpha2': {
return new TFOperatorClientV1Alpha2();
}
case 'v1beta1': {
return new TFOperatorClientV1Beta1();
}
case 'v1beta2': {
return new TFOperatorClientV1Beta2();
}
default:
throw new Error(`Invalid tf-operator apiVersion ${operatorApiVersion}`);
}
}
case 'pytorch-operator': {
switch (operatorApiVersion) {
case 'v1alpha2': {
return new PyTorchOperatorClientV1Alpha2();
}
case 'v1beta1': {
return new PyTorchOperatorClientV1Beta1();
}
case 'v1beta2': {
return new PyTorchOperatorClientV1Beta2();
}
default:
throw new Error(`Invalid pytorch-operator apiVersion ${operatorApiVersion}`);
}
}
default:
throw new Error(`Invalid operator ${kubeflowOperator}`);
}
}
}
// tslint:enable: no-unsafe-any
export { KubeflowOperatorClient, GeneralK8sClient };
export { KubeflowOperatorClientFactory, GeneralK8sClient };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment