Commit c009512a authored by Azure-Tang's avatar Azure-Tang
Browse files

Merge branch 'main' into hip

parents c1f13a69 4f22d726
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearFP8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearFP8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
30: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
...@@ -153,9 +153,20 @@ ...@@ -153,9 +153,20 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace: replace:
class: "default" class: ktransformers.operators.linear.KTransformersLinear
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
...@@ -135,7 +135,18 @@ ...@@ -135,7 +135,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -5,6 +5,18 @@ ...@@ -5,6 +5,18 @@
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
...@@ -48,6 +60,7 @@ ...@@ -48,6 +60,7 @@
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
......
...@@ -15,6 +15,16 @@ ...@@ -15,6 +15,16 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.block_sparse_moe$" name: "^model\\.layers\\..*\\.block_sparse_moe$"
class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock
......
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)
#- match:
# name: "^model\\.layers\\..*\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
# kwargs:
# prefill_device: "cuda"
# prefill_op: "KExpertsTorch"
# generate_device: "cuda"
# generate_op: "KExpertsMarlin"
# recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
\ No newline at end of file
...@@ -77,9 +77,19 @@ ...@@ -77,9 +77,19 @@
kwargs: kwargs:
generate_device: "cpu" generate_device: "cpu"
prefill_device: "cpu" prefill_device: "cpu"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "(^model.norm)|(^lm_head)" name: "(^model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,16 @@ ...@@ -15,6 +15,16 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
......
...@@ -12,8 +12,10 @@ from ktransformers.server.config.config import Config ...@@ -12,8 +12,10 @@ from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import check_link_response from ktransformers.server.schemas.assistants.streaming import check_link_response
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
router = APIRouter(prefix='/api')
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter(prefix='/api')
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
class OllamaGenerateCompletionRequest(BaseModel): class OllamaGenerateCompletionRequest(BaseModel):
...@@ -40,61 +42,129 @@ class OllamaGenerateCompletionRequest(BaseModel): ...@@ -40,61 +42,129 @@ class OllamaGenerateCompletionRequest(BaseModel):
keep_alive: Optional[str] = Field( keep_alive: Optional[str] = Field(
"5m", description="Controls how long the model will stay loaded into memory following the request.") "5m", description="Controls how long the model will stay loaded into memory following the request.")
class OllamaGenerationStreamResponse(BaseModel): class OllamaGenerationStreamResponse(BaseModel):
model: str model: str
created_at: str created_at: str
response: str response: str
done: bool = Field(...) done: bool = Field(...)
class OllamaGenerationResponse(BaseModel): class OllamaGenerationResponse(BaseModel):
pass pass
@router.post("/generate", tags=['ollama']) @router.post("/generate", tags=['ollama'])
async def generate(request: Request, input: OllamaGenerateCompletionRequest): async def generate(request: Request, input: OllamaGenerateCompletionRequest):
id = str(uuid4()) id = str(uuid4())
interface: BackendInterfaceBase = get_interface() interface: BackendInterfaceBase = get_interface()
print(f'COMPLETION INPUT:----\n{input.prompt}\n----') print(f'COMPLETION INPUT:----\n{input.prompt}\n----')
config = Config() config = Config()
if input.stream: if input.stream:
async def inner(): async def inner():
async for token in interface.inference(input.prompt,id): async for res in interface.inference(input.prompt, id):
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response=token,done=False) if isinstance(res, RawUsage):
yield d.model_dump_json()+'\n' raw_usage = res
# d = {'model':config.model_name,'created_at':"", 'response':token,'done':False} else:
# yield f"{json.dumps(d)}\n" token, finish_reason = res
# d = {'model':config.model_name,'created_at':"", 'response':'','done':True} d = OllamaGenerationStreamResponse(
# yield f"{json.dumps(d)}\n" model=config.model_name,
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response='',done=True) created_at=str(datetime.now()),
yield d.model_dump_json()+'\n' response=token,
return check_link_response(request,inner()) done=False
)
yield d.model_dump_json() + '\n'
d = OllamaGenerationStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
response='',
done=True
)
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else: else:
raise NotImplementedError raise NotImplementedError
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
class OllamaChatCompletionMessage(BaseModel):
role: str
content: str
class OllamaChatCompletionRequest(BaseModel): class OllamaChatCompletionRequest(BaseModel):
pass model: str = Field(..., description="The model name, which is required.")
messages: List[OllamaChatCompletionMessage] = Field(
..., description="A list of messages to generate a response for.")
stream: bool = Field(True, description="If true, the response will be streamed.")
class OllamaChatCompletionStreamResponse(BaseModel): class OllamaChatCompletionStreamResponse(BaseModel):
pass model: str
created_at: str
message: dict
done: bool = Field(...)
total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds")
load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds")
prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt")
prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds")
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
class OllamaChatCompletionResponse(BaseModel): class OllamaChatCompletionResponse(BaseModel):
pass pass
@router.post("/chat", tags=['ollama']) @router.post("/chat", tags=['ollama'])
async def chat(request: Request, input: OllamaChatCompletionRequest): async def chat(request: Request, input: OllamaChatCompletionRequest):
raise NotImplementedError id = str(uuid4())
interface: BackendInterfaceBase = get_interface()
config = Config()
# 将消息转换为提示字符串
prompt = ""
for msg in input.messages:
prompt += f"{msg.role}: {msg.content}\n"
prompt += "assistant:"
if input.stream:
async def inner():
start_time = time() # 记录开始时间(秒)
eval_count = 0 # 统计生成的 token 数量
tokens = []
async for res in interface.inference(prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={"role": "assistant", "content": token},
done=False
)
yield d.model_dump_json() + '\n'
# 计算性能数据
end_time = time()
total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒
prompt_eval_count = len(prompt.split()) # 简单估算提示词数量
eval_duration = total_duration # 假设全部时间用于生成(简化)
prompt_eval_duration = 0 # 假设无单独提示评估时间
load_duration = 0 # 假设加载时间未知
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={},
done=True,
total_duration=total_duration,
load_duration=load_duration,
prompt_eval_count=prompt_eval_count,
prompt_eval_duration=prompt_eval_duration,
eval_count=eval_count,
eval_duration=eval_duration
)
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError("Non-streaming chat is not implemented.")
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
class OllamaModel(BaseModel): class OllamaModel(BaseModel):
...@@ -103,9 +173,8 @@ class OllamaModel(BaseModel): ...@@ -103,9 +173,8 @@ class OllamaModel(BaseModel):
size: int size: int
# TODO: fill the rest correctly # TODO: fill the rest correctly
# mock ollama # mock ollama
@router.get("/tags",tags=['ollama']) @router.get("/tags", tags=['ollama'])
async def tags(): async def tags():
config = Config() config = Config()
# TODO: fill this correctly, although it does not effect Tabby # TODO: fill this correctly, although it does not effect Tabby
...@@ -138,25 +207,21 @@ class OllamaShowResponse(BaseModel): ...@@ -138,25 +207,21 @@ class OllamaShowResponse(BaseModel):
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
@router.post("/show", tags=['ollama']) @router.post("/show", tags=['ollama'])
async def show(request: Request, input: OllamaShowRequest): async def show(request: Request, input: OllamaShowRequest):
config = Config() config = Config()
# TODO: Add more info in config to return, although it does not effect Tabby # TODO: Add more info in config to return, although it does not effect Tabby
return OllamaShowResponse( return OllamaShowResponse(
modelfile = "# Modelfile generated by ...", modelfile="# Modelfile generated by ...",
parameters = " ", parameters=" ",
template = " ", template=" ",
details = OllamaShowDetial( details=OllamaShowDetial(
parent_model = " ", parent_model=" ",
format = "gguf", format="gguf",
family = " ", family=" ",
families = [ families=[" "],
" " parameter_size=" ",
], quantization_level=" "
parameter_size = " ",
quantization_level = " "
), ),
model_info = OllamaModelInfo() model_info=OllamaModelInfo()
) )
\ No newline at end of file
...@@ -5,18 +5,21 @@ from fastapi import APIRouter ...@@ -5,18 +5,21 @@ from fastapi import APIRouter
from fastapi.requests import Request from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response from ktransformers.server.schemas.assistants.streaming import chat_stream_response
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
from openai.types.chat import ChatCompletion
from openai.types.completion_usage import CompletionUsage
router = APIRouter()
models = [ router = APIRouter()
{"id": "0", "name": "ktranformers-model"},
]
@router.get('/models', tags=['openai']) @router.get('/models', tags=['openai'])
async def list_models(): async def list_models():
return models return {"data": [{"id": Config().model_name, "name": Config().model_name}], "object": "list"}
@router.post('/chat/completions', tags=['openai']) @router.post('/chat/completions', tags=['openai'])
...@@ -28,15 +31,80 @@ async def chat_completion(request:Request,create:ChatCompletionCreate): ...@@ -28,15 +31,80 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
input_message = [json.loads(m.model_dump_json()) for m in create.messages] input_message = [json.loads(m.model_dump_json()) for m in create.messages]
if Config().api_key != '':
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
if create.stream: if create.stream:
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
async def inner(): async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) chunk = ChatCompletionChunk(
async for token in interface.inference(input_message,id): id = id,
chunk.set_token(token) choices = [],
yield chunk object = 'chat.completion.chunk',
return chat_stream_response(request,inner()) created = int(time()),
model = Config().model_name,
)
async for res in interface.inference(input_message,id, create.temperature, create.top_p):
if isinstance(res, RawUsage):
# at the end of inference, interface.inference() will return the usage of inference
raw_usage = res
chunk.choices = []
chunk.usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
yield chunk
else:
token, finish_reason = res
choice = Choice(
index = 0,
delta = ChoiceDelta(content=token, role=None, tool_calls=None),
finish_reason = finish_reason,
logprobs = None,
)
chunk.choices = [choice]
yield chunk
return chat_stream_response(request, inner())
else: else:
comp = ChatCompletionObject(id=id,object='chat.completion.chunk',created=int(time())) from openai.types.chat.chat_completion import Choice
async for token in interface.inference(input_message,id): from openai.types.chat.chat_completion_message import ChatCompletionMessage
comp.append_token(token)
return comp content = ""
finish_reason = None
async for res in interface.inference(input_message,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
else:
token, finish_reason = res
content = content + token
finish_reason = finish_reason
choice = Choice(
index = 0,
finish_reason = finish_reason,
message = ChatCompletionMessage(
content=content,
role="assistant"
))
chat_completion = ChatCompletion(
id = id,
choices = [choice],
created = int(time()),
model = Config().model_name,
object = 'chat.completion',
usage = usage
)
return chat_completion
...@@ -6,6 +6,7 @@ from fastapi.requests import Request ...@@ -6,6 +6,7 @@ from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import stream_response from ktransformers.server.schemas.assistants.streaming import stream_response
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter() router = APIRouter()
...@@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate): ...@@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate):
print(f'COMPLETION INPUT:----\n{create.prompt}\n----') print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
if create.stream: if create.stream:
async def inner(): async def inner():
async for token in interface.inference(create.prompt,id): async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
d = {'choices':[{'delta':{'content':token}}]} if isinstance(res, RawUsage):
yield f"data:{json.dumps(d)}\n\n" raw_usage = res
else:
token, finish_reason = res
d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
yield f"data:{json.dumps(d)}\n\n" yield f"data:{json.dumps(d)}\n\n"
return stream_response(request,inner()) return stream_response(request,inner())
else: else:
comp = CompletionObject(id=id,object='text_completion',created=int(time())) comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id): async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
comp.append_token(token) if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
comp.append_token(token)
return comp return comp
...@@ -10,6 +10,7 @@ class ArgumentParser: ...@@ -10,6 +10,7 @@ class ArgumentParser:
parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers") parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers")
parser.add_argument("--host", type=str, default=self.cfg.server_ip) parser.add_argument("--host", type=str, default=self.cfg.server_ip)
parser.add_argument("--port", type=int, default=self.cfg.server_port) parser.add_argument("--port", type=int, default=self.cfg.server_port)
parser.add_argument("--api_key", type=str, default=self.cfg.api_key)
parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str) parser.add_argument("--ssl_certfile", type=str)
parser.add_argument("--web", type=bool, default=self.cfg.mount_web) parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
...@@ -23,13 +24,13 @@ class ArgumentParser: ...@@ -23,13 +24,13 @@ class ArgumentParser:
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False) parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer) parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
parser.add_argument("--type", type=str, default=self.cfg.backend_type) parser.add_argument("--type", type=str, default=self.cfg.backend_type)
parser.add_argument("--chunk_prefill_size", type=int, default=8192)
# model configs # model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int? # parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser.add_argument("--paged", type=bool, default=self.cfg.paged) parser.add_argument("--paged", type=bool, default=self.cfg.paged)
parser.add_argument("--total_context", type=int, default=self.cfg.total_context) parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size) parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
parser.add_argument("--max_chunk_size", type=int, default=self.cfg.max_chunk_size)
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens) parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode) parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
parser.add_argument("--healing", type=bool, default=self.cfg.healing) parser.add_argument("--healing", type=bool, default=self.cfg.healing)
...@@ -90,7 +91,8 @@ class ArgumentParser: ...@@ -90,7 +91,8 @@ class ArgumentParser:
# user config # user config
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
parser.add_argument("--force_think", type=bool, default=self.cfg.user_force_think) parser.add_argument("--force_think", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think)
parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph)
# web config # web config
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
......
...@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel): ...@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel):
max_batch_size: int = Field( max_batch_size: int = Field(
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context" None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
) )
max_chunk_size: int = Field( chunk_prefill_size: int = Field(
None, None,
description=( description=(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new" "Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
......
...@@ -15,6 +15,7 @@ from ktransformers.server.schemas.assistants.assistants import AssistantObject ...@@ -15,6 +15,7 @@ from ktransformers.server.schemas.assistants.assistants import AssistantObject
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role
from ktransformers.server.schemas.assistants.runs import RunObject from ktransformers.server.schemas.assistants.runs import RunObject
from ktransformers.server.schemas.assistants.threads import ThreadObject from ktransformers.server.schemas.assistants.threads import ThreadObject
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.schemas.base import ObjectID, Order from ktransformers.server.schemas.base import ObjectID, Order
from ktransformers.server.utils.multi_timer import Profiler from ktransformers.server.utils.multi_timer import Profiler
...@@ -142,12 +143,16 @@ class ThreadContext: ...@@ -142,12 +143,16 @@ class ThreadContext:
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress) yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
yield self.run.stream_response_with_event(RunObject.Status.in_progress) yield self.run.stream_response_with_event(RunObject.Status.in_progress)
async for token in self.interface.inference(local_messages,self.thread.id): async for res in self.interface.inference(local_messages,self.thread.id):
if self.run.status == RunObject.Status.cancelling: if isinstance(res, RawUsage):
logger.warn(f'Run {self.run.id} cancelling') raw_usage = res
break else:
yield reply_message.append_message_delta(token) token, finish_reason = res
response_str_count+=1 if self.run.status == RunObject.Status.cancelling:
logger.warn(f'Run {self.run.id} cancelling')
break
yield reply_message.append_message_delta(token)
response_str_count+=1
if self.run.status == RunObject.Status.cancelling: if self.run.status == RunObject.Status.cancelling:
yield self.run.stream_response_with_event(RunObject.Status.cancelled) yield self.run.stream_response_with_event(RunObject.Status.cancelled)
......
import torch import torch
import asyncio
from transformers import AutoTokenizer, AutoConfig, GenerationConfig from transformers import AutoTokenizer, AutoConfig, GenerationConfig
from ktransformers.server.backend.interfaces.transformers import ( from ktransformers.server.backend.interfaces.transformers import (
TransformersInterface, TransformersInterface,
...@@ -13,7 +14,11 @@ from ktransformers.models.custom_cache import StaticCache ...@@ -13,7 +14,11 @@ from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device from ktransformers.util.utils import get_device
from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
from ktransformers.server.schemas.endpoints.chat import RawUsage
warm_uped = False
class KTransformersThreadContext(TransformersThreadContext): class KTransformersThreadContext(TransformersThreadContext):
pass pass
...@@ -22,19 +27,29 @@ class KTransformersThreadContext(TransformersThreadContext): ...@@ -22,19 +27,29 @@ class KTransformersThreadContext(TransformersThreadContext):
class KTransformersInterface(TransformersInterface): class KTransformersInterface(TransformersInterface):
def __init__(self, args: ConfigArgs = default_args): def __init__(self, args: ConfigArgs = default_args):
self.args = args self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
try:
generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
generation_config = GenerationConfig(
max_length=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
do_sample=True
)
torch.set_default_dtype(config.torch_dtype)
if config.architectures[0] == "Qwen2MoeForCausalLM": if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
with torch.device("meta"): with torch.device("meta"):
self.model = custom_models[config.architectures[0]](config) self.model = custom_models[config.architectures[0]](config)
if default_args.optimize_config_path is None: if default_args.optimize_config_path is None:
optimize_rule_path = default_optimize_rules[config.architectures[0]] optimize_config_path = default_optimize_rules[config.architectures[0]]
else: else:
optimize_rule_path = args.optimize_config_path optimize_config_path = args.optimize_config_path
# print(optimize_config) # print(optimize_config)
...@@ -44,8 +59,8 @@ class KTransformersInterface(TransformersInterface): ...@@ -44,8 +59,8 @@ class KTransformersInterface(TransformersInterface):
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all" "please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):" " belong to current model):"
) )
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config) optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.model.generation_config = generation_config
self.device_map = self.model.gguf_loader.tensor_device_map self.device_map = self.model.gguf_loader.tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}") # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
self.cache = StaticCache( self.cache = StaticCache(
...@@ -56,25 +71,21 @@ class KTransformersInterface(TransformersInterface): ...@@ -56,25 +71,21 @@ class KTransformersInterface(TransformersInterface):
dtype=self.model.dtype, dtype=self.model.dtype,
) )
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}") # logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
try:
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
gen_config = GenerationConfig(
max_length=128,
temperature=0.7,
top_p=0.9,
do_sample=True
)
self.model.generation_config = gen_config
if self.model.generation_config.pad_token_id is None: if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.streamer = TextStreamer(self.tokenizer) self.streamer = TextStreamer(self.tokenizer)
self._infer_lock = asyncio.Lock()
def decode_one_tokens(self): def decode_one_tokens(self):
global warm_uped
device_map = self.model.gguf_loader.tensor_device_map device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map) torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device torch_device = "cuda:0" if torch_device == "cuda" else torch_device
if self.args.use_cuda_graph: torch.cuda.set_device(torch_device)
if warm_uped and self.args.use_cuda_graph:
if not hasattr(self, "cuda_graph_runner"): if not hasattr(self, "cuda_graph_runner"):
self.cuda_graph_runner = CUDAGraphRunner() self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture( self.cuda_graph_runner.capture(
...@@ -96,14 +107,15 @@ class KTransformersInterface(TransformersInterface): ...@@ -96,14 +107,15 @@ class KTransformersInterface(TransformersInterface):
torch.cuda.synchronize() torch.cuda.synchronize()
logits = logits[0, -1, :] logits = logits[0, -1, :]
return self.logits_to_token(logits) return self.logits_to_token(logits)
if self.args.use_cuda_graph:
warm_uped = True
if self.use_static_cache: if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(torch_device)
logits = self.model( logits = self.model(
self.current_ids.to(torch_device), self.current_ids.to(torch_device),
cache_position=self.active_cache_position, cache_position=self.active_cache_position,
past_key_values=self.cache, past_key_values=self.cache,
attention_mask=mask,
return_dict=False, return_dict=False,
use_cache=True, use_cache=True,
)[0] )[0]
...@@ -116,59 +128,116 @@ class KTransformersInterface(TransformersInterface): ...@@ -116,59 +128,116 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool): def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
self.profiler.set_counter("prefill", input_ids_length) if(input_ids_length >= self.args.cache_lens):
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
self.seq_length = input_ids_length
return
logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device
if is_new: if is_new:
self.cache.reset()
self.ever_generated_ids.clear() self.ever_generated_ids.clear()
former_seq_length = 0 same_prefix = 0
self.seq_length = input_ids_length flat_input_ids = input_ids.flatten()
self.generated_ids = torch.zeros(
self.args.batch_size, if getattr(self, 'generated_ids', None) is None:
self.seq_length + self.args.max_new_tokens + 1, self.generated_ids = torch.zeros(
dtype=torch.int, self.args.batch_size,
device=self.args.device, input_ids.shape[-1] + self.args.max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
self.seq_length = 1
flat_prev_ids = self.generated_ids.flatten()
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
if flat_input_ids[i] == flat_prev_ids[i]:
same_prefix += 1
else:
break
logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]
self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
) )
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
else: else:
logger.debug(f"generate_ids: {self.generated_ids.shape}") logger.warning(f"seq_length bigger than cache_lens, killed")
former_seq_length = self.seq_length exit(0)
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device) cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
mask = torch.ones((1, self.seq_length)).to(device)
if not (type(self) is TransformersInterface): if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu") input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
if self.use_static_cache: def chunk_prefill(input_ids, cache_position):
logits = self.model( inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
inputs_embeds=inputs_embeds, torch.cuda.set_device(device)
cache_position=cache_position, if flashinfer_enabled:
past_key_values=self.cache, MLAWrapperSingleton.need_plan_all()
return_dict=False, if self.use_static_cache:
use_cache=True, logits = self.model(
attention_mask=mask, inputs_embeds=inputs_embeds,
)[0] cache_position=cache_position,
else: past_key_values=self.cache,
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] return_dict=False,
use_cache=True,
)[0]
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
return logits
chunk_start = 0
while chunk_start < input_ids_length:
chunk_end = min(chunk_start + self.args.chunk_prefill_size, input_ids_length)
if self.cache != None:
self.cache.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
chunk_start += self.args.chunk_prefill_size
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
@property @property
def active_cache_position(self): def active_cache_position(self):
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
return torch.tensor([self.seq_length - 1], device=device) return torch.tensor([self.seq_length - 1], device=device)
\ No newline at end of file
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p):
yield v
# return this inference raw usage
yield RawUsage(
tokenize_time = self.profiler.get_timer_sec('tokenize'),
prefill_time = self.profiler.get_timer_sec('prefill'),
decode_time = self.profiler.get_timer_sec('decode'),
prefill_count = self.profiler.get_counter('prefill'),
decode_count = self.profiler.get_counter('decode'),
)
\ No newline at end of file
...@@ -13,12 +13,13 @@ from transformers import ( ...@@ -13,12 +13,13 @@ from transformers import (
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.schemas.base import ObjectID from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.multi_timer import Profiler from ktransformers.server.utils.multi_timer import Profiler
from torch.nn.attention import SDPBackend
import torch import torch
import sys, os import sys, os
from ..base import ThreadContext, BackendInterfaceBase from ..base import ThreadContext, BackendInterfaceBase
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
from ..args import ConfigArgs, default_args from ..args import ConfigArgs, default_args
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py # This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class TextStreamer: class TextStreamer:
...@@ -170,7 +171,7 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -170,7 +171,7 @@ class TransformersInterface(BackendInterfaceBase):
for m in messages[1:]: for m in messages[1:]:
if m["role"] == "user" and new_messages[-1]["role"] == "user": if m["role"] == "user" and new_messages[-1]["role"] == "user":
logger.warning("merge two adjacent user messages") logger.warning("merge two adjacent user messages")
new_messages[-1]["content"] += m["content"] new_messages[-1]["content"] += '\n' + m["content"]
else: else:
new_messages.append(m) new_messages.append(m)
# if (self.last_request_id is not None) and self.last_request_id == thread_id: # if (self.last_request_id is not None) and self.last_request_id == thread_id:
...@@ -179,7 +180,11 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -179,7 +180,11 @@ class TransformersInterface(BackendInterfaceBase):
# input_ids = self.tokenizer.apply_chat_template( # input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True # new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device) # ).to(self.args.device)
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device) input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template
if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
if (self.last_request_id is not None) and self.last_request_id == thread_id: if (self.last_request_id is not None) and self.last_request_id == thread_id:
x = self.generated_ids[:,:self.seq_length] x = self.generated_ids[:,:self.seq_length]
y = input_ids[:,:self.seq_length] y = input_ids[:,:self.seq_length]
...@@ -198,14 +203,31 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -198,14 +203,31 @@ class TransformersInterface(BackendInterfaceBase):
self.seq_length += 1 self.seq_length += 1
return self.streamer.put(new_tokens) return self.streamer.put(new_tokens)
def logits_to_token(self, logits: torch.Tensor): def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
logits = logits / self.args.temperature if self.args.temperature!=0 else logits if temperature is None or temperature == 0:
temperature = self.model.generation_config.temperature
if top_p is None:
top_p = self.model.generation_config.top_p
generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
top_k=self.args.top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
)
self.inputs = inputs
try: # transformers==4.43
self.logits_warper = (
self.model._get_logits_warper(generation_config, device=device)
)
except:
self.logits_warper = (
self.model._get_logits_warper(generation_config)
)
for token_idx in self.ever_generated_ids: def logits_to_token(self, logits: torch.Tensor):
if logits[token_idx] < 0: logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
logits[token_idx] *= self.args.repetition_penalty
else:
logits[token_idx] /= self.args.repetition_penalty
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)
...@@ -221,12 +243,10 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -221,12 +243,10 @@ class TransformersInterface(BackendInterfaceBase):
def decode_one_tokens(self): def decode_one_tokens(self):
if self.use_static_cache: if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(self.args.device)
logits = self.model( logits = self.model(
self.current_ids, self.current_ids,
cache_position=self.active_cache_position, cache_position=self.active_cache_position,
past_key_values=self.cache, past_key_values=self.cache,
attention_mask=mask,
return_dict=False, return_dict=False,
use_cache=True, use_cache=True,
)[0] )[0]
...@@ -237,38 +257,57 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -237,38 +257,57 @@ class TransformersInterface(BackendInterfaceBase):
return self.logits_to_token(logits) return self.logits_to_token(logits)
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool): def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"input_ids: {input_ids.shape}")
if is_new: if is_new:
self.cache.reset()
self.ever_generated_ids.clear() self.ever_generated_ids.clear()
former_seq_length = 0 same_prefix = 0
self.seq_length = input_ids_length flat_input_ids = input_ids.flatten()
self.generated_ids = torch.zeros(
self.args.batch_size, if getattr(self, 'generated_ids', None) is None:
self.seq_length + self.args.max_new_tokens + 1, self.generated_ids = torch.zeros(
dtype=torch.int, self.args.batch_size,
device=self.args.device, input_ids.shape[-1] + self.args.max_new_tokens + 1,
) dtype=torch.int,
else: device=self.args.device,
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
) )
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) self.seq_length = 1
flat_prev_ids = self.generated_ids.flatten()
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
if flat_input_ids[i] == flat_prev_ids[i]:
same_prefix += 1
else:
break
logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]
self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device) cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
mask = torch.ones((1, self.seq_length)).to(self.args.device)
device = input_ids.device device = input_ids.device
if not (type(self) is TransformersInterface): if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu") input_ids = input_ids.to("cpu")
...@@ -280,26 +319,46 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -280,26 +319,46 @@ class TransformersInterface(BackendInterfaceBase):
past_key_values=self.cache, past_key_values=self.cache,
return_dict=False, return_dict=False,
use_cache=True, use_cache=True,
attention_mask=mask,
)[0] )[0]
else: else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
@torch.no_grad @torch.no_grad
def generate(self): def generate(self):
self.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) - 1
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
if(self.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0")
yield self.streamer.end(), "length"
return
logger.info(f"max_new_tokens: {self.max_new_tokens}")
self.profiler.set_counter("decode", 0) self.profiler.set_counter("decode", 0)
for _ in range(1, self.args.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): for i in range(1, self.max_new_tokens):
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
if flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
next_token = self.decode_one_tokens() next_token = self.decode_one_tokens()
self.profiler.inc("decode") self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id: if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
yield self.streamer.end(), None
yield "", "stop"
assert self.args.batch_size == 1 assert self.args.batch_size == 1
break break
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token), None
yield self.streamer.end()
else: # for's else, if output get max new tokens
yield self.streamer.end(), None
yield "", "length"
def check_is_new(self, thread_id: str): def check_is_new(self, thread_id: str):
if not self.use_static_cache: if not self.use_static_cache:
...@@ -314,7 +373,8 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -314,7 +373,8 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id self.last_request_id = thread_id
return True return True
async def inference(self, local_messages, thread_id: str): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize") self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List): if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
...@@ -324,8 +384,9 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -324,8 +384,9 @@ class TransformersInterface(BackendInterfaceBase):
#input_ids = torch.tensor([[6366]], device=input_ids.device) #input_ids = torch.tensor([[6366]], device=input_ids.device)
else: else:
raise ValueError("local_messages should be List or str") raise ValueError("local_messages should be List or str")
if Config().user_force_think: if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_ids.device) token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
input_ids = torch.cat( input_ids = torch.cat(
[input_ids, token_thinks], dim=1 [input_ids, token_thinks], dim=1
) )
...@@ -333,21 +394,24 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -333,21 +394,24 @@ class TransformersInterface(BackendInterfaceBase):
self.profiler.pause_timer("tokenize") self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill") self.profiler.create_and_start_timer("prefill")
if Config().user_force_think: if Config().user_force_think:
t = "<think>\n" think = '<think>\n'
print(t,end="",flush=True) print(think, end="",flush=True)
yield t yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t yield t, None
self.profiler.pause_timer("prefill") self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode") self.profiler.create_and_start_timer("decode")
for t in self.generate(): for t, finish_reason in self.generate():
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t yield t, finish_reason
print("") print("")
self.profiler.pause_timer("decode") self.profiler.pause_timer("decode")
self.report_last_time_performance() self.report_last_time_performance()
...@@ -69,6 +69,7 @@ class Config(metaclass=Singleton): ...@@ -69,6 +69,7 @@ class Config(metaclass=Singleton):
self.server: dict = cfg.get("server", {}) self.server: dict = cfg.get("server", {})
self.server_ip = self.server.get("ip", "0.0.0.0") self.server_ip = self.server.get("ip", "0.0.0.0")
self.server_port = self.server.get("port", 9016) self.server_port = self.server.get("port", 9016)
self.api_key = self.server.get("api_key", "")
# db configs # db configs
self.db_configs: dict = cfg.get("db", {}) self.db_configs: dict = cfg.get("db", {})
...@@ -104,7 +105,8 @@ class Config(metaclass=Singleton): ...@@ -104,7 +105,8 @@ class Config(metaclass=Singleton):
self.total_context = self.model.get("total_context", 2**18) self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1) self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.max_chunk_size = self.model.get("max_chunk_size", 2048) self.chunk_prefill_size = self.model.get("chunk_prefill_size", 8192)
self.max_new_tokens = self.model.get("max_new_tokens", 2000) self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False) self.json_mode = self.model.get("json_mode", False)
self.healing = self.model.get("healing", False) self.healing = self.model.get("healing", False)
......
...@@ -105,6 +105,7 @@ def custom_openapi(app): ...@@ -105,6 +105,7 @@ def custom_openapi(app):
def main(): def main():
cfg = Config() cfg = Config()
arg_parser = ArgumentParser(cfg) arg_parser = ArgumentParser(cfg)
# 初始化消息 # 初始化消息
......
...@@ -5,6 +5,7 @@ langchain >= 0.2.0 ...@@ -5,6 +5,7 @@ langchain >= 0.2.0
blessed >= 1.20.0 blessed >= 1.20.0
accelerate >= 0.31.0 accelerate >= 0.31.0
sentencepiece >= 0.1.97 sentencepiece >= 0.1.97
openai
setuptools setuptools
build build
ninja ninja
......
...@@ -73,7 +73,7 @@ class RunStepDelta(Object): ...@@ -73,7 +73,7 @@ class RunStepDelta(Object):
class Done(): class Done():
def to_stream_reply(self): def to_stream_reply(self):
return f"event: done\ndata: [DONE]\n\n" return f"data: [DONE]\n\n"
async def check_client_link(request: Request, async_events: AsyncIterable): async def check_client_link(request: Request, async_events: AsyncIterable):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment