Unverified Commit 7d40d190 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

fix turbomind stream canceling (#686)

* fix

* instance for each forward
parent 4eb8dd83
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .api_server_backend import run_api_server
from .triton_server_backend import run_triton_server
from .turbomind_coupled import run_local
__all__ = ['run_api_server', 'run_triton_server', 'run_local']
...@@ -185,3 +185,8 @@ def run_local(model_path: str, ...@@ -185,3 +185,8 @@ def run_local(model_path: str,
server_port=server_port, server_port=server_port,
server_name=server_name, server_name=server_name,
) )
if __name__ == '__main__':
import fire
fire.Fire(run_local)
...@@ -178,8 +178,6 @@ class TurboMindInstance: ...@@ -178,8 +178,6 @@ class TurboMindInstance:
self.session_len = tm_model.session_len self.session_len = tm_model.session_len
self.nccl_params = tm_model.nccl_params self.nccl_params = tm_model.nccl_params
self.instance_comm = tm_model.model_comm.create_instance_comm(
self.gpu_count)
# create model instances # create model instances
model_insts = [None] * self.gpu_count model_insts = [None] * self.gpu_count
...@@ -207,16 +205,20 @@ class TurboMindInstance: ...@@ -207,16 +205,20 @@ class TurboMindInstance:
self.que.put((False, result)) self.que.put((False, result))
def _forward_thread(self, inputs): def _forward_thread(self, inputs):
instance_comm = self.tm_model.model_comm.create_instance_comm(
self.gpu_count)
def _func(device_id, enque_output): def _func(device_id, enque_output):
with cuda_ctx(device_id): with cuda_ctx(device_id):
output = self.model_insts[device_id].forward( output = self.model_insts[device_id].forward(
inputs, self.instance_comm) inputs, instance_comm)
if enque_output: if enque_output:
self.que.put((True, output)) self.que.put((True, output))
for device_id in range(self.gpu_count): for device_id in range(self.gpu_count):
t = Thread(target=_func, args=(device_id, device_id == 0)) t = Thread(target=_func,
args=(device_id, device_id == 0),
daemon=True)
t.start() t.start()
self.threads[device_id] = t self.threads[device_id] = t
...@@ -264,7 +266,7 @@ class TurboMindInstance: ...@@ -264,7 +266,7 @@ class TurboMindInstance:
random_seed (int): seed used by sampling random_seed (int): seed used by sampling
stream_output (bool): indicator for stream output stream_output (bool): indicator for stream output
""" """
if stream_output: if stream_output and not stop:
self.model_insts[0].register_callback(self._forward_callback) self.model_insts[0].register_callback(self._forward_callback)
if len(input_ids) == 0: if len(input_ids) == 0:
...@@ -372,7 +374,7 @@ class TurboMindInstance: ...@@ -372,7 +374,7 @@ class TurboMindInstance:
self.que.get() self.que.get()
break break
if stream_output: if stream_output and not stop:
self.model_insts[0].unregister_callback() self.model_insts[0].unregister_callback()
def decode(self, input_ids): def decode(self, input_ids):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment