Unverified Commit 58cd8c76 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Increase IPC message length to 10^14 (#2425)

parent 6f19d3ca
...@@ -23,11 +23,7 @@ const ipcIncomingFd: number = 4; ...@@ -23,11 +23,7 @@ const ipcIncomingFd: number = 4;
*/ */
function encodeCommand(commandType: string, content: string): Buffer { function encodeCommand(commandType: string, content: string): Buffer {
const contentBuffer: Buffer = Buffer.from(content); const contentBuffer: Buffer = Buffer.from(content);
if (contentBuffer.length >= 1_000_000) { const contentLengthBuffer: Buffer = Buffer.from(contentBuffer.length.toString().padStart(14, '0'));
throw new RangeError('Command too long');
}
const contentLengthBuffer: Buffer = Buffer.from(contentBuffer.length.toString().padStart(6, '0'));
return Buffer.concat([Buffer.from(commandType), contentLengthBuffer, contentBuffer]); return Buffer.concat([Buffer.from(commandType), contentLengthBuffer, contentBuffer]);
} }
...@@ -43,12 +39,12 @@ function decodeCommand(data: Buffer): [boolean, string, string, Buffer] { ...@@ -43,12 +39,12 @@ function decodeCommand(data: Buffer): [boolean, string, string, Buffer] {
return [false, '', '', data]; return [false, '', '', data];
} }
const commandType: string = data.slice(0, 2).toString(); const commandType: string = data.slice(0, 2).toString();
const contentLength: number = parseInt(data.slice(2, 8).toString(), 10); const contentLength: number = parseInt(data.slice(2, 16).toString(), 10);
if (data.length < contentLength + 8) { if (data.length < contentLength + 16) {
return [false, '', '', data]; return [false, '', '', data];
} }
const content: string = data.slice(8, contentLength + 8).toString(); const content: string = data.slice(16, contentLength + 16).toString();
const remain: Buffer = data.slice(contentLength + 8); const remain: Buffer = data.slice(contentLength + 16);
return [true, commandType, content, remain]; return [true, commandType, content, remain];
} }
......
...@@ -8,13 +8,13 @@ _out_file = open(4, 'wb') ...@@ -8,13 +8,13 @@ _out_file = open(4, 'wb')
def send(command, data): def send(command, data):
command = command.encode('utf8') command = command.encode('utf8')
data = data.encode('utf8') data = data.encode('utf8')
msg = b'%b%06d%b' % (command, len(data), data) msg = b'%b%14d%b' % (command, len(data), data)
_out_file.write(msg) _out_file.write(msg)
_out_file.flush() _out_file.flush()
def receive(): def receive():
header = _in_file.read(8) header = _in_file.read(16)
l = int(header[2:]) l = int(header[2:])
command = header[:2].decode('utf8') command = header[:2].decode('utf8')
data = _in_file.read(l).decode('utf8') data = _in_file.read(l).decode('utf8')
......
...@@ -14,7 +14,6 @@ import { NNIError } from '../../common/errors'; ...@@ -14,7 +14,6 @@ import { NNIError } from '../../common/errors';
let sentCommands: { [key: string]: string }[] = []; let sentCommands: { [key: string]: string }[] = [];
const receivedCommands: { [key: string]: string }[] = []; const receivedCommands: { [key: string]: string }[] = [];
let commandTooLong: Error | undefined;
let rejectCommandType: Error | undefined; let rejectCommandType: Error | undefined;
function runProcess(): Promise<Error | null> { function runProcess(): Promise<Error | null> {
...@@ -54,14 +53,7 @@ function runProcess(): Promise<Error | null> { ...@@ -54,14 +53,7 @@ function runProcess(): Promise<Error | null> {
// Command #2: ok // Command #2: ok
dispatcher.sendCommand('ME', '123'); dispatcher.sendCommand('ME', '123');
// Command #3: too long // Command #3: FE is not tuner/assessor command, test the exception type of send non-valid command
try {
dispatcher.sendCommand('ME', 'x'.repeat(1_000_000));
} catch (error) {
commandTooLong = error;
}
// Command #4: FE is not tuner/assessor command, test the exception type of send non-valid command
try { try {
dispatcher.sendCommand('FE', '1'); dispatcher.sendCommand('FE', '1');
} catch (error) { } catch (error) {
...@@ -88,21 +80,11 @@ describe('core/protocol', (): void => { ...@@ -88,21 +80,11 @@ describe('core/protocol', (): void => {
}); });
it('sendCommand() should work without content', (): void => { it('sendCommand() should work without content', (): void => {
assert.equal(sentCommands[0], '(\'IN\', \'\')'); assert.equal(sentCommands[0], "('IN', '')");
}); });
it('sendCommand() should work with content', (): void => { it('sendCommand() should work with content', (): void => {
assert.equal(sentCommands[1], '(\'ME\', \'123\')'); assert.equal(sentCommands[1], "('ME', '123')");
});
it('sendCommand() should throw on too long command', (): void => {
if (commandTooLong === undefined) {
assert.fail('Should throw error')
} else {
const err: Error | undefined = (<NNIError>commandTooLong).cause;
assert(err && err.name === 'RangeError');
assert(err && err.message === 'Command too long');
}
}); });
it('sendCommand() should throw on wrong command type', (): void => { it('sendCommand() should throw on wrong command type', (): void => {
......
...@@ -43,8 +43,7 @@ def send(command, data): ...@@ -43,8 +43,7 @@ def send(command, data):
try: try:
_lock.acquire() _lock.acquire()
data = data.encode('utf8') data = data.encode('utf8')
assert len(data) < 1000000, 'Command too long' msg = b'%b%014d%b' % (command.value, len(data), data)
msg = b'%b%06d%b' % (command.value, len(data), data)
logging.getLogger(__name__).debug('Sending command, data: [%s]', msg) logging.getLogger(__name__).debug('Sending command, data: [%s]', msg)
_out_file.write(msg) _out_file.write(msg)
_out_file.flush() _out_file.flush()
...@@ -56,9 +55,9 @@ def receive(): ...@@ -56,9 +55,9 @@ def receive():
"""Receive a command from Training Service. """Receive a command from Training Service.
Returns a tuple of command (CommandType) and payload (str) Returns a tuple of command (CommandType) and payload (str)
""" """
header = _in_file.read(8) header = _in_file.read(16)
logging.getLogger(__name__).debug('Received command, header: [%s]', header) logging.getLogger(__name__).debug('Received command, header: [%s]', header)
if header is None or len(header) < 8: if header is None or len(header) < 16:
# Pipe EOF encountered # Pipe EOF encountered
logging.getLogger(__name__).debug('Pipe EOF encountered') logging.getLogger(__name__).debug('Pipe EOF encountered')
return None, None return None, None
......
...@@ -20,30 +20,21 @@ class ProtocolTestCase(TestCase): ...@@ -20,30 +20,21 @@ class ProtocolTestCase(TestCase):
def test_send_en(self): def test_send_en(self):
out_file = _prepare_send() out_file = _prepare_send()
send(CommandType.NewTrialJob, 'CONTENT') send(CommandType.NewTrialJob, 'CONTENT')
self.assertEqual(out_file.getvalue(), b'TR000007CONTENT') self.assertEqual(out_file.getvalue(), b'TR00000000000007CONTENT')
def test_send_zh(self): def test_send_zh(self):
out_file = _prepare_send() out_file = _prepare_send()
send(CommandType.NewTrialJob, '你好') send(CommandType.NewTrialJob, '你好')
self.assertEqual(out_file.getvalue(), 'TR000006你好'.encode('utf8')) self.assertEqual(out_file.getvalue(), 'TR00000000000006你好'.encode('utf8'))
def test_send_too_large(self):
_prepare_send()
exception = None
try:
send(CommandType.NewTrialJob, ' ' * 1000000)
except AssertionError as e:
exception = e
self.assertIsNotNone(exception)
def test_receive_en(self): def test_receive_en(self):
_prepare_receive(b'IN000005hello') _prepare_receive(b'IN00000000000005hello')
command, data = receive() command, data = receive()
self.assertIs(command, CommandType.Initialize) self.assertIs(command, CommandType.Initialize)
self.assertEqual(data, 'hello') self.assertEqual(data, 'hello')
def test_receive_zh(self): def test_receive_zh(self):
_prepare_receive('IN000006世界'.encode('utf8')) _prepare_receive('IN00000000000006世界'.encode('utf8'))
command, data = receive() command, data = receive()
self.assertIs(command, CommandType.Initialize) self.assertIs(command, CommandType.Initialize)
self.assertEqual(data, '世界') self.assertEqual(data, '世界')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment