Unverified Commit ca1dc1e7 authored by Atream's avatar Atream Committed by GitHub
Browse files

Merge branch 'main' into main

parents d3b45d57 505f4e2c
...@@ -15,6 +15,18 @@ ...@@ -15,6 +15,18 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
......
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearFP8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
\ No newline at end of file
...@@ -182,6 +182,53 @@ ...@@ -182,6 +182,53 @@
generate_device: "cuda:3" generate_device: "cuda:3"
prefill_device: "cuda:3" prefill_device: "cuda:3"
# === MLP Experts Replacement ===
# replace with marlin expert. Open and modify layer-num as needed.
# Each layer of malin experts takes about 6GB of GPU memory.
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
# GPU 0: layers 3–4
# - match:
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:0"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 1: layers 15–17
# - match:
# name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:1"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 2: layers 30–32
# - match:
# name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:2"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 3: layers 45–46
# - match:
# name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:3"
# generate_op: "KExpertsMarlin"
# recursive: False
# === MLP Experts Replacement === # === MLP Experts Replacement ===
# GPU 0: layers 0–14 # GPU 0: layers 0–14
...@@ -246,6 +293,7 @@ ...@@ -246,6 +293,7 @@
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
absorb_for_prefill: False
# GPU 1: layers 15–29 # GPU 1: layers 15–29
- match: - match:
...@@ -255,6 +303,7 @@ ...@@ -255,6 +303,7 @@
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
absorb_for_prefill: False
# GPU 2: layers 30–44 # GPU 2: layers 30–44
- match: - match:
...@@ -264,6 +313,7 @@ ...@@ -264,6 +313,7 @@
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
absorb_for_prefill: False
# GPU 3: layers 45–60 # GPU 3: layers 45–60
- match: - match:
...@@ -273,6 +323,7 @@ ...@@ -273,6 +323,7 @@
kwargs: kwargs:
generate_device: "cuda:3" generate_device: "cuda:3"
prefill_device: "cuda:3" prefill_device: "cuda:3"
absorb_for_prefill: False
# === Overall Model Replacement with Transfer Map === # === Overall Model Replacement with Transfer Map ===
...@@ -316,9 +367,20 @@ ...@@ -316,9 +367,20 @@
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config)
- match: - match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
# For final modules (model.norm), ensure they are on GPU 3 (as in your original config)
- match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -713,9 +713,20 @@ ...@@ -713,9 +713,20 @@
generate_device: "cuda:7" generate_device: "cuda:7"
prefill_device: "cuda:7" prefill_device: "cuda:7"
# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config)
- match: - match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:7"
prefill_device: "cuda:7"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
# For final modules (model.norm), ensure they are on GPU 7 (as in your original config)
- match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearFP8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearFP8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
30: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
...@@ -153,9 +153,20 @@ ...@@ -153,9 +153,20 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace: replace:
class: "default" class: ktransformers.operators.linear.KTransformersLinear
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
...@@ -135,7 +135,18 @@ ...@@ -135,7 +135,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -5,6 +5,18 @@ ...@@ -5,6 +5,18 @@
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
...@@ -48,6 +60,7 @@ ...@@ -48,6 +60,7 @@
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
......
...@@ -15,6 +15,16 @@ ...@@ -15,6 +15,16 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.block_sparse_moe$" name: "^model\\.layers\\..*\\.block_sparse_moe$"
class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock
......
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)
#- match:
# name: "^model\\.layers\\..*\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
# kwargs:
# prefill_device: "cuda"
# prefill_op: "KExpertsTorch"
# generate_device: "cuda"
# generate_op: "KExpertsMarlin"
# recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
\ No newline at end of file
...@@ -77,9 +77,19 @@ ...@@ -77,9 +77,19 @@
kwargs: kwargs:
generate_device: "cpu" generate_device: "cpu"
prefill_device: "cpu" prefill_device: "cpu"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "(^model.norm)|(^lm_head)" name: "(^model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,16 @@ ...@@ -15,6 +15,16 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
......
...@@ -12,8 +12,8 @@ from ktransformers.server.config.config import Config ...@@ -12,8 +12,8 @@ from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import check_link_response from ktransformers.server.schemas.assistants.streaming import check_link_response
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
router = APIRouter(prefix='/api')
router = APIRouter(prefix='/api')
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
class OllamaGenerateCompletionRequest(BaseModel): class OllamaGenerateCompletionRequest(BaseModel):
...@@ -40,61 +40,121 @@ class OllamaGenerateCompletionRequest(BaseModel): ...@@ -40,61 +40,121 @@ class OllamaGenerateCompletionRequest(BaseModel):
keep_alive: Optional[str] = Field( keep_alive: Optional[str] = Field(
"5m", description="Controls how long the model will stay loaded into memory following the request.") "5m", description="Controls how long the model will stay loaded into memory following the request.")
class OllamaGenerationStreamResponse(BaseModel): class OllamaGenerationStreamResponse(BaseModel):
model: str model: str
created_at: str created_at: str
response: str response: str
done: bool = Field(...) done: bool = Field(...)
class OllamaGenerationResponse(BaseModel): class OllamaGenerationResponse(BaseModel):
pass pass
@router.post("/generate", tags=['ollama']) @router.post("/generate", tags=['ollama'])
async def generate(request: Request, input: OllamaGenerateCompletionRequest): async def generate(request: Request, input: OllamaGenerateCompletionRequest):
id = str(uuid4()) id = str(uuid4())
interface: BackendInterfaceBase = get_interface() interface: BackendInterfaceBase = get_interface()
print(f'COMPLETION INPUT:----\n{input.prompt}\n----') print(f'COMPLETION INPUT:----\n{input.prompt}\n----')
config = Config() config = Config()
if input.stream: if input.stream:
async def inner(): async def inner():
async for token in interface.inference(input.prompt,id): async for token in interface.inference(input.prompt, id):
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response=token,done=False) d = OllamaGenerationStreamResponse(
yield d.model_dump_json()+'\n' model=config.model_name,
# d = {'model':config.model_name,'created_at':"", 'response':token,'done':False} created_at=str(datetime.now()),
# yield f"{json.dumps(d)}\n" response=token,
# d = {'model':config.model_name,'created_at':"", 'response':'','done':True} done=False
# yield f"{json.dumps(d)}\n" )
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response='',done=True) yield d.model_dump_json() + '\n'
yield d.model_dump_json()+'\n' d = OllamaGenerationStreamResponse(
return check_link_response(request,inner()) model=config.model_name,
created_at=str(datetime.now()),
response='',
done=True
)
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else: else:
raise NotImplementedError raise NotImplementedError
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
class OllamaChatCompletionMessage(BaseModel):
role: str
content: str
class OllamaChatCompletionRequest(BaseModel): class OllamaChatCompletionRequest(BaseModel):
pass model: str = Field(..., description="The model name, which is required.")
messages: List[OllamaChatCompletionMessage] = Field(
..., description="A list of messages to generate a response for.")
stream: bool = Field(True, description="If true, the response will be streamed.")
class OllamaChatCompletionStreamResponse(BaseModel): class OllamaChatCompletionStreamResponse(BaseModel):
pass model: str
created_at: str
message: dict
done: bool = Field(...)
total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds")
load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds")
prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt")
prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds")
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
class OllamaChatCompletionResponse(BaseModel): class OllamaChatCompletionResponse(BaseModel):
pass pass
@router.post("/chat", tags=['ollama']) @router.post("/chat", tags=['ollama'])
async def chat(request: Request, input: OllamaChatCompletionRequest): async def chat(request: Request, input: OllamaChatCompletionRequest):
raise NotImplementedError id = str(uuid4())
interface: BackendInterfaceBase = get_interface()
config = Config()
# 将消息转换为提示字符串
prompt = ""
for msg in input.messages:
prompt += f"{msg.role}: {msg.content}\n"
prompt += "assistant:"
if input.stream:
async def inner():
start_time = time() # 记录开始时间(秒)
eval_count = 0 # 统计生成的 token 数量
tokens = []
async for token in interface.inference(prompt, id):
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={"role": "assistant", "content": token},
done=False
)
yield d.model_dump_json() + '\n'
# 计算性能数据
end_time = time()
total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒
prompt_eval_count = len(prompt.split()) # 简单估算提示词数量
eval_duration = total_duration # 假设全部时间用于生成(简化)
prompt_eval_duration = 0 # 假设无单独提示评估时间
load_duration = 0 # 假设加载时间未知
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={},
done=True,
total_duration=total_duration,
load_duration=load_duration,
prompt_eval_count=prompt_eval_count,
prompt_eval_duration=prompt_eval_duration,
eval_count=eval_count,
eval_duration=eval_duration
)
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError("Non-streaming chat is not implemented.")
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
class OllamaModel(BaseModel): class OllamaModel(BaseModel):
...@@ -103,9 +163,8 @@ class OllamaModel(BaseModel): ...@@ -103,9 +163,8 @@ class OllamaModel(BaseModel):
size: int size: int
# TODO: fill the rest correctly # TODO: fill the rest correctly
# mock ollama # mock ollama
@router.get("/tags",tags=['ollama']) @router.get("/tags", tags=['ollama'])
async def tags(): async def tags():
config = Config() config = Config()
# TODO: fill this correctly, although it does not effect Tabby # TODO: fill this correctly, although it does not effect Tabby
...@@ -138,25 +197,21 @@ class OllamaShowResponse(BaseModel): ...@@ -138,25 +197,21 @@ class OllamaShowResponse(BaseModel):
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
@router.post("/show", tags=['ollama']) @router.post("/show", tags=['ollama'])
async def show(request: Request, input: OllamaShowRequest): async def show(request: Request, input: OllamaShowRequest):
config = Config() config = Config()
# TODO: Add more info in config to return, although it does not effect Tabby # TODO: Add more info in config to return, although it does not effect Tabby
return OllamaShowResponse( return OllamaShowResponse(
modelfile = "# Modelfile generated by ...", modelfile="# Modelfile generated by ...",
parameters = " ", parameters=" ",
template = " ", template=" ",
details = OllamaShowDetial( details=OllamaShowDetial(
parent_model = " ", parent_model=" ",
format = "gguf", format="gguf",
family = " ", family=" ",
families = [ families=[" "],
" " parameter_size=" ",
], quantization_level=" "
parameter_size = " ",
quantization_level = " "
), ),
model_info = OllamaModelInfo() model_info=OllamaModelInfo()
) )
\ No newline at end of file
...@@ -5,18 +5,15 @@ from fastapi import APIRouter ...@@ -5,18 +5,15 @@ from fastapi import APIRouter
from fastapi.requests import Request from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response from ktransformers.server.schemas.assistants.streaming import chat_stream_response
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config
router = APIRouter() router = APIRouter()
models = [
{"id": "0", "name": "ktranformers-model"},
]
@router.get('/models', tags=['openai']) @router.get('/models', tags=['openai'])
async def list_models(): async def list_models():
return models return [{"id": Config().model_name, "name": Config().model_name}]
@router.post('/chat/completions', tags=['openai']) @router.post('/chat/completions', tags=['openai'])
...@@ -28,15 +25,19 @@ async def chat_completion(request:Request,create:ChatCompletionCreate): ...@@ -28,15 +25,19 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
input_message = [json.loads(m.model_dump_json()) for m in create.messages] input_message = [json.loads(m.model_dump_json()) for m in create.messages]
if Config().api_key != '':
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
if create.stream: if create.stream:
async def inner(): async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
async for token in interface.inference(input_message,id): async for token in interface.inference(input_message,id,create.temperature,create.top_p):
chunk.set_token(token) chunk.set_token(token)
yield chunk yield chunk
return chat_stream_response(request,inner()) return chat_stream_response(request,inner())
else: else:
comp = ChatCompletionObject(id=id,object='chat.completion.chunk',created=int(time())) comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
async for token in interface.inference(input_message,id): comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
comp.append_token(token) comp.append_token(token)
return comp return comp
...@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate): ...@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if create.stream: if create.stream:
async def inner(): async def inner():
async for token in interface.inference(create.prompt,id): async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
d = {'choices':[{'delta':{'content':token}}]} d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n" yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
...@@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate): ...@@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
return stream_response(request,inner()) return stream_response(request,inner())
else: else:
comp = CompletionObject(id=id,object='text_completion',created=int(time())) comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id): async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
comp.append_token(token) comp.append_token(token)
return comp return comp
...@@ -10,6 +10,7 @@ class ArgumentParser: ...@@ -10,6 +10,7 @@ class ArgumentParser:
parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers") parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers")
parser.add_argument("--host", type=str, default=self.cfg.server_ip) parser.add_argument("--host", type=str, default=self.cfg.server_ip)
parser.add_argument("--port", type=int, default=self.cfg.server_port) parser.add_argument("--port", type=int, default=self.cfg.server_port)
parser.add_argument("--api_key", type=str, default=self.cfg.api_key)
parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str) parser.add_argument("--ssl_certfile", type=str)
parser.add_argument("--web", type=bool, default=self.cfg.mount_web) parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
...@@ -23,13 +24,13 @@ class ArgumentParser: ...@@ -23,13 +24,13 @@ class ArgumentParser:
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False) parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer) parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
parser.add_argument("--type", type=str, default=self.cfg.backend_type) parser.add_argument("--type", type=str, default=self.cfg.backend_type)
parser.add_argument("--chunk_prefill_size", type=int, default=8192)
# model configs # model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int? # parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser.add_argument("--paged", type=bool, default=self.cfg.paged) parser.add_argument("--paged", type=bool, default=self.cfg.paged)
parser.add_argument("--total_context", type=int, default=self.cfg.total_context) parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size) parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
parser.add_argument("--max_chunk_size", type=int, default=self.cfg.max_chunk_size)
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens) parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode) parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
parser.add_argument("--healing", type=bool, default=self.cfg.healing) parser.add_argument("--healing", type=bool, default=self.cfg.healing)
...@@ -90,7 +91,8 @@ class ArgumentParser: ...@@ -90,7 +91,8 @@ class ArgumentParser:
# user config # user config
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
parser.add_argument("--force_think", type=bool, default=self.cfg.user_force_think) parser.add_argument("--force_think", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think)
parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph)
# web config # web config
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
......
...@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel): ...@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel):
max_batch_size: int = Field( max_batch_size: int = Field(
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context" None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
) )
max_chunk_size: int = Field( chunk_prefill_size: int = Field(
None, None,
description=( description=(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new" "Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
......
...@@ -69,6 +69,7 @@ class Config(metaclass=Singleton): ...@@ -69,6 +69,7 @@ class Config(metaclass=Singleton):
self.server: dict = cfg.get("server", {}) self.server: dict = cfg.get("server", {})
self.server_ip = self.server.get("ip", "0.0.0.0") self.server_ip = self.server.get("ip", "0.0.0.0")
self.server_port = self.server.get("port", 9016) self.server_port = self.server.get("port", 9016)
self.api_key = self.server.get("api_key", "")
# db configs # db configs
self.db_configs: dict = cfg.get("db", {}) self.db_configs: dict = cfg.get("db", {})
...@@ -104,7 +105,8 @@ class Config(metaclass=Singleton): ...@@ -104,7 +105,8 @@ class Config(metaclass=Singleton):
self.total_context = self.model.get("total_context", 2**18) self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1) self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.max_chunk_size = self.model.get("max_chunk_size", 2048) self.chunk_prefill_size = self.model.get("chunk_prefill_size", 8192)
self.max_new_tokens = self.model.get("max_new_tokens", 2000) self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False) self.json_mode = self.model.get("json_mode", False)
self.healing = self.model.get("healing", False) self.healing = self.model.get("healing", False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment