Unverified Commit 65d735ba authored by Lyu Han's avatar Lyu Han Committed by GitHub
Browse files

Fix wrong eos_id and bos_id obtained through grpc api (#644)

* Fix wrong eos_id and bos_id obtained through grpc api

* fix according to review comments

* update
parent 07640a3a
...@@ -465,7 +465,7 @@ class Chatbot: ...@@ -465,7 +465,7 @@ class Chatbot:
input_lengths = input_lengths - 1 input_lengths = input_lengths - 1
# will crash if last_token_id == eos_id and send empty input_ids # will crash if last_token_id == eos_id and send empty input_ids
if sequence_end and request_output_len == 0: if sequence_end and request_output_len == 0:
input_ids = np.array([[self.bos_id]], dtype=np.uint32) input_ids = np.array([[1]], dtype=np.uint32)
input_lengths = np.array([[1]], dtype=np.uint32) input_lengths = np.array([[1]], dtype=np.uint32)
input_tokens = input_lengths.squeeze() input_tokens = input_lengths.squeeze()
if self.profile_generation: if self.profile_generation:
......
...@@ -42,9 +42,7 @@ class TritonPythonModel: ...@@ -42,9 +42,7 @@ class TritonPythonModel:
self.model_config = model_config = json.loads(args['model_config']) self.model_config = model_config = json.loads(args['model_config'])
# Parse model output configs and convert Triton types to numpy types # Parse model output configs and convert Triton types to numpy types
input_names = [ input_names = ['INPUT_ID', 'REQUEST_INPUT_LEN']
'INPUT_ID', 'REQUEST_INPUT_LEN', 'BAD_WORDS_IDS', 'STOP_WORDS_IDS'
]
for input_name in input_names: for input_name in input_names:
setattr( setattr(
self, self,
...@@ -89,8 +87,6 @@ class TritonPythonModel: ...@@ -89,8 +87,6 @@ class TritonPythonModel:
# Get input tensors # Get input tensors
query = pb_utils.get_input_tensor_by_name(request, query = pb_utils.get_input_tensor_by_name(request,
'QUERY').as_numpy() 'QUERY').as_numpy()
request_output_len = pb_utils.get_input_tensor_by_name(
request, 'REQUEST_OUTPUT_LEN').as_numpy()
# Preprocessing input data. # Preprocessing input data.
input_id, request_input_len = self._create_request(query) input_id, request_input_len = self._create_request(query)
...@@ -104,8 +100,6 @@ class TritonPythonModel: ...@@ -104,8 +100,6 @@ class TritonPythonModel:
'REQUEST_INPUT_LEN', 'REQUEST_INPUT_LEN',
np.array(request_input_len).astype( np.array(request_input_len).astype(
self.request_input_len_dtype)) self.request_input_len_dtype))
request_output_len_tensor = pb_utils.Tensor(
'REQUEST_OUTPUT_LEN', request_output_len)
# Create InferenceResponse. You can set an error here in case # Create InferenceResponse. You can set an error here in case
# there was a problem with handling this inference request. # there was a problem with handling this inference request.
...@@ -114,10 +108,8 @@ class TritonPythonModel: ...@@ -114,10 +108,8 @@ class TritonPythonModel:
# #
# pb_utils.InferenceResponse( # pb_utils.InferenceResponse(
# output_tensors=..., TritonError("An error occurred")) # output_tensors=..., TritonError("An error occurred"))
inference_response = pb_utils.InferenceResponse(output_tensors=[ inference_response = pb_utils.InferenceResponse(
input_id_tensor, request_input_len_tensor, output_tensors=[input_id_tensor, request_input_len_tensor])
request_output_len_tensor
])
responses.append(inference_response) responses.append(inference_response)
# You should return a list of pb_utils.InferenceResponse. Length # You should return a list of pb_utils.InferenceResponse. Length
...@@ -140,10 +132,18 @@ class TritonPythonModel: ...@@ -140,10 +132,18 @@ class TritonPythonModel:
Returns: Returns:
tuple: token ids and their length tuple: token ids and their length
""" """
start_ids = [ start_ids = []
torch.IntTensor(self.tokenizer.encode(s[0].decode())) for s in query:
for s in query _s = s[0].decode()
] if _s == '<BOS>':
start_id = [self.start_id
] if self.start_id is not None else [-1]
elif _s == '<EOS>':
start_id = [self.end_id] if self.end_id is not None else [-1]
else:
start_id = self.tokenizer.encode(_s)
start_ids.append(torch.IntTensor(start_id))
start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids]) start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids])
start_ids = pad_sequence(start_ids, start_ids = pad_sequence(start_ids,
batch_first=True, batch_first=True,
......
...@@ -7,53 +7,16 @@ input [ ...@@ -7,53 +7,16 @@ input [
name: "QUERY" name: "QUERY"
data_type: TYPE_STRING data_type: TYPE_STRING
dims: [ -1 ] dims: [ -1 ]
},
{
name: "BAD_WORDS_DICT"
data_type: TYPE_STRING
dims: [ -1 ]
optional: true
},
{
name: "STOP_WORDS_DICT"
data_type: TYPE_STRING
dims: [ -1 ]
optional: true
},
{
name: "REQUEST_OUTPUT_LEN"
data_type: TYPE_UINT32
dims: [ -1 ]
} }
] ]
output [ output [
{ {
name: "INPUT_ID" name: "INPUT_ID"
data_type: TYPE_UINT32
dims: [ -1 ]
},
{
name: "REQUEST_INPUT_LEN"
data_type: TYPE_UINT32
dims: [ 1 ]
},
{
name: "BAD_WORDS_IDS"
data_type: TYPE_INT32 data_type: TYPE_INT32
dims: [ 2, -1 ]
},
{
name: "STOP_WORDS_IDS"
data_type: TYPE_INT32
dims: [ 2, -1 ]
},
{
name: "REQUEST_OUTPUT_LEN"
data_type: TYPE_UINT32
dims: [ -1 ] dims: [ -1 ]
}, },
{ {
name: "PROMPT_LEARNING_TASK_NAME_IDS" name: "REQUEST_INPUT_LEN"
data_type: TYPE_UINT32 data_type: TYPE_UINT32
dims: [ 1 ] dims: [ 1 ]
} }
......
...@@ -48,11 +48,7 @@ class Preprocessor: ...@@ -48,11 +48,7 @@ class Preprocessor:
f'{type(prompts)}' f'{type(prompts)}'
input0_data = np.array(input0).astype(object) input0_data = np.array(input0).astype(object)
output0_len = np.ones_like(input0).astype(np.uint32) inputs = [prepare_tensor('QUERY', input0_data)]
inputs = [
prepare_tensor('QUERY', input0_data),
prepare_tensor('REQUEST_OUTPUT_LEN', output0_len)
]
with grpcclient.InferenceServerClient(self.tritonserver_addr) as \ with grpcclient.InferenceServerClient(self.tritonserver_addr) as \
client: client:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment