"test/gemm/gemm_dl_int8.cpp" did not exist on "7e9a9d32c7a9259a1bd57b0b461c36d089d26fe8"
Unverified Commit 98c1a77f authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Support multiple HPO experiments in one process (#4855)

parent 5dc80762
......@@ -5,14 +5,12 @@ import json
from io import BytesIO
from unittest import TestCase, main
from nni.runtime import protocol
from nni.runtime import msg_dispatcher_base
from nni.runtime.msg_dispatcher import MsgDispatcher
from nni.runtime.protocol import CommandType, send, receive
from nni.runtime.tuner_command_channel.legacy import *
from nni.tuner import Tuner
from nni.utils import extract_scalar_reward
class NaiveTuner(Tuner):
def __init__(self):
self.param = 0
......@@ -45,15 +43,15 @@ _out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
protocol._set_out_file(_in_buf)
protocol._set_in_file(_out_buf)
_set_out_file(_in_buf)
_set_in_file(_out_buf)
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
protocol._set_in_file(_in_buf)
protocol._set_out_file(_out_buf)
_set_in_file(_in_buf)
_set_out_file(_out_buf)
class MsgDispatcherTestCase(TestCase):
......@@ -68,7 +66,8 @@ class MsgDispatcherTestCase(TestCase):
_restore_io()
tuner = NaiveTuner()
dispatcher = MsgDispatcher(tuner)
dispatcher = MsgDispatcher('ws://_placeholder_', tuner)
dispatcher._channel = LegacyCommandChannel()
msg_dispatcher_base._worker_fast_exit_on_terminate = False
dispatcher.run()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.runtime import protocol
from nni.runtime.protocol import CommandType, send, receive
from io import BytesIO
from unittest import TestCase, main
def _prepare_send():
out_file = BytesIO()
protocol._set_out_file(out_file)
return out_file
def _prepare_receive(data):
protocol._set_in_file(BytesIO(data))
class ProtocolTestCase(TestCase):
def test_send_en(self):
out_file = _prepare_send()
send(CommandType.NewTrialJob, 'CONTENT')
self.assertEqual(out_file.getvalue(), b'TR00000000000007CONTENT')
def test_send_zh(self):
out_file = _prepare_send()
send(CommandType.NewTrialJob, '你好')
self.assertEqual(out_file.getvalue(), 'TR00000000000006你好'.encode('utf8'))
def test_receive_en(self):
_prepare_receive(b'IN00000000000005hello')
command, data = receive()
self.assertIs(command, CommandType.Initialize)
self.assertEqual(data, 'hello')
def test_receive_zh(self):
_prepare_receive('IN00000000000006世界'.encode('utf8'))
command, data = receive()
self.assertIs(command, CommandType.Initialize)
self.assertEqual(data, '世界')
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment