Unverified Commit e095b162 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Add max_prefill_num_token into server arguments (#133)

parent 67be11c7
...@@ -82,7 +82,8 @@ class ModelRpcServer(rpyc.Service): ...@@ -82,7 +82,8 @@ class ModelRpcServer(rpyc.Service):
self.max_total_num_token = self.model_runner.max_total_num_token self.max_total_num_token = self.model_runner.max_total_num_token
self.max_num_running_seq = self.max_total_num_token // 2 self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max( self.max_prefill_num_token = max(
self.model_config.context_len, self.max_total_num_token // 6 self.model_config.context_len,
self.max_total_num_token // 6 if server_args.max_prefill_num_token is None else server_args.max_prefill_num_token,
) )
self.int_token_logit_bias = torch.tensor( self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
......
...@@ -430,7 +430,8 @@ class Runtime: ...@@ -430,7 +430,8 @@ class Runtime:
load_format: str = "auto", load_format: str = "auto",
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
trust_remote_code: bool = True, trust_remote_code: bool = True,
mem_fraction_static: float = 0.9, mem_fraction_static: float = ServerArgs.mem_fraction_static,
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
tp_size: int = 1, tp_size: int = 1,
model_mode: List[str] = (), model_mode: List[str] = (),
schedule_heuristic: str = "lpm", schedule_heuristic: str = "lpm",
...@@ -451,6 +452,7 @@ class Runtime: ...@@ -451,6 +452,7 @@ class Runtime:
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
mem_fraction_static=mem_fraction_static, mem_fraction_static=mem_fraction_static,
max_prefill_num_token=max_prefill_num_token,
tp_size=tp_size, tp_size=tp_size,
model_mode=model_mode, model_mode=model_mode,
schedule_heuristic=schedule_heuristic, schedule_heuristic=schedule_heuristic,
......
...@@ -15,6 +15,7 @@ class ServerArgs: ...@@ -15,6 +15,7 @@ class ServerArgs:
chat_template: Optional[str] = None chat_template: Optional[str] = None
trust_remote_code: bool = True trust_remote_code: bool = True
mem_fraction_static: Optional[float] = None mem_fraction_static: Optional[float] = None
max_prefill_num_token: Optional[int] = None
tp_size: int = 1 tp_size: int = 1
model_mode: List[str] = () model_mode: List[str] = ()
schedule_heuristic: str = "lpm" schedule_heuristic: str = "lpm"
...@@ -109,6 +110,12 @@ class ServerArgs: ...@@ -109,6 +110,12 @@ class ServerArgs:
default=ServerArgs.mem_fraction_static, default=ServerArgs.mem_fraction_static,
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.", help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
) )
parser.add_argument(
"--max-prefill-num-token",
type=int,
default=ServerArgs.max_prefill_num_token,
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length."
)
parser.add_argument( parser.add_argument(
"--tp-size", "--tp-size",
type=int, type=int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment