Commit ee24eb8d authored by ceerrep's avatar ceerrep
Browse files

fix: fix server for triton kernel

parent bb1cadff
...@@ -16,6 +16,8 @@ from ktransformers.local_chat import custom_models, default_optimize_rules ...@@ -16,6 +16,8 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device from ktransformers.util.utils import get_device
warm_uped = False
class KTransformersThreadContext(TransformersThreadContext): class KTransformersThreadContext(TransformersThreadContext):
pass pass
...@@ -74,10 +76,13 @@ class KTransformersInterface(TransformersInterface): ...@@ -74,10 +76,13 @@ class KTransformersInterface(TransformersInterface):
self._infer_lock = asyncio.Lock() self._infer_lock = asyncio.Lock()
def decode_one_tokens(self): def decode_one_tokens(self):
global warm_uped
device_map = self.model.gguf_loader.tensor_device_map device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map) torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device torch_device = "cuda:0" if torch_device == "cuda" else torch_device
if self.args.use_cuda_graph: torch.cuda.set_device(torch_device)
if warm_uped and self.args.use_cuda_graph:
if not hasattr(self, "cuda_graph_runner"): if not hasattr(self, "cuda_graph_runner"):
self.cuda_graph_runner = CUDAGraphRunner() self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture( self.cuda_graph_runner.capture(
...@@ -113,6 +118,7 @@ class KTransformersInterface(TransformersInterface): ...@@ -113,6 +118,7 @@ class KTransformersInterface(TransformersInterface):
else: else:
logits = self.model(self.current_ids, return_dict=False)[0] logits = self.model(self.current_ids, return_dict=False)[0]
logits = logits[0, -1, :] logits = logits[0, -1, :]
warm_uped = True
return self.logits_to_token(logits) return self.logits_to_token(logits)
...@@ -176,6 +182,7 @@ class KTransformersInterface(TransformersInterface): ...@@ -176,6 +182,7 @@ class KTransformersInterface(TransformersInterface):
if not (type(self) is TransformersInterface): if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu") input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if self.use_static_cache: if self.use_static_cache:
logits = self.model( logits = self.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
......
...@@ -106,9 +106,6 @@ def custom_openapi(app): ...@@ -106,9 +106,6 @@ def custom_openapi(app):
def main(): def main():
cfg = Config() cfg = Config()
# Temporarily disable cuda graph by default because of a bug in the prefix cache.
cfg.use_cuda_graph = False
arg_parser = ArgumentParser(cfg) arg_parser = ArgumentParser(cfg)
# 初始化消息 # 初始化消息
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment