"vscode:/vscode.git/clone" did not exist on "8456222852c97ecd5b0f76a39af2be401106056f"
Commit 5aee6c04 authored by Azure's avatar Azure
Browse files

Merge branch 'main' into develop-0.2.3

parents 216a63b8 48b98007
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel as compile_server FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel as compile_server
WORKDIR /workspace WORKDIR /workspace
ENV CUDA_HOME /usr/local/cuda ENV CUDA_HOME /usr/local/cuda
RUN <<EOF RUN <<EOF
...@@ -10,6 +10,7 @@ apt update -y && apt install -y --no-install-recommends \ ...@@ -10,6 +10,7 @@ apt update -y && apt install -y --no-install-recommends \
g++ \ g++ \
cmake && cmake &&
rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/* &&
pip install --upgrade pip &&
pip install ninja pyproject numpy cpufeature && pip install ninja pyproject numpy cpufeature &&
pip install flash-attn && pip install flash-attn &&
cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/ cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/
......
name: 🐞 Bug report
description: Create a report to help us reproduce and fix the bug
title: "[Bug] "
labels: ['Bug']
body:
- type: checkboxes
attributes:
label: Checklist
options:
- label: 1. I have searched related issues but cannot get the expected help.
- label: 2. The bug has not been fixed in the latest version.
- label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- label: 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/kvcache-ai/ktransformers/discussions. Otherwise, it will be closed.
- label: 5. To help the community, I will use Chinese/English or attach an Chinese/English translation if using another language. Non-Chinese/English content without translation may be closed.
- type: textarea
attributes:
label: Describe the bug
description: A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
attributes:
label: Reproduction
description: |
What command or script did you run? Which **model** are you using?
placeholder: |
A placeholder for the command.
validations:
required: true
- type: textarea
attributes:
label: Environment
description: |
Please provide necessary environment information here (e.g. OS/GPU/CPU). Otherwise the issue will be close.
placeholder: Environment here.
validations:
required: true
\ No newline at end of file
name: 🐞 BUG报告
description: 创建报告以帮助我们复现并修复BUG
title: "[Bug] "
labels: ['Bug']
body:
- type: checkboxes
attributes:
label: 检查清单
options:
- label: 1. 我已经搜索过相关问题,但未能获得预期的帮助
- label: 2. 该问题在最新版本中尚未修复
- label: 3. 请注意,如果您提交的BUG相关 issue 缺少对应环境信息和最小可复现示例,我们将难以复现和定位问题,降低获得反馈的可能性
- label: 4. 如果您提出的不是bug而是问题,请在讨论区发起讨论 https://github.com/kvcache-ai/ktransformers/discussions。否则该 issue 将被关闭
- label: 5. 为方便社区交流,我将使用中文/英文或附上中文/英文翻译(如使用其他语言)。未附带翻译的非中文/英语内容可能会被关闭
- type: textarea
attributes:
label: 问题描述
description: 清晰简洁地描述BUG是什么
validations:
required: true
- type: textarea
attributes:
label: 复现步骤
description: |
你运行了什么命令或脚本?使用的是哪个**模型**?
placeholder: |
在此处填写命令
validations:
required: true
- type: textarea
attributes:
label: 环境信息
description: |
请提供必要的环境信息(如操作系统/GPU/CPU),否则该 issue 将被关闭
placeholder: 在此处填写环境信息
validations:
required: true
\ No newline at end of file
name: 🚀 Feature request
description: Suggest an idea for this project
title: "[Feature] "
body:
- type: checkboxes
attributes:
label: Checklist
options:
- label: 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/kvcache-ai/ktransformers/discussions. Otherwise, it will be closed.
- label: 2. To help the community, I will use Chinese/English or attach an Chinese/English translation if using another language. Non-English/Chinese content without translation may be closed.
- type: textarea
attributes:
label: Motivation
description: |
A clear and concise description of the motivation of the feature.
validations:
required: true
- type: textarea
attributes:
label: Related resources
description: |
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
\ No newline at end of file
name: 🚀 新功能请求
description: 为项目提出新功能建议
title: "[Feature] "
body:
- type: checkboxes
attributes:
label: 检查清单
options:
- label: 1. 如果您提出的不是新功能而是问题,请在讨论区发起讨论 https://github.com/kvcache-ai/ktransformers/discussions。否则该 issue 将被关闭
- label: 2. 为方便社区交流,我将使用中文/英文或附上英文/中文翻译(如使用其他语言)。未附带翻译的非英文/中文内容可能会被关闭
- type: textarea
attributes:
label: 需求背景
description: |
清晰简洁地描述该功能的背景需求
validations:
required: true
- type: textarea
attributes:
label: 相关资源
description: |
如果有官方代码实现或第三方实现,请在此提供相关信息,这将非常有帮助
\ No newline at end of file
...@@ -3,6 +3,14 @@ name: DockerHub CI ...@@ -3,6 +3,14 @@ name: DockerHub CI
on: on:
release: release:
types: [published] types: [published]
workflow_dispatch:
inputs:
choose:
description: 'Will you push the image to DockerHub? 0 for No, 1 for Yes'
required: true
default: '0'
type: string
# push: # push:
# branches: # branches:
# - main # - main
......
...@@ -25,3 +25,4 @@ book ...@@ -25,3 +25,4 @@ book
ktransformers/tests/chat_txt.txt ktransformers/tests/chat_txt.txt
mmlu_result* mmlu_result*
ktransformers/ktransformers_ext/cuda_musa/ ktransformers/ktransformers_ext/cuda_musa/
test_prompt.txt
...@@ -10,7 +10,7 @@ EOF ...@@ -10,7 +10,7 @@ EOF
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel as compile_server FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel as compile_server
ARG CPU_INSTRUCT=NATIVE ARG CPU_INSTRUCT=NATIVE
WORKDIR /workspace WORKDIR /workspace
ENV CUDA_HOME /usr/local/cuda ENV CUDA_HOME /usr/local/cuda
...@@ -27,6 +27,7 @@ rm -rf /var/lib/apt/lists/* && ...@@ -27,6 +27,7 @@ rm -rf /var/lib/apt/lists/* &&
cd ktransformers && cd ktransformers &&
git submodule init && git submodule init &&
git submodule update && git submodule update &&
pip install --upgrade pip &&
pip install ninja pyproject numpy cpufeature && pip install ninja pyproject numpy cpufeature &&
pip install flash-attn && pip install flash-attn &&
CPU_INSTRUCT=${CPU_INSTRUCT} KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose && CPU_INSTRUCT=${CPU_INSTRUCT} KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose &&
......
...@@ -160,6 +160,7 @@ is speed up which is inspiring. So our showcase makes use of this finding* ...@@ -160,6 +160,7 @@ is speed up which is inspiring. So our showcase makes use of this finding*
### V0.2.2 longer context & FP8 kernel ### V0.2.2 longer context & FP8 kernel
#### longer context #### longer context
To use this feature, [install flashinfer](https://github.com/flashinfer-ai/flashinfer) first. To use this feature, [install flashinfer](https://github.com/flashinfer-ai/flashinfer) first.
Note: The latest MLA kernel in FlashInfer still has a few minor issues. They are continuously fixing them on the main branch. If you are using FlashInfer, please install it from the main source code. Note: The latest MLA kernel in FlashInfer still has a few minor issues. They are continuously fixing them on the main branch. If you are using FlashInfer, please install it from the main source code.
If you want to use long context(longer than 20K) for prefill, enable the matrix absorption MLA during the prefill phase, which will significantly reduce the size of the kv cache. Modify yaml file like this: If you want to use long context(longer than 20K) for prefill, enable the matrix absorption MLA during the prefill phase, which will significantly reduce the size of the kv cache. Modify yaml file like this:
...@@ -173,6 +174,8 @@ If you want to use long context(longer than 20K) for prefill, enable the matrix ...@@ -173,6 +174,8 @@ If you want to use long context(longer than 20K) for prefill, enable the matrix
prefill_device: "cuda" prefill_device: "cuda"
absorb_for_prefill: True # change this to True to enable long context(prefill may slower). absorb_for_prefill: True # change this to True to enable long context(prefill may slower).
``` ```
If the VRAM is still insufficient, try reducing the `chunk_prefill_size` parameter (default is 8192) to further decrease the intermediate results during chunk prefill.
#### FP8 kernel #### FP8 kernel
The DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achieve performance optimization through the following works: The DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achieve performance optimization through the following works:
......
...@@ -209,6 +209,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE ...@@ -209,6 +209,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
if (WIN32) if (WIN32)
include_directories("$ENV{CUDA_PATH}/include") include_directories("$ENV{CUDA_PATH}/include")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif (UNIX) elseif (UNIX)
if (KTRANSFORMERS_USE_CUDA) if (KTRANSFORMERS_USE_CUDA)
find_package(CUDA REQUIRED) find_package(CUDA REQUIRED)
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "kvcache.h" #include "kvcache.h"
#include <chrono>
void KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output, void KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,
float *attn_lse, int batch_size, float *attn_lse, int batch_size,
Backend *backend) { Backend *backend) {
......
...@@ -9,6 +9,9 @@ ...@@ -9,6 +9,9 @@
**/ **/
#include "kvcache.h" #include "kvcache.h"
#include <chrono>
void KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) { void KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) {
// Timer start // Timer start
auto start = std::chrono::high_resolution_clock::now(); auto start = std::chrono::high_resolution_clock::now();
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "kvcache.h" #include "kvcache.h"
#include <chrono>
void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id, void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id,
int block_idx, Backend *backend) { int block_idx, Backend *backend) {
// Timer start // Timer start
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "kvcache.h" #include "kvcache.h"
#include <chrono>
std::string ggml_type_to_string(ggml_type type) { std::string ggml_type_to_string(ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
......
...@@ -62,6 +62,7 @@ def local_chat( ...@@ -62,6 +62,7 @@ def local_chat(
prompt_file : str | None = None, prompt_file : str | None = None,
mode: str = "normal", mode: str = "normal",
force_think: bool = False, force_think: bool = False,
chunk_prefill_size: int = 8192
): ):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -110,15 +111,15 @@ def local_chat( ...@@ -110,15 +111,15 @@ def local_chat(
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
try: try:
model.generation_config = GenerationConfig.from_pretrained(model_path) model.generation_config = GenerationConfig.from_pretrained(model_path)
except: except Exception as e:
gen_config = GenerationConfig( print(f"generation config can't auto create, make default. Message: {e}")
max_length=128, gen_config = GenerationConfig(
temperature=0.7, temperature=0.6,
top_p=0.9, top_p=0.95,
do_sample=True do_sample=True
) )
model.generation_config = gen_config model.generation_config = gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path) # model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None: if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id model.generation_config.pad_token_id = model.generation_config.eos_token_id
...@@ -168,16 +169,16 @@ def local_chat( ...@@ -168,16 +169,16 @@ def local_chat(
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml" "please change max_seq_len in ~/.ktransformers/config.yaml"
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8: if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
generated = prefill_and_generate( generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
) )
else: else:
generated = prefill_and_generate( generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
) )
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(local_chat) fire.Fire(local_chat)
\ No newline at end of file
...@@ -122,7 +122,7 @@ class MLAWrapper(): ...@@ -122,7 +122,7 @@ class MLAWrapper():
if kv_indices is None: if kv_indices is None:
assert self.max_batch_size == 1 assert self.max_batch_size == 1
kv_indices = self.kv_indices_buf kv_indices = self.kv_indices_buf
self.wrapper.plan( self.wrapper.plan(
qo_indptr, qo_indptr,
kv_indptr, kv_indptr,
...@@ -139,6 +139,11 @@ class MLAWrapper(): ...@@ -139,6 +139,11 @@ class MLAWrapper():
) )
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse) return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
class MLAWrapperSingleton(): class MLAWrapperSingleton():
...@@ -200,12 +205,14 @@ class MLAWrapperSingleton(): ...@@ -200,12 +205,14 @@ class MLAWrapperSingleton():
if __name__ == "__main__": if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
max_batch_size = 1 max_batch_size = 1
max_pages = 1 max_pages = 128
page_size = 64 page_size = 64
num_heads = 128 num_heads = 128
q_len = 10 kv_len = 4023
q_len = 1
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda") q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda") q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda") ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
...@@ -218,7 +225,7 @@ if __name__ == "__main__": ...@@ -218,7 +225,7 @@ if __name__ == "__main__":
max_pages, max_pages,
) )
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda") kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan( wrapper.plan(
qo_indptr, qo_indptr,
...@@ -236,6 +243,29 @@ if __name__ == "__main__": ...@@ -236,6 +243,29 @@ if __name__ == "__main__":
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe) attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
print(attn_output.shape) print(attn_output.shape)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
kv_len = 6789
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
None,
None,
kv_len_arr,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
graph.replay()
k = ( k = (
torch.cat([ckv, k_pe], dim=-1) torch.cat([ckv, k_pe], dim=-1)
...@@ -244,15 +274,15 @@ if __name__ == "__main__": ...@@ -244,15 +274,15 @@ if __name__ == "__main__":
) )
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
print(k[:10].shape) print(k[:kv_len].shape)
print(v[:10].shape) print(v[:kv_len].shape)
attn_ref, lse_ref = attention_ref( attn_ref, lse_ref = attention_ref(
max_batch_size, max_batch_size,
torch.cat([q_nope, q_pe], dim=-1), torch.cat([q_nope, q_pe], dim=-1),
k[:10], k[:kv_len],
v[:10], v[:kv_len],
False, True,
192 ** (-0.5) 192 ** (-0.5)
) )
print(attn_ref.shape) print(attn_ref.shape)
......
...@@ -31,13 +31,13 @@ async def chat_completion(request:Request,create:ChatCompletionCreate): ...@@ -31,13 +31,13 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
if create.stream: if create.stream:
async def inner(): async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
async for token in interface.inference(input_message,id): async for token in interface.inference(input_message,id,create.temperature,create.top_p):
chunk.set_token(token) chunk.set_token(token)
yield chunk yield chunk
return chat_stream_response(request,inner()) return chat_stream_response(request,inner())
else: else:
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time())) comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2) comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
async for token in interface.inference(input_message,id): async for token in interface.inference(input_message,id,create.temperature,create.top_p):
comp.append_token(token) comp.append_token(token)
return comp return comp
...@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate): ...@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if create.stream: if create.stream:
async def inner(): async def inner():
async for token in interface.inference(create.prompt,id): async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
d = {'choices':[{'delta':{'content':token}}]} d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n" yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
...@@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate): ...@@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
return stream_response(request,inner()) return stream_response(request,inner())
else: else:
comp = CompletionObject(id=id,object='text_completion',created=int(time())) comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id): async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
comp.append_token(token) comp.append_token(token)
return comp return comp
...@@ -24,13 +24,13 @@ class ArgumentParser: ...@@ -24,13 +24,13 @@ class ArgumentParser:
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False) parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer) parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
parser.add_argument("--type", type=str, default=self.cfg.backend_type) parser.add_argument("--type", type=str, default=self.cfg.backend_type)
parser.add_argument("--chunk_prefill_size", type=int, default=8192)
# model configs # model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int? # parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser.add_argument("--paged", type=bool, default=self.cfg.paged) parser.add_argument("--paged", type=bool, default=self.cfg.paged)
parser.add_argument("--total_context", type=int, default=self.cfg.total_context) parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size) parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
parser.add_argument("--max_chunk_size", type=int, default=self.cfg.max_chunk_size)
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens) parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode) parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
parser.add_argument("--healing", type=bool, default=self.cfg.healing) parser.add_argument("--healing", type=bool, default=self.cfg.healing)
......
...@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel): ...@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel):
max_batch_size: int = Field( max_batch_size: int = Field(
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context" None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
) )
max_chunk_size: int = Field( chunk_prefill_size: int = Field(
None, None,
description=( description=(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new" "Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment