Unverified Commit 4ae2e81c authored by ZiWei Yuan's avatar ZiWei Yuan Committed by GitHub
Browse files

Merge pull request #152 from kvcache-ai/server_support

Server support
parents f30c6482 4385e850
...@@ -160,7 +160,7 @@ def local_chat( ...@@ -160,7 +160,7 @@ def local_chat(
messages, add_generation_prompt=True, return_tensors="pt" messages, add_generation_prompt=True, return_tensors="pt"
) )
if force_think: if force_think:
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)]) token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_tensor.device)
input_tensor = torch.cat( input_tensor = torch.cat(
[input_tensor, token_thinks], dim=1 [input_tensor, token_thinks], dim=1
) )
......
...@@ -90,6 +90,7 @@ class ArgumentParser: ...@@ -90,6 +90,7 @@ class ArgumentParser:
# user config # user config
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
parser.add_argument("--force_think", type=bool, default=self.cfg.user_force_think)
# web config # web config
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
...@@ -121,4 +122,5 @@ class ArgumentParser: ...@@ -121,4 +122,5 @@ class ArgumentParser:
self.cfg.server_ip = args.host self.cfg.server_ip = args.host
self.cfg.server_port = args.port self.cfg.server_port = args.port
self.cfg.backend_type = args.type self.cfg.backend_type = args.type
self.cfg.user_force_think = args.force_think
return args return args
...@@ -10,6 +10,7 @@ from transformers import ( ...@@ -10,6 +10,7 @@ from transformers import (
BitsAndBytesConfig, BitsAndBytesConfig,
) )
from ktransformers.server.config.config import Config
from ktransformers.server.schemas.base import ObjectID from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.multi_timer import Profiler from ktransformers.server.utils.multi_timer import Profiler
import torch import torch
...@@ -323,10 +324,19 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -323,10 +324,19 @@ class TransformersInterface(BackendInterfaceBase):
#input_ids = torch.tensor([[6366]], device=input_ids.device) #input_ids = torch.tensor([[6366]], device=input_ids.device)
else: else:
raise ValueError("local_messages should be List or str") raise ValueError("local_messages should be List or str")
if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_ids.device)
input_ids = torch.cat(
[input_ids, token_thinks], dim=1
)
self.profiler.pause_timer("tokenize") self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill") self.profiler.create_and_start_timer("prefill")
if Config().user_force_think:
t = "<think>\n"
print(t,end="",flush=True)
yield t
for t in self.prefill(input_ids, self.check_is_new(thread_id)): for t in self.prefill(input_ids, self.check_is_new(thread_id)):
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
...@@ -337,7 +347,7 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -337,7 +347,7 @@ class TransformersInterface(BackendInterfaceBase):
for t in self.generate(): for t in self.generate():
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t yield t
print("") print("")
self.profiler.pause_timer("decode") self.profiler.pause_timer("decode")
self.report_last_time_performance() self.report_last_time_performance()
...@@ -83,6 +83,7 @@ class Config(metaclass=Singleton): ...@@ -83,6 +83,7 @@ class Config(metaclass=Singleton):
self.user_config: dict = cfg.get("user", {}) self.user_config: dict = cfg.get("user", {})
self.user_secret_key = self.user_config.get("secret_key", "") self.user_secret_key = self.user_config.get("secret_key", "")
self.user_algorithm = self.user_config.get("algorithm", "") self.user_algorithm = self.user_config.get("algorithm", "")
self.user_force_think = self.user_config.get("force_think", False)
# model config # model config
self.model: dict = cfg.get("model", {}) self.model: dict = cfg.get("model", {})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment