Unverified Commit 58cd8c76 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Increase IPC message length to 10^14 (#2425)

parent 6f19d3ca
......@@ -23,11 +23,7 @@ const ipcIncomingFd: number = 4;
*/
function encodeCommand(commandType: string, content: string): Buffer {
const contentBuffer: Buffer = Buffer.from(content);
if (contentBuffer.length >= 1_000_000) {
throw new RangeError('Command too long');
}
const contentLengthBuffer: Buffer = Buffer.from(contentBuffer.length.toString().padStart(6, '0'));
const contentLengthBuffer: Buffer = Buffer.from(contentBuffer.length.toString().padStart(14, '0'));
return Buffer.concat([Buffer.from(commandType), contentLengthBuffer, contentBuffer]);
}
......@@ -43,12 +39,12 @@ function decodeCommand(data: Buffer): [boolean, string, string, Buffer] {
return [false, '', '', data];
}
const commandType: string = data.slice(0, 2).toString();
const contentLength: number = parseInt(data.slice(2, 8).toString(), 10);
if (data.length < contentLength + 8) {
const contentLength: number = parseInt(data.slice(2, 16).toString(), 10);
if (data.length < contentLength + 16) {
return [false, '', '', data];
}
const content: string = data.slice(8, contentLength + 8).toString();
const remain: Buffer = data.slice(contentLength + 8);
const content: string = data.slice(16, contentLength + 16).toString();
const remain: Buffer = data.slice(contentLength + 16);
return [true, commandType, content, remain];
}
......
......@@ -8,13 +8,13 @@ _out_file = open(4, 'wb')
def send(command, data):
command = command.encode('utf8')
data = data.encode('utf8')
msg = b'%b%06d%b' % (command, len(data), data)
msg = b'%b%14d%b' % (command, len(data), data)
_out_file.write(msg)
_out_file.flush()
def receive():
header = _in_file.read(8)
header = _in_file.read(16)
l = int(header[2:])
command = header[:2].decode('utf8')
data = _in_file.read(l).decode('utf8')
......
......@@ -14,7 +14,6 @@ import { NNIError } from '../../common/errors';
let sentCommands: { [key: string]: string }[] = [];
const receivedCommands: { [key: string]: string }[] = [];
let commandTooLong: Error | undefined;
let rejectCommandType: Error | undefined;
function runProcess(): Promise<Error | null> {
......@@ -54,14 +53,7 @@ function runProcess(): Promise<Error | null> {
// Command #2: ok
dispatcher.sendCommand('ME', '123');
// Command #3: too long
try {
dispatcher.sendCommand('ME', 'x'.repeat(1_000_000));
} catch (error) {
commandTooLong = error;
}
// Command #4: FE is not tuner/assessor command, test the exception type of send non-valid command
// Command #3: FE is not tuner/assessor command, test the exception type of send non-valid command
try {
dispatcher.sendCommand('FE', '1');
} catch (error) {
......@@ -88,21 +80,11 @@ describe('core/protocol', (): void => {
});
it('sendCommand() should work without content', (): void => {
assert.equal(sentCommands[0], '(\'IN\', \'\')');
assert.equal(sentCommands[0], "('IN', '')");
});
it('sendCommand() should work with content', (): void => {
assert.equal(sentCommands[1], '(\'ME\', \'123\')');
});
it('sendCommand() should throw on too long command', (): void => {
if (commandTooLong === undefined) {
assert.fail('Should throw error')
} else {
const err: Error | undefined = (<NNIError>commandTooLong).cause;
assert(err && err.name === 'RangeError');
assert(err && err.message === 'Command too long');
}
assert.equal(sentCommands[1], "('ME', '123')");
});
it('sendCommand() should throw on wrong command type', (): void => {
......
......@@ -43,8 +43,7 @@ def send(command, data):
try:
_lock.acquire()
data = data.encode('utf8')
assert len(data) < 1000000, 'Command too long'
msg = b'%b%06d%b' % (command.value, len(data), data)
msg = b'%b%014d%b' % (command.value, len(data), data)
logging.getLogger(__name__).debug('Sending command, data: [%s]', msg)
_out_file.write(msg)
_out_file.flush()
......@@ -56,9 +55,9 @@ def receive():
"""Receive a command from Training Service.
Returns a tuple of command (CommandType) and payload (str)
"""
header = _in_file.read(8)
header = _in_file.read(16)
logging.getLogger(__name__).debug('Received command, header: [%s]', header)
if header is None or len(header) < 8:
if header is None or len(header) < 16:
# Pipe EOF encountered
logging.getLogger(__name__).debug('Pipe EOF encountered')
return None, None
......
......@@ -20,30 +20,21 @@ class ProtocolTestCase(TestCase):
def test_send_en(self):
out_file = _prepare_send()
send(CommandType.NewTrialJob, 'CONTENT')
self.assertEqual(out_file.getvalue(), b'TR000007CONTENT')
self.assertEqual(out_file.getvalue(), b'TR00000000000007CONTENT')
def test_send_zh(self):
out_file = _prepare_send()
send(CommandType.NewTrialJob, '你好')
self.assertEqual(out_file.getvalue(), 'TR000006你好'.encode('utf8'))
def test_send_too_large(self):
_prepare_send()
exception = None
try:
send(CommandType.NewTrialJob, ' ' * 1000000)
except AssertionError as e:
exception = e
self.assertIsNotNone(exception)
self.assertEqual(out_file.getvalue(), 'TR00000000000006你好'.encode('utf8'))
def test_receive_en(self):
_prepare_receive(b'IN000005hello')
_prepare_receive(b'IN00000000000005hello')
command, data = receive()
self.assertIs(command, CommandType.Initialize)
self.assertEqual(data, 'hello')
def test_receive_zh(self):
_prepare_receive('IN000006世界'.encode('utf8'))
_prepare_receive('IN00000000000006世界'.encode('utf8'))
command, data = receive()
self.assertIs(command, CommandType.Initialize)
self.assertEqual(data, '世界')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment