"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "df7db85cef7f9c30a5b821007754b96eb1f977b6"
Unverified Commit ae5d9e11 authored by Azure's avatar Azure Committed by GitHub
Browse files

Merge pull request #227 from hrz6976/main

Add a lock to server inference()
parents a456e25a 2c3dcd97
import torch import torch
import asyncio
from transformers import AutoTokenizer, AutoConfig, GenerationConfig from transformers import AutoTokenizer, AutoConfig, GenerationConfig
from ktransformers.server.backend.interfaces.transformers import ( from ktransformers.server.backend.interfaces.transformers import (
TransformersInterface, TransformersInterface,
...@@ -70,6 +71,8 @@ class KTransformersInterface(TransformersInterface): ...@@ -70,6 +71,8 @@ class KTransformersInterface(TransformersInterface):
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.streamer = TextStreamer(self.tokenizer) self.streamer = TextStreamer(self.tokenizer)
self._infer_lock = asyncio.Lock()
def decode_one_tokens(self): def decode_one_tokens(self):
device_map = self.model.gguf_loader.tensor_device_map device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map) torch_device = get_device("blk.0.self_attn", device_map)
...@@ -171,4 +174,9 @@ class KTransformersInterface(TransformersInterface): ...@@ -171,4 +174,9 @@ class KTransformersInterface(TransformersInterface):
@property @property
def active_cache_position(self): def active_cache_position(self):
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
return torch.tensor([self.seq_length - 1], device=device) return torch.tensor([self.seq_length - 1], device=device)
\ No newline at end of file
async def inference(self, local_messages, thread_id: str):
async with self._infer_lock:
async for v in super().inference(local_messages, thread_id):
yield v
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment