Unverified Commit 12410686 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #20 from microsoft/master

pull code
parents 611a45fc 61fec446
......@@ -241,9 +241,7 @@ def get_id(word_dict, word):
'''
Given word, return word id.
'''
if word in word_dict.keys():
return word_dict[word]
return word_dict['<unk>']
return word_dict.get(word, word_dict['<unk>'])
def get_buckets(min_length, max_length, bucket_count):
......
import numpy as np
from nni.tuner import Tuner
def random_archi_generator(nas_ss, random_state):
'''random
'''
chosen_archi = {}
print("zql: nas search space: ", nas_ss)
for block_name, block in nas_ss.items():
tmp_block = {}
for layer_name, layer in block.items():
tmp_layer = {}
for key, value in layer.items():
if key == 'layer_choice':
index = random_state.randint(len(value))
tmp_layer['chosen_layer'] = value[index]
elif key == 'optional_inputs':
tmp_layer['chosen_inputs'] = []
print("zql: optional_inputs", layer['optional_inputs'])
if layer['optional_inputs']:
if isinstance(layer['optional_input_size'], int):
choice_num = layer['optional_input_size']
else:
choice_range = layer['optional_input_size']
choice_num = random_state.randint(choice_range[0], choice_range[1]+1)
for _ in range(choice_num):
index = random_state.randint(len(layer['optional_inputs']))
tmp_layer['chosen_inputs'].append(layer['optional_inputs'][index])
elif key == 'optional_input_size':
pass
else:
raise ValueError('Unknown field %s in layer %s of block %s' % (key, layer_name, block_name))
tmp_block[layer_name] = tmp_layer
chosen_archi[block_name] = tmp_block
return chosen_archi
class RandomNASTuner(Tuner):
'''RandomNASTuner
'''
def __init__(self):
self.searchspace_json = None
self.random_state = None
def update_search_space(self, search_space):
'''update
'''
self.searchspace_json = search_space
self.random_state = np.random.RandomState()
def generate_parameters(self, parameter_id):
'''generate
'''
return random_archi_generator(self.searchspace_json, self.random_state)
def receive_trial_result(self, parameter_id, parameters, value):
'''receive
'''
pass
......@@ -56,7 +56,8 @@ setup(
'scipy',
'schema',
'PythonWebHDFS',
'colorama'
'colorama',
'sklearn'
],
entry_points = {
......
......@@ -91,6 +91,7 @@ interface TrialJobMetric {
* define TrainingServiceError
*/
class TrainingServiceError extends Error {
private errCode: number;
constructor(errorCode: number, errorMessage: string) {
......@@ -136,5 +137,3 @@ export {
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
HostJobApplicationForm, JobApplicationForm, JobType, NNIManagerIpConfig
};
......@@ -374,6 +374,40 @@ function countFilesRecursively(directory: string, timeoutMilliSeconds?: number):
});
}
function validateFileName(fileName: string): boolean {
let pattern: string = '^[a-z0-9A-Z\.-_]+$';
const validateResult = fileName.match(pattern);
if(validateResult) {
return true;
}
return false;
}
async function validateFileNameRecursively(directory: string): Promise<boolean> {
if(!fs.existsSync(directory)) {
throw Error(`Direcotory ${directory} doesn't exist`);
}
const fileNameArray: string[] = fs.readdirSync(directory);
let result = true;
for(var name of fileNameArray){
const fullFilePath: string = path.join(directory, name);
try {
// validate file names and directory names
result = validateFileName(name);
if (fs.lstatSync(fullFilePath).isDirectory()) {
result = result && await validateFileNameRecursively(fullFilePath);
}
if(!result) {
return Promise.reject(new Error(`file name in ${fullFilePath} is not valid!`));
}
} catch(error) {
return Promise.reject(error);
}
}
return Promise.resolve(result);
}
/**
* get the version of current package
*/
......@@ -474,6 +508,6 @@ function unixPathJoin(...paths: any[]): string {
return dir;
}
export {countFilesRecursively, getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir,
export {countFilesRecursively, validateFileNameRecursively, getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin,
mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment