Unverified Commit 358efb26 authored by Yan Ni's avatar Yan Ni Committed by GitHub
Browse files

Dev weight sharing (#568) (#576)

* Dev weight sharing (#568)

* add pycharm project files to .gitignore list

* update pylintrc to conform vscode settings

* fix RemoteMachineMode for wrong trainingServicePlatform

* simple weight sharing

* update gitignore file

* change tuner codedir to relative path

* add python cache files to gitignore list

* move extract scalar reward logic from dispatcher to tuner

* update tuner code corresponding to last commit

* update doc for receive_trial_result api change

* add numpy to package whitelist of pylint

* distinguish param value from return reward for tuner.extract_scalar_reward

* update pylintrc

* add comments to dispatcher.handle_report_metric_data

* update install for mac support

* fix root mode bug on Makefile

* Quick fix bug: nnictl port value error (#245)

* fix port bug

* Dev exp stop more (#221)

* Exp stop refactor (#161)

* Update RemoteMachineMode.md (#63)

* Remove unused classes for SQuAD QA exampl...
parent e6eb6eab
......@@ -97,6 +97,7 @@ class MsgDispatcher(MsgDispatcherBase):
def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
_logger.debug("requesting for generating params of {}".format(ids))
params_list = self.tuner.generate_multiple_parameters(ids)
for i, _ in enumerate(params_list):
......
......@@ -19,10 +19,14 @@
# ==================================================================================================
#import json_tricks
import os
import logging
import json_tricks
import os
from queue import Queue
import sys
from multiprocessing.dummy import Pool as ThreadPool
import json_tricks
from .common import init_logger, multi_thread_enabled
from .recoverable import Recoverable
from .protocol import CommandType, receive
......@@ -34,6 +38,7 @@ class MsgDispatcherBase(Recoverable):
def __init__(self):
if multi_thread_enabled():
self.pool = ThreadPool()
self.thread_results = []
def run(self):
"""Run the tuner.
......@@ -49,7 +54,11 @@ class MsgDispatcherBase(Recoverable):
if command is None or command is CommandType.Terminate:
break
if multi_thread_enabled():
self.pool.map_async(self.handle_request, [(command, data)])
result = self.pool.map_async(self.handle_request_thread, [(command, data)])
self.thread_results.append(result)
if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]):
_logger.debug('Caught thread exception')
break
else:
self.handle_request((command, data))
......@@ -59,6 +68,16 @@ class MsgDispatcherBase(Recoverable):
_logger.info('Terminated by NNI manager')
def handle_request_thread(self, request):
if multi_thread_enabled():
try:
self.handle_request(request)
except Exception as e:
_logger.exception(str(e))
raise
else:
pass
def handle_request(self, request):
command, data = request
......
......@@ -48,6 +48,7 @@ class Tuner(Recoverable):
result = []
for parameter_id in parameter_id_list:
try:
_logger.debug("generating param for {}".format(parameter_id))
res = self.generate_parameters(parameter_id)
except nni.NoMoreTrialError:
return result
......
authorName: default
experimentName: example_weight_sharing
trialConcurrency: 3
maxExecDuration: 1h
maxTrialNum: 10
#choice: local, remote, pai
trainingServicePlatform: remote
#choice: true, false
useAnnotation: false
multiThread: true
tuner:
codeDir: .
classFileName: simple_tuner.py
className: SimpleTuner
trial:
command: python3 main.py
codeDir: .
gpuNum: 0
machineList:
- ip: 10.10.10.10
username: bob
passwd: bob123
- ip: 10.10.10.11
username: bob
passwd: bob123
"""
Test code for weight sharing
need NFS setup and mounted as `/mnt/nfs/nni`
"""
import hashlib
import os
import random
import time
import nni
def generate_rand_file(fl_name):
"""
generate random file and write to `fl_name`
"""
fl_size = random.randint(1024, 102400)
fl_dir = os.path.split(fl_name)[0]
if not os.path.exists(fl_dir):
os.makedirs(fl_dir)
with open(fl_name, 'wb') as fout:
fout.write(os.urandom(fl_size))
def check_sum(fl_name, tid=None):
"""
compute checksum for generated file of `fl_name`
"""
hasher = hashlib.md5()
with open(fl_name, 'rb') as fin:
for chunk in iter(lambda: fin.read(4096), b""):
hasher.update(chunk)
ret = hasher.hexdigest()
if tid is not None:
ret = ret + str(tid)
return ret
if __name__ == '__main__':
nfs_path = '/mnt/nfs/nni/test'
params = nni.get_next_parameter()
print(params)
if params['id'] == 0:
model_file = os.path.join(nfs_path, str(params['id']), 'model.dat')
generate_rand_file(model_file)
time.sleep(10)
nni.report_final_result({
'checksum': check_sum(model_file, tid=params['id']),
'path': model_file
})
else:
model_file = params['prev_path']
time.sleep(10)
nni.report_final_result({
'checksum': check_sum(model_file, tid=params['prev_id'])
})
"""
SimpleTuner for Weight Sharing
"""
import logging
from threading import Event, Lock
from nni.tuner import Tuner
_logger = logging.getLogger('WeightSharingTuner')
class SimpleTuner(Tuner):
"""
simple tuner, test for weight sharing
"""
def __init__(self):
super(SimpleTuner, self).__init__()
self.trial_meta = {}
self.f_id = None # father
self.sig_event = Event()
self.thread_lock = Lock()
def generate_parameters(self, parameter_id):
if self.f_id is None:
self.thread_lock.acquire()
self.f_id = parameter_id
self.trial_meta[parameter_id] = {
'prev_id': 0,
'id': parameter_id,
'checksum': None,
'path': '',
}
_logger.info('generate parameter for father trial %s' %
parameter_id)
self.thread_lock.release()
return {
'prev_id': 0,
'id': parameter_id,
}
else:
self.sig_event.wait()
self.thread_lock.acquire()
self.trial_meta[parameter_id] = {
'id': parameter_id,
'prev_id': self.f_id,
'prev_path': self.trial_meta[self.f_id]['path']
}
self.thread_lock.release()
return self.trial_meta[parameter_id]
def receive_trial_result(self, parameter_id, parameters, reward):
self.thread_lock.acquire()
if parameter_id == self.f_id:
self.trial_meta[parameter_id]['checksum'] = reward['checksum']
self.trial_meta[parameter_id]['path'] = reward['path']
self.sig_event.set()
else:
if reward['checksum'] != self.trial_meta[self.f_id]['checksum']:
raise ValueError("Inconsistency in weight sharing: {} != {}".format(
reward['checksum'], self.trial_meta[self.f_id]['checksum']))
self.thread_lock.release()
def update_search_space(self, search_space):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment