Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
c009512a
Commit
c009512a
authored
Mar 13, 2025
by
Azure-Tang
Browse files
Merge branch 'main' into hip
parents
c1f13a69
4f22d726
Changes
121
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
772 additions
and
179 deletions
+772
-179
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
...s/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
+157
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
...ize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
+13
-2
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+12
-1
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+13
-0
ktransformers/optimize/optimize_rules/Mixtral.yaml
ktransformers/optimize/optimize_rules/Mixtral.yaml
+10
-0
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
+86
-0
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
...ize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
+11
-1
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
...mers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
+10
-0
ktransformers/server/api/ollama/completions.py
ktransformers/server/api/ollama/completions.py
+104
-39
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+83
-15
ktransformers/server/api/openai/legacy/completions.py
ktransformers/server/api/openai/legacy/completions.py
+14
-6
ktransformers/server/args.py
ktransformers/server/args.py
+4
-2
ktransformers/server/backend/args.py
ktransformers/server/backend/args.py
+1
-1
ktransformers/server/backend/base.py
ktransformers/server/backend/base.py
+11
-6
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+124
-55
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+113
-49
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+3
-1
ktransformers/server/main.py
ktransformers/server/main.py
+1
-0
ktransformers/server/requirements.txt
ktransformers/server/requirements.txt
+1
-0
ktransformers/server/schemas/assistants/streaming.py
ktransformers/server/schemas/assistants/streaming.py
+1
-1
No files found.
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
0 → 100644
View file @
c009512a
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
generate_op
:
"
KLinearFP8"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
generate_op
:
"
KLinearFP8"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:0"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:0"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:1"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:1"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
transfer_map
:
30
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
View file @
c009512a
...
...
@@ -153,9 +153,20 @@
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)|(lm_head)"
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
"
default"
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
View file @
c009512a
...
...
@@ -135,7 +135,18 @@
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)|(lm_head)"
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)"
replace
:
class
:
"
default"
kwargs
:
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
View file @
c009512a
...
...
@@ -5,6 +5,18 @@
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
...
...
@@ -48,6 +60,7 @@
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
...
...
ktransformers/optimize/optimize_rules/Mixtral.yaml
View file @
c009512a
...
...
@@ -15,6 +15,16 @@
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.block_sparse_moe$"
class
:
ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock
...
...
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
0 → 100644
View file @
c009512a
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.RotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)
#- match:
# name: "^model\\.layers\\..*\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
# kwargs:
# prefill_device: "cuda"
# prefill_op: "KExpertsTorch"
# generate_device: "cuda"
# generate_op: "KExpertsMarlin"
# recursive: False # don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
\ No newline at end of file
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
View file @
c009512a
...
...
@@ -77,9 +77,19 @@
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
(^model.norm)
|(^lm_head)
"
name
:
"
(^model.norm)"
replace
:
class
:
"
default"
kwargs
:
...
...
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
View file @
c009512a
...
...
@@ -15,6 +15,16 @@
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
...
...
ktransformers/server/api/ollama/completions.py
View file @
c009512a
...
...
@@ -12,8 +12,10 @@ from ktransformers.server.config.config import Config
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
check_link_response
from
ktransformers.server.backend.base
import
BackendInterfaceBase
router
=
APIRouter
(
prefix
=
'/api'
)
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
router
=
APIRouter
(
prefix
=
'/api'
)
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
class
OllamaGenerateCompletionRequest
(
BaseModel
):
...
...
@@ -40,61 +42,129 @@ class OllamaGenerateCompletionRequest(BaseModel):
keep_alive
:
Optional
[
str
]
=
Field
(
"5m"
,
description
=
"Controls how long the model will stay loaded into memory following the request."
)
class
OllamaGenerationStreamResponse
(
BaseModel
):
model
:
str
created_at
:
str
response
:
str
done
:
bool
=
Field
(...)
class
OllamaGenerationResponse
(
BaseModel
):
pass
@
router
.
post
(
"/generate"
,
tags
=
[
'ollama'
])
async
def
generate
(
request
:
Request
,
input
:
OllamaGenerateCompletionRequest
):
id
=
str
(
uuid4
())
interface
:
BackendInterfaceBase
=
get_interface
()
print
(
f
'COMPLETION INPUT:----
\n
{
input
.
prompt
}
\n
----'
)
config
=
Config
()
if
input
.
stream
:
async
def
inner
():
async
for
token
in
interface
.
inference
(
input
.
prompt
,
id
):
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
token
,
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
# d = {'model':config.model_name,'created_at':"", 'response':token,'done':False}
# yield f"{json.dumps(d)}\n"
# d = {'model':config.model_name,'created_at':"", 'response':'','done':True}
# yield f"{json.dumps(d)}\n"
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
''
,
done
=
True
)
yield
d
.
model_dump_json
()
+
'
\n
'
return
check_link_response
(
request
,
inner
())
async
for
res
in
interface
.
inference
(
input
.
prompt
,
id
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
token
,
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
''
,
done
=
True
)
yield
d
.
model_dump_json
()
+
'
\n
'
return
check_link_response
(
request
,
inner
())
else
:
raise
NotImplementedError
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
class
OllamaChatCompletionMessage
(
BaseModel
):
role
:
str
content
:
str
class
OllamaChatCompletionRequest
(
BaseModel
):
pass
model
:
str
=
Field
(...,
description
=
"The model name, which is required."
)
messages
:
List
[
OllamaChatCompletionMessage
]
=
Field
(
...,
description
=
"A list of messages to generate a response for."
)
stream
:
bool
=
Field
(
True
,
description
=
"If true, the response will be streamed."
)
class
OllamaChatCompletionStreamResponse
(
BaseModel
):
pass
model
:
str
created_at
:
str
message
:
dict
done
:
bool
=
Field
(...)
total_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Total time spent in nanoseconds"
)
load_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Time spent loading model in nanoseconds"
)
prompt_eval_count
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Number of tokens in prompt"
)
prompt_eval_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Time spent evaluating prompt in nanoseconds"
)
eval_count
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Number of tokens generated"
)
eval_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Time spent generating response in nanoseconds"
)
class
OllamaChatCompletionResponse
(
BaseModel
):
pass
@
router
.
post
(
"/chat"
,
tags
=
[
'ollama'
])
async
def
chat
(
request
:
Request
,
input
:
OllamaChatCompletionRequest
):
raise
NotImplementedError
id
=
str
(
uuid4
())
interface
:
BackendInterfaceBase
=
get_interface
()
config
=
Config
()
# 将消息转换为提示字符串
prompt
=
""
for
msg
in
input
.
messages
:
prompt
+=
f
"
{
msg
.
role
}
:
{
msg
.
content
}
\n
"
prompt
+=
"assistant:"
if
input
.
stream
:
async
def
inner
():
start_time
=
time
()
# 记录开始时间(秒)
eval_count
=
0
# 统计生成的 token 数量
tokens
=
[]
async
for
res
in
interface
.
inference
(
prompt
,
id
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
OllamaChatCompletionStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
message
=
{
"role"
:
"assistant"
,
"content"
:
token
},
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
# 计算性能数据
end_time
=
time
()
total_duration
=
int
((
end_time
-
start_time
)
*
1_000_000_000
)
# 转换为纳秒
prompt_eval_count
=
len
(
prompt
.
split
())
# 简单估算提示词数量
eval_duration
=
total_duration
# 假设全部时间用于生成(简化)
prompt_eval_duration
=
0
# 假设无单独提示评估时间
load_duration
=
0
# 假设加载时间未知
d
=
OllamaChatCompletionStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
message
=
{},
done
=
True
,
total_duration
=
total_duration
,
load_duration
=
load_duration
,
prompt_eval_count
=
prompt_eval_count
,
prompt_eval_duration
=
prompt_eval_duration
,
eval_count
=
eval_count
,
eval_duration
=
eval_duration
)
yield
d
.
model_dump_json
()
+
'
\n
'
return
check_link_response
(
request
,
inner
())
else
:
raise
NotImplementedError
(
"Non-streaming chat is not implemented."
)
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
class
OllamaModel
(
BaseModel
):
...
...
@@ -103,9 +173,8 @@ class OllamaModel(BaseModel):
size
:
int
# TODO: fill the rest correctly
# mock ollama
@
router
.
get
(
"/tags"
,
tags
=
[
'ollama'
])
@
router
.
get
(
"/tags"
,
tags
=
[
'ollama'
])
async
def
tags
():
config
=
Config
()
# TODO: fill this correctly, although it does not effect Tabby
...
...
@@ -138,25 +207,21 @@ class OllamaShowResponse(BaseModel):
class
Config
:
protected_namespaces
=
()
@
router
.
post
(
"/show"
,
tags
=
[
'ollama'
])
async
def
show
(
request
:
Request
,
input
:
OllamaShowRequest
):
config
=
Config
()
# TODO: Add more info in config to return, although it does not effect Tabby
return
OllamaShowResponse
(
modelfile
=
"# Modelfile generated by ..."
,
parameters
=
" "
,
template
=
" "
,
details
=
OllamaShowDetial
(
parent_model
=
" "
,
format
=
"gguf"
,
family
=
" "
,
families
=
[
" "
],
parameter_size
=
" "
,
quantization_level
=
" "
modelfile
=
"# Modelfile generated by ..."
,
parameters
=
" "
,
template
=
" "
,
details
=
OllamaShowDetial
(
parent_model
=
" "
,
format
=
"gguf"
,
family
=
" "
,
families
=
[
" "
],
parameter_size
=
" "
,
quantization_level
=
" "
),
model_info
=
OllamaModelInfo
()
model_info
=
OllamaModelInfo
()
)
\ No newline at end of file
ktransformers/server/api/openai/endpoints/chat.py
View file @
c009512a
...
...
@@ -5,18 +5,21 @@ from fastapi import APIRouter
from
fastapi.requests
import
Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
,
ChatCompletionChunk
,
ChatCompletionObject
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionChunk
from
openai.types.chat
import
ChatCompletion
from
openai.types.completion_usage
import
CompletionUsage
router
=
APIRouter
()
models
=
[
{
"id"
:
"0"
,
"name"
:
"ktranformers-model"
},
]
router
=
APIRouter
()
@
router
.
get
(
'/models'
,
tags
=
[
'openai'
])
async
def
list_models
():
return
models
return
{
"data"
:
[{
"id"
:
Config
().
model_name
,
"name"
:
Config
().
model_name
}],
"object"
:
"list"
}
@
router
.
post
(
'/chat/completions'
,
tags
=
[
'openai'
])
...
...
@@ -28,15 +31,80 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
input_message
=
[
json
.
loads
(
m
.
model_dump_json
())
for
m
in
create
.
messages
]
if
Config
().
api_key
!=
''
:
assert
request
.
headers
.
get
(
'Authorization'
,
''
).
split
()[
-
1
]
==
Config
().
api_key
if
create
.
stream
:
from
openai.types.chat.chat_completion_chunk
import
Choice
,
ChoiceDelta
async
def
inner
():
chunk
=
ChatCompletionChunk
(
id
=
id
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
input_message
,
id
):
chunk
.
set_token
(
token
)
yield
chunk
return
chat_stream_response
(
request
,
inner
())
chunk
=
ChatCompletionChunk
(
id
=
id
,
choices
=
[],
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()),
model
=
Config
().
model_name
,
)
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
# at the end of inference, interface.inference() will return the usage of inference
raw_usage
=
res
chunk
.
choices
=
[]
chunk
.
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
yield
chunk
else
:
token
,
finish_reason
=
res
choice
=
Choice
(
index
=
0
,
delta
=
ChoiceDelta
(
content
=
token
,
role
=
None
,
tool_calls
=
None
),
finish_reason
=
finish_reason
,
logprobs
=
None
,
)
chunk
.
choices
=
[
choice
]
yield
chunk
return
chat_stream_response
(
request
,
inner
())
else
:
comp
=
ChatCompletionObject
(
id
=
id
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
input_message
,
id
):
comp
.
append_token
(
token
)
return
comp
from
openai.types.chat.chat_completion
import
Choice
from
openai.types.chat.chat_completion_message
import
ChatCompletionMessage
content
=
""
finish_reason
=
None
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
else
:
token
,
finish_reason
=
res
content
=
content
+
token
finish_reason
=
finish_reason
choice
=
Choice
(
index
=
0
,
finish_reason
=
finish_reason
,
message
=
ChatCompletionMessage
(
content
=
content
,
role
=
"assistant"
))
chat_completion
=
ChatCompletion
(
id
=
id
,
choices
=
[
choice
],
created
=
int
(
time
()),
model
=
Config
().
model_name
,
object
=
'chat.completion'
,
usage
=
usage
)
return
chat_completion
ktransformers/server/api/openai/legacy/completions.py
View file @
c009512a
...
...
@@ -6,6 +6,7 @@ from fastapi.requests import Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
stream_response
from
ktransformers.server.schemas.legacy.completions
import
CompletionCreate
,
CompletionObject
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
router
=
APIRouter
()
...
...
@@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate):
print
(
f
'COMPLETION INPUT:----
\n
{
create
.
prompt
}
\n
----'
)
if
create
.
stream
:
async
def
inner
():
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
):
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
return
stream_response
(
request
,
inner
())
else
:
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
):
comp
.
append_token
(
token
)
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
comp
.
append_token
(
token
)
return
comp
ktransformers/server/args.py
View file @
c009512a
...
...
@@ -10,6 +10,7 @@ class ArgumentParser:
parser
=
argparse
.
ArgumentParser
(
prog
=
"kvcache.ai"
,
description
=
"Ktransformers"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
self
.
cfg
.
server_ip
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
self
.
cfg
.
server_port
)
parser
.
add_argument
(
"--api_key"
,
type
=
str
,
default
=
self
.
cfg
.
api_key
)
parser
.
add_argument
(
"--ssl_keyfile"
,
type
=
str
)
parser
.
add_argument
(
"--ssl_certfile"
,
type
=
str
)
parser
.
add_argument
(
"--web"
,
type
=
bool
,
default
=
self
.
cfg
.
mount_web
)
...
...
@@ -23,13 +24,13 @@ class ArgumentParser:
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
self
.
cfg
.
optimize_config_path
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
parser
.
add_argument
(
"--type"
,
type
=
str
,
default
=
self
.
cfg
.
backend_type
)
parser
.
add_argument
(
"--chunk_prefill_size"
,
type
=
int
,
default
=
8192
)
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser
.
add_argument
(
"--paged"
,
type
=
bool
,
default
=
self
.
cfg
.
paged
)
parser
.
add_argument
(
"--total_context"
,
type
=
int
,
default
=
self
.
cfg
.
total_context
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_batch_size
)
parser
.
add_argument
(
"--max_chunk_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_chunk_size
)
parser
.
add_argument
(
"--max_new_tokens"
,
type
=
int
,
default
=
self
.
cfg
.
max_new_tokens
)
parser
.
add_argument
(
"--json_mode"
,
type
=
bool
,
default
=
self
.
cfg
.
json_mode
)
parser
.
add_argument
(
"--healing"
,
type
=
bool
,
default
=
self
.
cfg
.
healing
)
...
...
@@ -90,7 +91,8 @@ class ArgumentParser:
# user config
parser
.
add_argument
(
"--user_secret_key"
,
type
=
str
,
default
=
self
.
cfg
.
user_secret_key
)
parser
.
add_argument
(
"--user_algorithm"
,
type
=
str
,
default
=
self
.
cfg
.
user_algorithm
)
parser
.
add_argument
(
"--force_think"
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--force_think"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--use_cuda_graph"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
use_cuda_graph
)
# web config
parser
.
add_argument
(
"--web_cross_domain"
,
type
=
bool
,
default
=
self
.
cfg
.
web_cross_domain
)
...
...
ktransformers/server/backend/args.py
View file @
c009512a
...
...
@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel):
max_batch_size
:
int
=
Field
(
None
,
description
=
"Max number of batches to run at once, assuming the sequences will fit within total_context"
)
max_
chunk_size
:
int
=
Field
(
chunk_
prefill_
size
:
int
=
Field
(
None
,
description
=
(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
...
...
ktransformers/server/backend/base.py
View file @
c009512a
...
...
@@ -15,6 +15,7 @@ from ktransformers.server.schemas.assistants.assistants import AssistantObject
from
ktransformers.server.schemas.assistants.messages
import
MessageCreate
,
MessageObject
,
Role
from
ktransformers.server.schemas.assistants.runs
import
RunObject
from
ktransformers.server.schemas.assistants.threads
import
ThreadObject
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.schemas.base
import
ObjectID
,
Order
from
ktransformers.server.utils.multi_timer
import
Profiler
...
...
@@ -142,12 +143,16 @@ class ThreadContext:
yield
reply_message
.
stream_response_with_event
(
MessageObject
.
Status
.
in_progress
)
yield
self
.
run
.
stream_response_with_event
(
RunObject
.
Status
.
in_progress
)
async
for
token
in
self
.
interface
.
inference
(
local_messages
,
self
.
thread
.
id
):
if
self
.
run
.
status
==
RunObject
.
Status
.
cancelling
:
logger
.
warn
(
f
'Run
{
self
.
run
.
id
}
cancelling'
)
break
yield
reply_message
.
append_message_delta
(
token
)
response_str_count
+=
1
async
for
res
in
self
.
interface
.
inference
(
local_messages
,
self
.
thread
.
id
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
if
self
.
run
.
status
==
RunObject
.
Status
.
cancelling
:
logger
.
warn
(
f
'Run
{
self
.
run
.
id
}
cancelling'
)
break
yield
reply_message
.
append_message_delta
(
token
)
response_str_count
+=
1
if
self
.
run
.
status
==
RunObject
.
Status
.
cancelling
:
yield
self
.
run
.
stream_response_with_event
(
RunObject
.
Status
.
cancelled
)
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
c009512a
import
torch
import
asyncio
from
transformers
import
AutoTokenizer
,
AutoConfig
,
GenerationConfig
from
ktransformers.server.backend.interfaces.transformers
import
(
TransformersInterface
,
...
...
@@ -13,7 +14,11 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
typing
import
Optional
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
warm_uped
=
False
class
KTransformersThreadContext
(
TransformersThreadContext
):
pass
...
...
@@ -22,19 +27,29 @@ class KTransformersThreadContext(TransformersThreadContext):
class
KTransformersInterface
(
TransformersInterface
):
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
):
self
.
args
=
args
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_grad_enabled
(
False
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_code
)
try
:
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
generation_config
=
GenerationConfig
(
max_length
=
args
.
max_new_tokens
,
temperature
=
args
.
temperature
,
top_p
=
args
.
top_p
,
do_sample
=
True
)
torch
.
set_default_dtype
(
config
.
torch_dtype
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
config
.
_attn_implementation
=
"flash_attention_2"
with
torch
.
device
(
"meta"
):
self
.
model
=
custom_models
[
config
.
architectures
[
0
]](
config
)
if
default_args
.
optimize_config_path
is
None
:
optimize_
rule
_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
optimize_
config
_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
optimize_
rule
_path
=
args
.
optimize_config_path
optimize_
config
_path
=
args
.
optimize_config_path
# print(optimize_config)
...
...
@@ -44,8 +59,8 @@ class KTransformersInterface(TransformersInterface):
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf
(
self
.
model
,
optimize_
rule
_path
,
gguf_path
,
config
)
optimize_and_load_gguf
(
self
.
model
,
optimize_
config
_path
,
gguf_path
,
config
)
self
.
model
.
generation_config
=
generation_config
self
.
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
self
.
cache
=
StaticCache
(
...
...
@@ -56,25 +71,21 @@ class KTransformersInterface(TransformersInterface):
dtype
=
self
.
model
.
dtype
,
)
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
try
:
self
.
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
gen_config
=
GenerationConfig
(
max_length
=
128
,
temperature
=
0.7
,
top_p
=
0.9
,
do_sample
=
True
)
self
.
model
.
generation_config
=
gen_config
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
self
.
_infer_lock
=
asyncio
.
Lock
()
def
decode_one_tokens
(
self
):
global
warm_uped
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
if
self
.
args
.
use_cuda_graph
:
torch
.
cuda
.
set_device
(
torch_device
)
if
warm_uped
and
self
.
args
.
use_cuda_graph
:
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
self
.
cuda_graph_runner
.
capture
(
...
...
@@ -96,14 +107,15 @@ class KTransformersInterface(TransformersInterface):
torch
.
cuda
.
synchronize
()
logits
=
logits
[
0
,
-
1
,
:]
return
self
.
logits_to_token
(
logits
)
if
self
.
args
.
use_cuda_graph
:
warm_uped
=
True
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
logits
=
self
.
model
(
self
.
current_ids
.
to
(
torch_device
),
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
use_cache
=
True
,
)[
0
]
...
...
@@ -116,59 +128,116 @@ class KTransformersInterface(TransformersInterface):
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
if
(
input_ids_length
>=
self
.
args
.
cache_lens
):
logger
.
warning
(
f
"input_ids_length
{
input_ids_length
}
> cache_lens
{
self
.
args
.
cache_lens
}
"
)
self
.
seq_length
=
input_ids_length
return
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
"cuda:0"
if
device
==
"cuda"
else
device
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
self
.
seq_length
=
input_ids_length
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
same_prefix
=
0
flat_input_ids
=
input_ids
.
flatten
()
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
self
.
seq_length
=
1
flat_prev_ids
=
self
.
generated_ids
.
flatten
()
for
i
in
range
(
min
(
self
.
seq_length
,
flat_input_ids
.
shape
[
0
])
-
1
):
if
flat_input_ids
[
i
]
==
flat_prev_ids
[
i
]:
same_prefix
+=
1
else
:
break
logger
.
debug
(
f
"same prefix len:
{
same_prefix
}
"
)
self
.
cache
.
remove_suffix
(
same_prefix
)
self
.
seq_length
=
same_prefix
self
.
generated_ids
=
self
.
generated_ids
[...,
:
same_prefix
]
input_ids
=
input_ids
[...,
same_prefix
:]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
min
(
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
args
.
cache_lens
)
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
else
:
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
warning
(
f
"seq_length bigger than cache_lens, killed"
)
exit
(
0
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
device
)
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
self
.
cache
,
return_dict
=
False
,
use_cache
=
True
,
attention_mask
=
mask
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
def
chunk_prefill
(
input_ids
,
cache_position
):
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
if
flashinfer_enabled
:
MLAWrapperSingleton
.
need_plan_all
()
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
self
.
cache
,
return_dict
=
False
,
use_cache
=
True
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
return
logits
chunk_start
=
0
while
chunk_start
<
input_ids_length
:
chunk_end
=
min
(
chunk_start
+
self
.
args
.
chunk_prefill_size
,
input_ids_length
)
if
self
.
cache
!=
None
:
self
.
cache
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
input_ids
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
])
chunk_start
+=
self
.
args
.
chunk_prefill_size
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
@
property
def
active_cache_position
(
self
):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
\ No newline at end of file
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
yield
v
# return this inference raw usage
yield
RawUsage
(
tokenize_time
=
self
.
profiler
.
get_timer_sec
(
'tokenize'
),
prefill_time
=
self
.
profiler
.
get_timer_sec
(
'prefill'
),
decode_time
=
self
.
profiler
.
get_timer_sec
(
'decode'
),
prefill_count
=
self
.
profiler
.
get_counter
(
'prefill'
),
decode_count
=
self
.
profiler
.
get_counter
(
'decode'
),
)
\ No newline at end of file
ktransformers/server/backend/interfaces/transformers.py
View file @
c009512a
...
...
@@ -13,12 +13,13 @@ from transformers import (
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.schemas.base
import
ObjectID
from
ktransformers.server.utils.multi_timer
import
Profiler
from
torch.nn.attention
import
SDPBackend
import
torch
import
sys
,
os
from
..base
import
ThreadContext
,
BackendInterfaceBase
from
ktransformers.server.config.log
import
logger
from
..args
import
ConfigArgs
,
default_args
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class
TextStreamer
:
...
...
@@ -170,7 +171,7 @@ class TransformersInterface(BackendInterfaceBase):
for
m
in
messages
[
1
:]:
if
m
[
"role"
]
==
"user"
and
new_messages
[
-
1
][
"role"
]
==
"user"
:
logger
.
warning
(
"merge two adjacent user messages"
)
new_messages
[
-
1
][
"content"
]
+=
m
[
"content"
]
new_messages
[
-
1
][
"content"
]
+=
'
\n
'
+
m
[
"content"
]
else
:
new_messages
.
append
(
m
)
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
...
...
@@ -179,7 +180,11 @@ class TransformersInterface(BackendInterfaceBase):
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
input_ids
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
return_tensors
=
'pt'
,
add_generation_prompt
=
True
).
to
(
self
.
args
.
device
)
input_str
:
str
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
# drop <think> token in chat template
if
input_str
.
endswith
(
'<think>
\n
'
):
input_str
=
input_str
[:
-
len
(
'<think>
\n
'
)]
input_ids
=
self
.
tokenizer
.
encode
(
input_str
,
return_tensors
=
"pt"
).
to
(
self
.
args
.
device
)
if
(
self
.
last_request_id
is
not
None
)
and
self
.
last_request_id
==
thread_id
:
x
=
self
.
generated_ids
[:,:
self
.
seq_length
]
y
=
input_ids
[:,:
self
.
seq_length
]
...
...
@@ -198,14 +203,31 @@ class TransformersInterface(BackendInterfaceBase):
self
.
seq_length
+=
1
return
self
.
streamer
.
put
(
new_tokens
)
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
logits
=
logits
/
self
.
args
.
temperature
if
self
.
args
.
temperature
!=
0
else
logits
def
prepare_logits_wrapper
(
self
,
inputs
,
device
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
if
temperature
is
None
or
temperature
==
0
:
temperature
=
self
.
model
.
generation_config
.
temperature
if
top_p
is
None
:
top_p
=
self
.
model
.
generation_config
.
top_p
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
top_k
=
self
.
args
.
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
)
self
.
inputs
=
inputs
try
:
# transformers==4.43
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
,
device
=
device
)
)
except
:
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
)
)
for
token_idx
in
self
.
ever_generated_ids
:
if
logits
[
token_idx
]
<
0
:
logits
[
token_idx
]
*=
self
.
args
.
repetition_penalty
else
:
logits
[
token_idx
]
/=
self
.
args
.
repetition_penalty
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
logits
=
self
.
logits_warper
(
self
.
inputs
.
view
(
1
,
-
1
),
logits
.
view
(
1
,
-
1
))
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
...
...
@@ -221,12 +243,10 @@ class TransformersInterface(BackendInterfaceBase):
def
decode_one_tokens
(
self
):
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
logits
=
self
.
model
(
self
.
current_ids
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
use_cache
=
True
,
)[
0
]
...
...
@@ -237,38 +257,57 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
self
.
seq_length
=
input_ids_length
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
else
:
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
same_prefix
=
0
flat_input_ids
=
input_ids
.
flatten
()
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
self
.
seq_length
=
1
flat_prev_ids
=
self
.
generated_ids
.
flatten
()
for
i
in
range
(
min
(
self
.
seq_length
,
flat_input_ids
.
shape
[
0
])
-
1
):
if
flat_input_ids
[
i
]
==
flat_prev_ids
[
i
]:
same_prefix
+=
1
else
:
break
logger
.
debug
(
f
"same prefix len:
{
same_prefix
}
"
)
self
.
cache
.
remove_suffix
(
same_prefix
)
self
.
seq_length
=
same_prefix
self
.
generated_ids
=
self
.
generated_ids
[...,
:
same_prefix
]
input_ids
=
input_ids
[...,
same_prefix
:]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
device
=
input_ids
.
device
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
...
...
@@ -280,26 +319,46 @@ class TransformersInterface(BackendInterfaceBase):
past_key_values
=
self
.
cache
,
return_dict
=
False
,
use_cache
=
True
,
attention_mask
=
mask
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
@
torch
.
no_grad
def
generate
(
self
):
self
.
max_new_tokens
=
min
(
self
.
args
.
max_new_tokens
,
self
.
args
.
cache_lens
-
self
.
seq_length
)
-
1
logger
.
info
(
f
"args.max_new_tokens:
{
self
.
args
.
max_new_tokens
}
, cache_lens:
{
self
.
args
.
cache_lens
}
, seq_length:
{
self
.
seq_length
}
"
)
if
(
self
.
max_new_tokens
<=
0
):
logger
.
warning
(
"max_new_tokens is less than 0"
)
yield
self
.
streamer
.
end
(),
"length"
return
logger
.
info
(
f
"max_new_tokens:
{
self
.
max_new_tokens
}
"
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
for
_
in
range
(
1
,
self
.
args
.
max_new_tokens
):
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_flash
=
False
,
enable_mem_efficient
=
False
,
enable_math
=
True
):
for
i
in
range
(
1
,
self
.
max_new_tokens
):
with
torch
.
nn
.
attention
.
sdpa_kernel
(
backends
=
[
SDPBackend
.
FLASH_ATTENTION
,
SDPBackend
.
MATH
,
SDPBackend
.
EFFICIENT_ATTENTION
]):
if
flashinfer_enabled
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
sm_scale
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
next_token
=
self
.
decode_one_tokens
()
self
.
profiler
.
inc
(
"decode"
)
if
next_token
==
self
.
tokenizer
.
eos_token_id
:
if
next_token
==
self
.
tokenizer
.
eos_token_id
or
"<|im_end|>"
==
self
.
tokenizer
.
decode
(
next_token
):
yield
self
.
streamer
.
end
(),
None
yield
""
,
"stop"
assert
self
.
args
.
batch_size
==
1
break
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
streamer
.
end
()
yield
self
.
append_new_tokens
(
next_token
),
None
else
:
# for's else, if output get max new tokens
yield
self
.
streamer
.
end
(),
None
yield
""
,
"length"
def
check_is_new
(
self
,
thread_id
:
str
):
if
not
self
.
use_static_cache
:
...
...
@@ -314,7 +373,8 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
...
...
@@ -324,8 +384,9 @@ class TransformersInterface(BackendInterfaceBase):
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
raise
ValueError
(
"local_messages should be List or str"
)
if
Config
().
user_force_think
:
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\
\
n"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\n
"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
(
[
input_ids
,
token_thinks
],
dim
=
1
)
...
...
@@ -333,21 +394,24 @@ class TransformersInterface(BackendInterfaceBase):
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
if
Config
().
user_force_think
:
t
=
"<think>
\n
"
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
think
=
'<think>
\n
'
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
,
None
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
# output think token after prefill done
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
,
None
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
for
t
in
self
.
generate
():
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
,
finish_reason
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
report_last_time_performance
()
ktransformers/server/config/config.py
View file @
c009512a
...
...
@@ -69,6 +69,7 @@ class Config(metaclass=Singleton):
self
.
server
:
dict
=
cfg
.
get
(
"server"
,
{})
self
.
server_ip
=
self
.
server
.
get
(
"ip"
,
"0.0.0.0"
)
self
.
server_port
=
self
.
server
.
get
(
"port"
,
9016
)
self
.
api_key
=
self
.
server
.
get
(
"api_key"
,
""
)
# db configs
self
.
db_configs
:
dict
=
cfg
.
get
(
"db"
,
{})
...
...
@@ -104,7 +105,8 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
chunk_prefill_size
=
self
.
model
.
get
(
"chunk_prefill_size"
,
8192
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
...
...
ktransformers/server/main.py
View file @
c009512a
...
...
@@ -105,6 +105,7 @@ def custom_openapi(app):
def
main
():
cfg
=
Config
()
arg_parser
=
ArgumentParser
(
cfg
)
# 初始化消息
...
...
ktransformers/server/requirements.txt
View file @
c009512a
...
...
@@ -5,6 +5,7 @@ langchain >= 0.2.0
blessed >= 1.20.0
accelerate >= 0.31.0
sentencepiece >= 0.1.97
openai
setuptools
build
ninja
...
...
ktransformers/server/schemas/assistants/streaming.py
View file @
c009512a
...
...
@@ -73,7 +73,7 @@ class RunStepDelta(Object):
class
Done
():
def
to_stream_reply
(
self
):
return
f
"
event: done
\n
data: [DONE]
\n\n
"
return
f
"data: [DONE]
\n\n
"
async
def
check_client_link
(
request
:
Request
,
async_events
:
AsyncIterable
):
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment