Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
c009512a
Commit
c009512a
authored
Mar 13, 2025
by
Azure-Tang
Browse files
Merge branch 'main' into hip
parents
c1f13a69
4f22d726
Changes
121
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
772 additions
and
179 deletions
+772
-179
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
...s/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
+157
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
...ize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
+13
-2
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+12
-1
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+13
-0
ktransformers/optimize/optimize_rules/Mixtral.yaml
ktransformers/optimize/optimize_rules/Mixtral.yaml
+10
-0
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
+86
-0
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
...ize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
+11
-1
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
...mers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
+10
-0
ktransformers/server/api/ollama/completions.py
ktransformers/server/api/ollama/completions.py
+104
-39
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+83
-15
ktransformers/server/api/openai/legacy/completions.py
ktransformers/server/api/openai/legacy/completions.py
+14
-6
ktransformers/server/args.py
ktransformers/server/args.py
+4
-2
ktransformers/server/backend/args.py
ktransformers/server/backend/args.py
+1
-1
ktransformers/server/backend/base.py
ktransformers/server/backend/base.py
+11
-6
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+124
-55
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+113
-49
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+3
-1
ktransformers/server/main.py
ktransformers/server/main.py
+1
-0
ktransformers/server/requirements.txt
ktransformers/server/requirements.txt
+1
-0
ktransformers/server/schemas/assistants/streaming.py
ktransformers/server/schemas/assistants/streaming.py
+1
-1
No files found.
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
0 → 100644
View file @
c009512a
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
generate_op
:
"
KLinearFP8"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
generate_op
:
"
KLinearFP8"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:0"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:0"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:1"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:1"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
transfer_map
:
30
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
View file @
c009512a
...
@@ -153,9 +153,20 @@
...
@@ -153,9 +153,20 @@
prefill_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)|(lm_head)"
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
replace
:
class
:
"
default"
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
kwargs
:
generate_device
:
"
cuda:0"
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
View file @
c009512a
...
@@ -135,7 +135,18 @@
...
@@ -135,7 +135,18 @@
prefill_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)|(lm_head)"
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)"
replace
:
replace
:
class
:
"
default"
class
:
"
default"
kwargs
:
kwargs
:
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
View file @
c009512a
...
@@ -5,6 +5,18 @@
...
@@ -5,6 +5,18 @@
kwargs
:
kwargs
:
generate_device
:
"
cuda"
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
...
@@ -48,6 +60,7 @@
...
@@ -48,6 +60,7 @@
kwargs
:
kwargs
:
generate_device
:
"
cuda"
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
-
match
:
name
:
"
^model$"
name
:
"
^model$"
replace
:
replace
:
...
...
ktransformers/optimize/optimize_rules/Mixtral.yaml
View file @
c009512a
...
@@ -15,6 +15,16 @@
...
@@ -15,6 +15,16 @@
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.block_sparse_moe$"
name
:
"
^model
\\
.layers
\\
..*
\\
.block_sparse_moe$"
class
:
ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock
class
:
ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock
...
...
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
0 → 100644
View file @
c009512a
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.RotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)
#- match:
# name: "^model\\.layers\\..*\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
# kwargs:
# prefill_device: "cuda"
# prefill_op: "KExpertsTorch"
# generate_device: "cuda"
# generate_op: "KExpertsMarlin"
# recursive: False # don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
\ No newline at end of file
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
View file @
c009512a
...
@@ -77,9 +77,19 @@
...
@@ -77,9 +77,19 @@
kwargs
:
kwargs
:
generate_device
:
"
cpu"
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
-
match
:
name
:
"
(^model.norm)
|(^lm_head)
"
name
:
"
(^model.norm)"
replace
:
replace
:
class
:
"
default"
class
:
"
default"
kwargs
:
kwargs
:
...
...
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
View file @
c009512a
...
@@ -15,6 +15,16 @@
...
@@ -15,6 +15,16 @@
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
class
:
ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
...
...
ktransformers/server/api/ollama/completions.py
View file @
c009512a
...
@@ -12,8 +12,10 @@ from ktransformers.server.config.config import Config
...
@@ -12,8 +12,10 @@ from ktransformers.server.config.config import Config
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
check_link_response
from
ktransformers.server.schemas.assistants.streaming
import
check_link_response
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.backend.base
import
BackendInterfaceBase
router
=
APIRouter
(
prefix
=
'/api'
)
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
router
=
APIRouter
(
prefix
=
'/api'
)
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
class
OllamaGenerateCompletionRequest
(
BaseModel
):
class
OllamaGenerateCompletionRequest
(
BaseModel
):
...
@@ -40,61 +42,129 @@ class OllamaGenerateCompletionRequest(BaseModel):
...
@@ -40,61 +42,129 @@ class OllamaGenerateCompletionRequest(BaseModel):
keep_alive
:
Optional
[
str
]
=
Field
(
keep_alive
:
Optional
[
str
]
=
Field
(
"5m"
,
description
=
"Controls how long the model will stay loaded into memory following the request."
)
"5m"
,
description
=
"Controls how long the model will stay loaded into memory following the request."
)
class
OllamaGenerationStreamResponse
(
BaseModel
):
class
OllamaGenerationStreamResponse
(
BaseModel
):
model
:
str
model
:
str
created_at
:
str
created_at
:
str
response
:
str
response
:
str
done
:
bool
=
Field
(...)
done
:
bool
=
Field
(...)
class
OllamaGenerationResponse
(
BaseModel
):
class
OllamaGenerationResponse
(
BaseModel
):
pass
pass
@
router
.
post
(
"/generate"
,
tags
=
[
'ollama'
])
@
router
.
post
(
"/generate"
,
tags
=
[
'ollama'
])
async
def
generate
(
request
:
Request
,
input
:
OllamaGenerateCompletionRequest
):
async
def
generate
(
request
:
Request
,
input
:
OllamaGenerateCompletionRequest
):
id
=
str
(
uuid4
())
id
=
str
(
uuid4
())
interface
:
BackendInterfaceBase
=
get_interface
()
interface
:
BackendInterfaceBase
=
get_interface
()
print
(
f
'COMPLETION INPUT:----
\n
{
input
.
prompt
}
\n
----'
)
print
(
f
'COMPLETION INPUT:----
\n
{
input
.
prompt
}
\n
----'
)
config
=
Config
()
config
=
Config
()
if
input
.
stream
:
if
input
.
stream
:
async
def
inner
():
async
def
inner
():
async
for
token
in
interface
.
inference
(
input
.
prompt
,
id
):
async
for
res
in
interface
.
inference
(
input
.
prompt
,
id
):
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
token
,
done
=
False
)
if
isinstance
(
res
,
RawUsage
):
yield
d
.
model_dump_json
()
+
'
\n
'
raw_usage
=
res
# d = {'model':config.model_name,'created_at':"", 'response':token,'done':False}
else
:
# yield f"{json.dumps(d)}\n"
token
,
finish_reason
=
res
# d = {'model':config.model_name,'created_at':"", 'response':'','done':True}
d
=
OllamaGenerationStreamResponse
(
# yield f"{json.dumps(d)}\n"
model
=
config
.
model_name
,
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
''
,
done
=
True
)
created_at
=
str
(
datetime
.
now
()),
yield
d
.
model_dump_json
()
+
'
\n
'
response
=
token
,
return
check_link_response
(
request
,
inner
())
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
''
,
done
=
True
)
yield
d
.
model_dump_json
()
+
'
\n
'
return
check_link_response
(
request
,
inner
())
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
class
OllamaChatCompletionMessage
(
BaseModel
):
role
:
str
content
:
str
class
OllamaChatCompletionRequest
(
BaseModel
):
class
OllamaChatCompletionRequest
(
BaseModel
):
pass
model
:
str
=
Field
(...,
description
=
"The model name, which is required."
)
messages
:
List
[
OllamaChatCompletionMessage
]
=
Field
(
...,
description
=
"A list of messages to generate a response for."
)
stream
:
bool
=
Field
(
True
,
description
=
"If true, the response will be streamed."
)
class
OllamaChatCompletionStreamResponse
(
BaseModel
):
class
OllamaChatCompletionStreamResponse
(
BaseModel
):
pass
model
:
str
created_at
:
str
message
:
dict
done
:
bool
=
Field
(...)
total_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Total time spent in nanoseconds"
)
load_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Time spent loading model in nanoseconds"
)
prompt_eval_count
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Number of tokens in prompt"
)
prompt_eval_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Time spent evaluating prompt in nanoseconds"
)
eval_count
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Number of tokens generated"
)
eval_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Time spent generating response in nanoseconds"
)
class
OllamaChatCompletionResponse
(
BaseModel
):
class
OllamaChatCompletionResponse
(
BaseModel
):
pass
pass
@
router
.
post
(
"/chat"
,
tags
=
[
'ollama'
])
@
router
.
post
(
"/chat"
,
tags
=
[
'ollama'
])
async
def
chat
(
request
:
Request
,
input
:
OllamaChatCompletionRequest
):
async
def
chat
(
request
:
Request
,
input
:
OllamaChatCompletionRequest
):
raise
NotImplementedError
id
=
str
(
uuid4
())
interface
:
BackendInterfaceBase
=
get_interface
()
config
=
Config
()
# 将消息转换为提示字符串
prompt
=
""
for
msg
in
input
.
messages
:
prompt
+=
f
"
{
msg
.
role
}
:
{
msg
.
content
}
\n
"
prompt
+=
"assistant:"
if
input
.
stream
:
async
def
inner
():
start_time
=
time
()
# 记录开始时间(秒)
eval_count
=
0
# 统计生成的 token 数量
tokens
=
[]
async
for
res
in
interface
.
inference
(
prompt
,
id
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
OllamaChatCompletionStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
message
=
{
"role"
:
"assistant"
,
"content"
:
token
},
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
# 计算性能数据
end_time
=
time
()
total_duration
=
int
((
end_time
-
start_time
)
*
1_000_000_000
)
# 转换为纳秒
prompt_eval_count
=
len
(
prompt
.
split
())
# 简单估算提示词数量
eval_duration
=
total_duration
# 假设全部时间用于生成(简化)
prompt_eval_duration
=
0
# 假设无单独提示评估时间
load_duration
=
0
# 假设加载时间未知
d
=
OllamaChatCompletionStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
message
=
{},
done
=
True
,
total_duration
=
total_duration
,
load_duration
=
load_duration
,
prompt_eval_count
=
prompt_eval_count
,
prompt_eval_duration
=
prompt_eval_duration
,
eval_count
=
eval_count
,
eval_duration
=
eval_duration
)
yield
d
.
model_dump_json
()
+
'
\n
'
return
check_link_response
(
request
,
inner
())
else
:
raise
NotImplementedError
(
"Non-streaming chat is not implemented."
)
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
class
OllamaModel
(
BaseModel
):
class
OllamaModel
(
BaseModel
):
...
@@ -103,9 +173,8 @@ class OllamaModel(BaseModel):
...
@@ -103,9 +173,8 @@ class OllamaModel(BaseModel):
size
:
int
size
:
int
# TODO: fill the rest correctly
# TODO: fill the rest correctly
# mock ollama
# mock ollama
@
router
.
get
(
"/tags"
,
tags
=
[
'ollama'
])
@
router
.
get
(
"/tags"
,
tags
=
[
'ollama'
])
async
def
tags
():
async
def
tags
():
config
=
Config
()
config
=
Config
()
# TODO: fill this correctly, although it does not effect Tabby
# TODO: fill this correctly, although it does not effect Tabby
...
@@ -138,25 +207,21 @@ class OllamaShowResponse(BaseModel):
...
@@ -138,25 +207,21 @@ class OllamaShowResponse(BaseModel):
class
Config
:
class
Config
:
protected_namespaces
=
()
protected_namespaces
=
()
@
router
.
post
(
"/show"
,
tags
=
[
'ollama'
])
@
router
.
post
(
"/show"
,
tags
=
[
'ollama'
])
async
def
show
(
request
:
Request
,
input
:
OllamaShowRequest
):
async
def
show
(
request
:
Request
,
input
:
OllamaShowRequest
):
config
=
Config
()
config
=
Config
()
# TODO: Add more info in config to return, although it does not effect Tabby
# TODO: Add more info in config to return, although it does not effect Tabby
return
OllamaShowResponse
(
return
OllamaShowResponse
(
modelfile
=
"# Modelfile generated by ..."
,
modelfile
=
"# Modelfile generated by ..."
,
parameters
=
" "
,
parameters
=
" "
,
template
=
" "
,
template
=
" "
,
details
=
OllamaShowDetial
(
details
=
OllamaShowDetial
(
parent_model
=
" "
,
parent_model
=
" "
,
format
=
"gguf"
,
format
=
"gguf"
,
family
=
" "
,
family
=
" "
,
families
=
[
families
=
[
" "
],
" "
parameter_size
=
" "
,
],
quantization_level
=
" "
parameter_size
=
" "
,
quantization_level
=
" "
),
),
model_info
=
OllamaModelInfo
()
model_info
=
OllamaModelInfo
()
)
)
\ No newline at end of file
ktransformers/server/api/openai/endpoints/chat.py
View file @
c009512a
...
@@ -5,18 +5,21 @@ from fastapi import APIRouter
...
@@ -5,18 +5,21 @@ from fastapi import APIRouter
from
fastapi.requests
import
Request
from
fastapi.requests
import
Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
,
ChatCompletionChunk
,
ChatCompletionObject
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionChunk
from
openai.types.chat
import
ChatCompletion
from
openai.types.completion_usage
import
CompletionUsage
router
=
APIRouter
()
models
=
[
router
=
APIRouter
()
{
"id"
:
"0"
,
"name"
:
"ktranformers-model"
},
]
@
router
.
get
(
'/models'
,
tags
=
[
'openai'
])
@
router
.
get
(
'/models'
,
tags
=
[
'openai'
])
async
def
list_models
():
async
def
list_models
():
return
models
return
{
"data"
:
[{
"id"
:
Config
().
model_name
,
"name"
:
Config
().
model_name
}],
"object"
:
"list"
}
@
router
.
post
(
'/chat/completions'
,
tags
=
[
'openai'
])
@
router
.
post
(
'/chat/completions'
,
tags
=
[
'openai'
])
...
@@ -28,15 +31,80 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
...
@@ -28,15 +31,80 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
input_message
=
[
json
.
loads
(
m
.
model_dump_json
())
for
m
in
create
.
messages
]
input_message
=
[
json
.
loads
(
m
.
model_dump_json
())
for
m
in
create
.
messages
]
if
Config
().
api_key
!=
''
:
assert
request
.
headers
.
get
(
'Authorization'
,
''
).
split
()[
-
1
]
==
Config
().
api_key
if
create
.
stream
:
if
create
.
stream
:
from
openai.types.chat.chat_completion_chunk
import
Choice
,
ChoiceDelta
async
def
inner
():
async
def
inner
():
chunk
=
ChatCompletionChunk
(
id
=
id
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()))
chunk
=
ChatCompletionChunk
(
async
for
token
in
interface
.
inference
(
input_message
,
id
):
id
=
id
,
chunk
.
set_token
(
token
)
choices
=
[],
yield
chunk
object
=
'chat.completion.chunk'
,
return
chat_stream_response
(
request
,
inner
())
created
=
int
(
time
()),
model
=
Config
().
model_name
,
)
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
# at the end of inference, interface.inference() will return the usage of inference
raw_usage
=
res
chunk
.
choices
=
[]
chunk
.
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
yield
chunk
else
:
token
,
finish_reason
=
res
choice
=
Choice
(
index
=
0
,
delta
=
ChoiceDelta
(
content
=
token
,
role
=
None
,
tool_calls
=
None
),
finish_reason
=
finish_reason
,
logprobs
=
None
,
)
chunk
.
choices
=
[
choice
]
yield
chunk
return
chat_stream_response
(
request
,
inner
())
else
:
else
:
comp
=
ChatCompletionObject
(
id
=
id
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()))
from
openai.types.chat.chat_completion
import
Choice
async
for
token
in
interface
.
inference
(
input_message
,
id
):
from
openai.types.chat.chat_completion_message
import
ChatCompletionMessage
comp
.
append_token
(
token
)
return
comp
content
=
""
finish_reason
=
None
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
else
:
token
,
finish_reason
=
res
content
=
content
+
token
finish_reason
=
finish_reason
choice
=
Choice
(
index
=
0
,
finish_reason
=
finish_reason
,
message
=
ChatCompletionMessage
(
content
=
content
,
role
=
"assistant"
))
chat_completion
=
ChatCompletion
(
id
=
id
,
choices
=
[
choice
],
created
=
int
(
time
()),
model
=
Config
().
model_name
,
object
=
'chat.completion'
,
usage
=
usage
)
return
chat_completion
ktransformers/server/api/openai/legacy/completions.py
View file @
c009512a
...
@@ -6,6 +6,7 @@ from fastapi.requests import Request
...
@@ -6,6 +6,7 @@ from fastapi.requests import Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
stream_response
from
ktransformers.server.schemas.assistants.streaming
import
stream_response
from
ktransformers.server.schemas.legacy.completions
import
CompletionCreate
,
CompletionObject
from
ktransformers.server.schemas.legacy.completions
import
CompletionCreate
,
CompletionObject
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
router
=
APIRouter
()
router
=
APIRouter
()
...
@@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate):
...
@@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate):
print
(
f
'COMPLETION INPUT:----
\n
{
create
.
prompt
}
\n
----'
)
print
(
f
'COMPLETION INPUT:----
\n
{
create
.
prompt
}
\n
----'
)
if
create
.
stream
:
if
create
.
stream
:
async
def
inner
():
async
def
inner
():
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
):
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
if
isinstance
(
res
,
RawUsage
):
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
raw_usage
=
res
else
:
token
,
finish_reason
=
res
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
token
}}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
d
=
{
'choices'
:[{
'delta'
:{
'content'
:
''
},
'finish_reason'
:
''
}]}
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
yield
f
"data:
{
json
.
dumps
(
d
)
}
\n\n
"
return
stream_response
(
request
,
inner
())
return
stream_response
(
request
,
inner
())
else
:
else
:
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
async
for
token
in
interface
.
inference
(
create
.
prompt
,
id
):
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
comp
.
append_token
(
token
)
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
token
,
finish_reason
=
res
comp
.
append_token
(
token
)
return
comp
return
comp
ktransformers/server/args.py
View file @
c009512a
...
@@ -10,6 +10,7 @@ class ArgumentParser:
...
@@ -10,6 +10,7 @@ class ArgumentParser:
parser
=
argparse
.
ArgumentParser
(
prog
=
"kvcache.ai"
,
description
=
"Ktransformers"
)
parser
=
argparse
.
ArgumentParser
(
prog
=
"kvcache.ai"
,
description
=
"Ktransformers"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
self
.
cfg
.
server_ip
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
self
.
cfg
.
server_ip
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
self
.
cfg
.
server_port
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
self
.
cfg
.
server_port
)
parser
.
add_argument
(
"--api_key"
,
type
=
str
,
default
=
self
.
cfg
.
api_key
)
parser
.
add_argument
(
"--ssl_keyfile"
,
type
=
str
)
parser
.
add_argument
(
"--ssl_keyfile"
,
type
=
str
)
parser
.
add_argument
(
"--ssl_certfile"
,
type
=
str
)
parser
.
add_argument
(
"--ssl_certfile"
,
type
=
str
)
parser
.
add_argument
(
"--web"
,
type
=
bool
,
default
=
self
.
cfg
.
mount_web
)
parser
.
add_argument
(
"--web"
,
type
=
bool
,
default
=
self
.
cfg
.
mount_web
)
...
@@ -23,13 +24,13 @@ class ArgumentParser:
...
@@ -23,13 +24,13 @@ class ArgumentParser:
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
self
.
cfg
.
optimize_config_path
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
self
.
cfg
.
optimize_config_path
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
parser
.
add_argument
(
"--type"
,
type
=
str
,
default
=
self
.
cfg
.
backend_type
)
parser
.
add_argument
(
"--type"
,
type
=
str
,
default
=
self
.
cfg
.
backend_type
)
parser
.
add_argument
(
"--chunk_prefill_size"
,
type
=
int
,
default
=
8192
)
# model configs
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser
.
add_argument
(
"--paged"
,
type
=
bool
,
default
=
self
.
cfg
.
paged
)
parser
.
add_argument
(
"--paged"
,
type
=
bool
,
default
=
self
.
cfg
.
paged
)
parser
.
add_argument
(
"--total_context"
,
type
=
int
,
default
=
self
.
cfg
.
total_context
)
parser
.
add_argument
(
"--total_context"
,
type
=
int
,
default
=
self
.
cfg
.
total_context
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_batch_size
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_batch_size
)
parser
.
add_argument
(
"--max_chunk_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_chunk_size
)
parser
.
add_argument
(
"--max_new_tokens"
,
type
=
int
,
default
=
self
.
cfg
.
max_new_tokens
)
parser
.
add_argument
(
"--max_new_tokens"
,
type
=
int
,
default
=
self
.
cfg
.
max_new_tokens
)
parser
.
add_argument
(
"--json_mode"
,
type
=
bool
,
default
=
self
.
cfg
.
json_mode
)
parser
.
add_argument
(
"--json_mode"
,
type
=
bool
,
default
=
self
.
cfg
.
json_mode
)
parser
.
add_argument
(
"--healing"
,
type
=
bool
,
default
=
self
.
cfg
.
healing
)
parser
.
add_argument
(
"--healing"
,
type
=
bool
,
default
=
self
.
cfg
.
healing
)
...
@@ -90,7 +91,8 @@ class ArgumentParser:
...
@@ -90,7 +91,8 @@ class ArgumentParser:
# user config
# user config
parser
.
add_argument
(
"--user_secret_key"
,
type
=
str
,
default
=
self
.
cfg
.
user_secret_key
)
parser
.
add_argument
(
"--user_secret_key"
,
type
=
str
,
default
=
self
.
cfg
.
user_secret_key
)
parser
.
add_argument
(
"--user_algorithm"
,
type
=
str
,
default
=
self
.
cfg
.
user_algorithm
)
parser
.
add_argument
(
"--user_algorithm"
,
type
=
str
,
default
=
self
.
cfg
.
user_algorithm
)
parser
.
add_argument
(
"--force_think"
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--force_think"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--use_cuda_graph"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
use_cuda_graph
)
# web config
# web config
parser
.
add_argument
(
"--web_cross_domain"
,
type
=
bool
,
default
=
self
.
cfg
.
web_cross_domain
)
parser
.
add_argument
(
"--web_cross_domain"
,
type
=
bool
,
default
=
self
.
cfg
.
web_cross_domain
)
...
...
ktransformers/server/backend/args.py
View file @
c009512a
...
@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel):
...
@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel):
max_batch_size
:
int
=
Field
(
max_batch_size
:
int
=
Field
(
None
,
description
=
"Max number of batches to run at once, assuming the sequences will fit within total_context"
None
,
description
=
"Max number of batches to run at once, assuming the sequences will fit within total_context"
)
)
max_
chunk_size
:
int
=
Field
(
chunk_
prefill_
size
:
int
=
Field
(
None
,
None
,
description
=
(
description
=
(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
...
...
ktransformers/server/backend/base.py
View file @
c009512a
...
@@ -15,6 +15,7 @@ from ktransformers.server.schemas.assistants.assistants import AssistantObject
...
@@ -15,6 +15,7 @@ from ktransformers.server.schemas.assistants.assistants import AssistantObject
from
ktransformers.server.schemas.assistants.messages
import
MessageCreate
,
MessageObject
,
Role
from
ktransformers.server.schemas.assistants.messages
import
MessageCreate
,
MessageObject
,
Role
from
ktransformers.server.schemas.assistants.runs
import
RunObject
from
ktransformers.server.schemas.assistants.runs
import
RunObject
from
ktransformers.server.schemas.assistants.threads
import
ThreadObject
from
ktransformers.server.schemas.assistants.threads
import
ThreadObject
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.schemas.base
import
ObjectID
,
Order
from
ktransformers.server.schemas.base
import
ObjectID
,
Order
from
ktransformers.server.utils.multi_timer
import
Profiler
from
ktransformers.server.utils.multi_timer
import
Profiler
...
@@ -142,12 +143,16 @@ class ThreadContext:
...
@@ -142,12 +143,16 @@ class ThreadContext:
yield
reply_message
.
stream_response_with_event
(
MessageObject
.
Status
.
in_progress
)
yield
reply_message
.
stream_response_with_event
(
MessageObject
.
Status
.
in_progress
)
yield
self
.
run
.
stream_response_with_event
(
RunObject
.
Status
.
in_progress
)
yield
self
.
run
.
stream_response_with_event
(
RunObject
.
Status
.
in_progress
)
async
for
token
in
self
.
interface
.
inference
(
local_messages
,
self
.
thread
.
id
):
async
for
res
in
self
.
interface
.
inference
(
local_messages
,
self
.
thread
.
id
):
if
self
.
run
.
status
==
RunObject
.
Status
.
cancelling
:
if
isinstance
(
res
,
RawUsage
):
logger
.
warn
(
f
'Run
{
self
.
run
.
id
}
cancelling'
)
raw_usage
=
res
break
else
:
yield
reply_message
.
append_message_delta
(
token
)
token
,
finish_reason
=
res
response_str_count
+=
1
if
self
.
run
.
status
==
RunObject
.
Status
.
cancelling
:
logger
.
warn
(
f
'Run
{
self
.
run
.
id
}
cancelling'
)
break
yield
reply_message
.
append_message_delta
(
token
)
response_str_count
+=
1
if
self
.
run
.
status
==
RunObject
.
Status
.
cancelling
:
if
self
.
run
.
status
==
RunObject
.
Status
.
cancelling
:
yield
self
.
run
.
stream_response_with_event
(
RunObject
.
Status
.
cancelled
)
yield
self
.
run
.
stream_response_with_event
(
RunObject
.
Status
.
cancelled
)
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
c009512a
import
torch
import
torch
import
asyncio
from
transformers
import
AutoTokenizer
,
AutoConfig
,
GenerationConfig
from
transformers
import
AutoTokenizer
,
AutoConfig
,
GenerationConfig
from
ktransformers.server.backend.interfaces.transformers
import
(
from
ktransformers.server.backend.interfaces.transformers
import
(
TransformersInterface
,
TransformersInterface
,
...
@@ -13,7 +14,11 @@ from ktransformers.models.custom_cache import StaticCache
...
@@ -13,7 +14,11 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
ktransformers.util.utils
import
get_device
from
typing
import
Optional
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
warm_uped
=
False
class
KTransformersThreadContext
(
TransformersThreadContext
):
class
KTransformersThreadContext
(
TransformersThreadContext
):
pass
pass
...
@@ -22,19 +27,29 @@ class KTransformersThreadContext(TransformersThreadContext):
...
@@ -22,19 +27,29 @@ class KTransformersThreadContext(TransformersThreadContext):
class
KTransformersInterface
(
TransformersInterface
):
class
KTransformersInterface
(
TransformersInterface
):
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
):
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
):
self
.
args
=
args
self
.
args
=
args
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_code
)
try
:
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
generation_config
=
GenerationConfig
(
max_length
=
args
.
max_new_tokens
,
temperature
=
args
.
temperature
,
top_p
=
args
.
top_p
,
do_sample
=
True
)
torch
.
set_default_dtype
(
config
.
torch_dtype
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
config
.
_attn_implementation
=
"flash_attention_2"
config
.
_attn_implementation
=
"flash_attention_2"
with
torch
.
device
(
"meta"
):
with
torch
.
device
(
"meta"
):
self
.
model
=
custom_models
[
config
.
architectures
[
0
]](
config
)
self
.
model
=
custom_models
[
config
.
architectures
[
0
]](
config
)
if
default_args
.
optimize_config_path
is
None
:
if
default_args
.
optimize_config_path
is
None
:
optimize_
rule
_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
optimize_
config
_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
else
:
optimize_
rule
_path
=
args
.
optimize_config_path
optimize_
config
_path
=
args
.
optimize_config_path
# print(optimize_config)
# print(optimize_config)
...
@@ -44,8 +59,8 @@ class KTransformersInterface(TransformersInterface):
...
@@ -44,8 +59,8 @@ class KTransformersInterface(TransformersInterface):
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
" belong to current model):"
)
)
optimize_and_load_gguf
(
self
.
model
,
optimize_
rule
_path
,
gguf_path
,
config
)
optimize_and_load_gguf
(
self
.
model
,
optimize_
config
_path
,
gguf_path
,
config
)
self
.
model
.
generation_config
=
generation_config
self
.
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
self
.
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
self
.
cache
=
StaticCache
(
self
.
cache
=
StaticCache
(
...
@@ -56,25 +71,21 @@ class KTransformersInterface(TransformersInterface):
...
@@ -56,25 +71,21 @@ class KTransformersInterface(TransformersInterface):
dtype
=
self
.
model
.
dtype
,
dtype
=
self
.
model
.
dtype
,
)
)
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
try
:
self
.
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
gen_config
=
GenerationConfig
(
max_length
=
128
,
temperature
=
0.7
,
top_p
=
0.9
,
do_sample
=
True
)
self
.
model
.
generation_config
=
gen_config
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
self
.
_infer_lock
=
asyncio
.
Lock
()
def
decode_one_tokens
(
self
):
def
decode_one_tokens
(
self
):
global
warm_uped
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
if
self
.
args
.
use_cuda_graph
:
torch
.
cuda
.
set_device
(
torch_device
)
if
warm_uped
and
self
.
args
.
use_cuda_graph
:
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
self
.
cuda_graph_runner
.
capture
(
self
.
cuda_graph_runner
.
capture
(
...
@@ -96,14 +107,15 @@ class KTransformersInterface(TransformersInterface):
...
@@ -96,14 +107,15 @@ class KTransformersInterface(TransformersInterface):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
logits
=
logits
[
0
,
-
1
,
:]
logits
=
logits
[
0
,
-
1
,
:]
return
self
.
logits_to_token
(
logits
)
return
self
.
logits_to_token
(
logits
)
if
self
.
args
.
use_cuda_graph
:
warm_uped
=
True
if
self
.
use_static_cache
:
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
logits
=
self
.
model
(
logits
=
self
.
model
(
self
.
current_ids
.
to
(
torch_device
),
self
.
current_ids
.
to
(
torch_device
),
cache_position
=
self
.
active_cache_position
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
return_dict
=
False
,
use_cache
=
True
,
use_cache
=
True
,
)[
0
]
)[
0
]
...
@@ -116,59 +128,116 @@ class KTransformersInterface(TransformersInterface):
...
@@ -116,59 +128,116 @@ class KTransformersInterface(TransformersInterface):
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
if
(
input_ids_length
>=
self
.
args
.
cache_lens
):
logger
.
warning
(
f
"input_ids_length
{
input_ids_length
}
> cache_lens
{
self
.
args
.
cache_lens
}
"
)
self
.
seq_length
=
input_ids_length
return
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
"cuda:0"
if
device
==
"cuda"
else
device
if
is_new
:
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
same_prefix
=
0
self
.
seq_length
=
input_ids_length
flat_input_ids
=
input_ids
.
flatten
()
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
generated_ids
=
torch
.
zeros
(
dtype
=
torch
.
int
,
self
.
args
.
batch_size
,
device
=
self
.
args
.
device
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
self
.
seq_length
=
1
flat_prev_ids
=
self
.
generated_ids
.
flatten
()
for
i
in
range
(
min
(
self
.
seq_length
,
flat_input_ids
.
shape
[
0
])
-
1
):
if
flat_input_ids
[
i
]
==
flat_prev_ids
[
i
]:
same_prefix
+=
1
else
:
break
logger
.
debug
(
f
"same prefix len:
{
same_prefix
}
"
)
self
.
cache
.
remove_suffix
(
same_prefix
)
self
.
seq_length
=
same_prefix
self
.
generated_ids
=
self
.
generated_ids
[...,
:
same_prefix
]
input_ids
=
input_ids
[...,
same_prefix
:]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
min
(
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
args
.
cache_lens
)
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
else
:
else
:
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
logger
.
warning
(
f
"seq_length bigger than cache_lens, killed"
)
former_seq_length
=
self
.
seq_length
exit
(
0
)
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
device
)
if
not
(
type
(
self
)
is
TransformersInterface
):
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
if
self
.
use_static_cache
:
def
chunk_prefill
(
input_ids
,
cache_position
):
logits
=
self
.
model
(
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
inputs_embeds
=
inputs_embeds
,
torch
.
cuda
.
set_device
(
device
)
cache_position
=
cache_position
,
if
flashinfer_enabled
:
past_key_values
=
self
.
cache
,
MLAWrapperSingleton
.
need_plan_all
()
return_dict
=
False
,
if
self
.
use_static_cache
:
use_cache
=
True
,
logits
=
self
.
model
(
attention_mask
=
mask
,
inputs_embeds
=
inputs_embeds
,
)[
0
]
cache_position
=
cache_position
,
else
:
past_key_values
=
self
.
cache
,
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
return_dict
=
False
,
use_cache
=
True
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
return
logits
chunk_start
=
0
while
chunk_start
<
input_ids_length
:
chunk_end
=
min
(
chunk_start
+
self
.
args
.
chunk_prefill_size
,
input_ids_length
)
if
self
.
cache
!=
None
:
self
.
cache
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
input_ids
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
])
chunk_start
+=
self
.
args
.
chunk_prefill_size
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
@
property
@
property
def
active_cache_position
(
self
):
def
active_cache_position
(
self
):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
\ No newline at end of file
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
yield
v
# return this inference raw usage
yield
RawUsage
(
tokenize_time
=
self
.
profiler
.
get_timer_sec
(
'tokenize'
),
prefill_time
=
self
.
profiler
.
get_timer_sec
(
'prefill'
),
decode_time
=
self
.
profiler
.
get_timer_sec
(
'decode'
),
prefill_count
=
self
.
profiler
.
get_counter
(
'prefill'
),
decode_count
=
self
.
profiler
.
get_counter
(
'decode'
),
)
\ No newline at end of file
ktransformers/server/backend/interfaces/transformers.py
View file @
c009512a
...
@@ -13,12 +13,13 @@ from transformers import (
...
@@ -13,12 +13,13 @@ from transformers import (
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.schemas.base
import
ObjectID
from
ktransformers.server.schemas.base
import
ObjectID
from
ktransformers.server.utils.multi_timer
import
Profiler
from
ktransformers.server.utils.multi_timer
import
Profiler
from
torch.nn.attention
import
SDPBackend
import
torch
import
torch
import
sys
,
os
import
sys
,
os
from
..base
import
ThreadContext
,
BackendInterfaceBase
from
..base
import
ThreadContext
,
BackendInterfaceBase
from
ktransformers.server.config.log
import
logger
from
ktransformers.server.config.log
import
logger
from
..args
import
ConfigArgs
,
default_args
from
..args
import
ConfigArgs
,
default_args
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class
TextStreamer
:
class
TextStreamer
:
...
@@ -170,7 +171,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -170,7 +171,7 @@ class TransformersInterface(BackendInterfaceBase):
for
m
in
messages
[
1
:]:
for
m
in
messages
[
1
:]:
if
m
[
"role"
]
==
"user"
and
new_messages
[
-
1
][
"role"
]
==
"user"
:
if
m
[
"role"
]
==
"user"
and
new_messages
[
-
1
][
"role"
]
==
"user"
:
logger
.
warning
(
"merge two adjacent user messages"
)
logger
.
warning
(
"merge two adjacent user messages"
)
new_messages
[
-
1
][
"content"
]
+=
m
[
"content"
]
new_messages
[
-
1
][
"content"
]
+=
'
\n
'
+
m
[
"content"
]
else
:
else
:
new_messages
.
append
(
m
)
new_messages
.
append
(
m
)
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
...
@@ -179,7 +180,11 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -179,7 +180,11 @@ class TransformersInterface(BackendInterfaceBase):
# input_ids = self.tokenizer.apply_chat_template(
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
# ).to(self.args.device)
input_ids
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
return_tensors
=
'pt'
,
add_generation_prompt
=
True
).
to
(
self
.
args
.
device
)
input_str
:
str
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
# drop <think> token in chat template
if
input_str
.
endswith
(
'<think>
\n
'
):
input_str
=
input_str
[:
-
len
(
'<think>
\n
'
)]
input_ids
=
self
.
tokenizer
.
encode
(
input_str
,
return_tensors
=
"pt"
).
to
(
self
.
args
.
device
)
if
(
self
.
last_request_id
is
not
None
)
and
self
.
last_request_id
==
thread_id
:
if
(
self
.
last_request_id
is
not
None
)
and
self
.
last_request_id
==
thread_id
:
x
=
self
.
generated_ids
[:,:
self
.
seq_length
]
x
=
self
.
generated_ids
[:,:
self
.
seq_length
]
y
=
input_ids
[:,:
self
.
seq_length
]
y
=
input_ids
[:,:
self
.
seq_length
]
...
@@ -198,14 +203,31 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -198,14 +203,31 @@ class TransformersInterface(BackendInterfaceBase):
self
.
seq_length
+=
1
self
.
seq_length
+=
1
return
self
.
streamer
.
put
(
new_tokens
)
return
self
.
streamer
.
put
(
new_tokens
)
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
def
prepare_logits_wrapper
(
self
,
inputs
,
device
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
logits
=
logits
/
self
.
args
.
temperature
if
self
.
args
.
temperature
!=
0
else
logits
if
temperature
is
None
or
temperature
==
0
:
temperature
=
self
.
model
.
generation_config
.
temperature
if
top_p
is
None
:
top_p
=
self
.
model
.
generation_config
.
top_p
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
top_k
=
self
.
args
.
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
)
self
.
inputs
=
inputs
try
:
# transformers==4.43
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
,
device
=
device
)
)
except
:
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
)
)
for
token_idx
in
self
.
ever_generated_ids
:
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
if
logits
[
token_idx
]
<
0
:
logits
=
self
.
logits_warper
(
self
.
inputs
.
view
(
1
,
-
1
),
logits
.
view
(
1
,
-
1
))
logits
[
token_idx
]
*=
self
.
args
.
repetition_penalty
else
:
logits
[
token_idx
]
/=
self
.
args
.
repetition_penalty
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
...
@@ -221,12 +243,10 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -221,12 +243,10 @@ class TransformersInterface(BackendInterfaceBase):
def
decode_one_tokens
(
self
):
def
decode_one_tokens
(
self
):
if
self
.
use_static_cache
:
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
logits
=
self
.
model
(
logits
=
self
.
model
(
self
.
current_ids
,
self
.
current_ids
,
cache_position
=
self
.
active_cache_position
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
return_dict
=
False
,
use_cache
=
True
,
use_cache
=
True
,
)[
0
]
)[
0
]
...
@@ -237,38 +257,57 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -237,38 +257,57 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
logits_to_token
(
logits
)
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
if
is_new
:
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
same_prefix
=
0
self
.
seq_length
=
input_ids_length
flat_input_ids
=
input_ids
.
flatten
()
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
generated_ids
=
torch
.
zeros
(
dtype
=
torch
.
int
,
self
.
args
.
batch_size
,
device
=
self
.
args
.
device
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
)
dtype
=
torch
.
int
,
else
:
device
=
self
.
args
.
device
,
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
self
.
seq_length
=
1
flat_prev_ids
=
self
.
generated_ids
.
flatten
()
for
i
in
range
(
min
(
self
.
seq_length
,
flat_input_ids
.
shape
[
0
])
-
1
):
if
flat_input_ids
[
i
]
==
flat_prev_ids
[
i
]:
same_prefix
+=
1
else
:
break
logger
.
debug
(
f
"same prefix len:
{
same_prefix
}
"
)
self
.
cache
.
remove_suffix
(
same_prefix
)
self
.
seq_length
=
same_prefix
self
.
generated_ids
=
self
.
generated_ids
[...,
:
same_prefix
]
input_ids
=
input_ids
[...,
same_prefix
:]
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
device
=
input_ids
.
device
device
=
input_ids
.
device
if
not
(
type
(
self
)
is
TransformersInterface
):
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
input_ids
=
input_ids
.
to
(
"cpu"
)
...
@@ -280,26 +319,46 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -280,26 +319,46 @@ class TransformersInterface(BackendInterfaceBase):
past_key_values
=
self
.
cache
,
past_key_values
=
self
.
cache
,
return_dict
=
False
,
return_dict
=
False
,
use_cache
=
True
,
use_cache
=
True
,
attention_mask
=
mask
,
)[
0
]
)[
0
]
else
:
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
@
torch
.
no_grad
@
torch
.
no_grad
def
generate
(
self
):
def
generate
(
self
):
self
.
max_new_tokens
=
min
(
self
.
args
.
max_new_tokens
,
self
.
args
.
cache_lens
-
self
.
seq_length
)
-
1
logger
.
info
(
f
"args.max_new_tokens:
{
self
.
args
.
max_new_tokens
}
, cache_lens:
{
self
.
args
.
cache_lens
}
, seq_length:
{
self
.
seq_length
}
"
)
if
(
self
.
max_new_tokens
<=
0
):
logger
.
warning
(
"max_new_tokens is less than 0"
)
yield
self
.
streamer
.
end
(),
"length"
return
logger
.
info
(
f
"max_new_tokens:
{
self
.
max_new_tokens
}
"
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
for
_
in
range
(
1
,
self
.
args
.
max_new_tokens
):
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_flash
=
False
,
enable_mem_efficient
=
False
,
enable_math
=
True
):
for
i
in
range
(
1
,
self
.
max_new_tokens
):
with
torch
.
nn
.
attention
.
sdpa_kernel
(
backends
=
[
SDPBackend
.
FLASH_ATTENTION
,
SDPBackend
.
MATH
,
SDPBackend
.
EFFICIENT_ATTENTION
]):
if
flashinfer_enabled
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
sm_scale
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
next_token
=
self
.
decode_one_tokens
()
next_token
=
self
.
decode_one_tokens
()
self
.
profiler
.
inc
(
"decode"
)
self
.
profiler
.
inc
(
"decode"
)
if
next_token
==
self
.
tokenizer
.
eos_token_id
:
if
next_token
==
self
.
tokenizer
.
eos_token_id
or
"<|im_end|>"
==
self
.
tokenizer
.
decode
(
next_token
):
yield
self
.
streamer
.
end
(),
None
yield
""
,
"stop"
assert
self
.
args
.
batch_size
==
1
assert
self
.
args
.
batch_size
==
1
break
break
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
),
None
yield
self
.
streamer
.
end
()
else
:
# for's else, if output get max new tokens
yield
self
.
streamer
.
end
(),
None
yield
""
,
"length"
def
check_is_new
(
self
,
thread_id
:
str
):
def
check_is_new
(
self
,
thread_id
:
str
):
if
not
self
.
use_static_cache
:
if
not
self
.
use_static_cache
:
...
@@ -314,7 +373,8 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -314,7 +373,8 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
self
.
last_request_id
=
thread_id
return
True
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
...
@@ -324,8 +384,9 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -324,8 +384,9 @@ class TransformersInterface(BackendInterfaceBase):
#input_ids = torch.tensor([[6366]], device=input_ids.device)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
else
:
raise
ValueError
(
"local_messages should be List or str"
)
raise
ValueError
(
"local_messages should be List or str"
)
if
Config
().
user_force_think
:
if
Config
().
user_force_think
:
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\
\
n"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\n
"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
(
input_ids
=
torch
.
cat
(
[
input_ids
,
token_thinks
],
dim
=
1
[
input_ids
,
token_thinks
],
dim
=
1
)
)
...
@@ -333,21 +394,24 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -333,21 +394,24 @@ class TransformersInterface(BackendInterfaceBase):
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
if
Config
().
user_force_think
:
if
Config
().
user_force_think
:
t
=
"<think>
\n
"
think
=
'<think>
\n
'
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
think
,
end
=
""
,
flush
=
True
)
yield
t
yield
think
,
None
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
# output think token after prefill done
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
,
None
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
for
t
in
self
.
generate
():
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
,
finish_reason
print
(
""
)
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
report_last_time_performance
()
self
.
report_last_time_performance
()
ktransformers/server/config/config.py
View file @
c009512a
...
@@ -69,6 +69,7 @@ class Config(metaclass=Singleton):
...
@@ -69,6 +69,7 @@ class Config(metaclass=Singleton):
self
.
server
:
dict
=
cfg
.
get
(
"server"
,
{})
self
.
server
:
dict
=
cfg
.
get
(
"server"
,
{})
self
.
server_ip
=
self
.
server
.
get
(
"ip"
,
"0.0.0.0"
)
self
.
server_ip
=
self
.
server
.
get
(
"ip"
,
"0.0.0.0"
)
self
.
server_port
=
self
.
server
.
get
(
"port"
,
9016
)
self
.
server_port
=
self
.
server
.
get
(
"port"
,
9016
)
self
.
api_key
=
self
.
server
.
get
(
"api_key"
,
""
)
# db configs
# db configs
self
.
db_configs
:
dict
=
cfg
.
get
(
"db"
,
{})
self
.
db_configs
:
dict
=
cfg
.
get
(
"db"
,
{})
...
@@ -104,7 +105,8 @@ class Config(metaclass=Singleton):
...
@@ -104,7 +105,8 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
chunk_prefill_size
=
self
.
model
.
get
(
"chunk_prefill_size"
,
8192
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
...
...
ktransformers/server/main.py
View file @
c009512a
...
@@ -105,6 +105,7 @@ def custom_openapi(app):
...
@@ -105,6 +105,7 @@ def custom_openapi(app):
def
main
():
def
main
():
cfg
=
Config
()
cfg
=
Config
()
arg_parser
=
ArgumentParser
(
cfg
)
arg_parser
=
ArgumentParser
(
cfg
)
# 初始化消息
# 初始化消息
...
...
ktransformers/server/requirements.txt
View file @
c009512a
...
@@ -5,6 +5,7 @@ langchain >= 0.2.0
...
@@ -5,6 +5,7 @@ langchain >= 0.2.0
blessed >= 1.20.0
blessed >= 1.20.0
accelerate >= 0.31.0
accelerate >= 0.31.0
sentencepiece >= 0.1.97
sentencepiece >= 0.1.97
openai
setuptools
setuptools
build
build
ninja
ninja
...
...
ktransformers/server/schemas/assistants/streaming.py
View file @
c009512a
...
@@ -73,7 +73,7 @@ class RunStepDelta(Object):
...
@@ -73,7 +73,7 @@ class RunStepDelta(Object):
class
Done
():
class
Done
():
def
to_stream_reply
(
self
):
def
to_stream_reply
(
self
):
return
f
"
event: done
\n
data: [DONE]
\n\n
"
return
f
"data: [DONE]
\n\n
"
async
def
check_client_link
(
request
:
Request
,
async_events
:
AsyncIterable
):
async
def
check_client_link
(
request
:
Request
,
async_events
:
AsyncIterable
):
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment