"docs/source/vscode:/vscode.git/clone" did not exist on "ae3316c846aa9f126e8a7d31144ce0bc9c5ff677"
Unverified Commit 9de9a468 authored by psych0v0yager's avatar psych0v0yager Committed by GitHub
Browse files

Added the ability to Modify the Context Length (#210)

parent ce3b2610
...@@ -57,7 +57,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -57,7 +57,7 @@ class ModelRpcServer(rpyc.Service):
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
server_args.model_path, server_args.trust_remote_code server_args.model_path, server_args.trust_remote_code, context_length=server_args.context_length
) )
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
self.model_config, self.model_config,
......
...@@ -11,14 +11,19 @@ class ModelConfig: ...@@ -11,14 +11,19 @@ class ModelConfig:
path: str, path: str,
trust_remote_code: bool = True, trust_remote_code: bool = True,
revision: Optional[str] = None, revision: Optional[str] = None,
context_length: Optional[int] = None,
) -> None: ) -> None:
self.path = path self.path = path
self.trust_remote_code = trust_remote_code self.trust_remote_code = trust_remote_code
self.revision = revision self.revision = revision
self.hf_config = get_config(self.path, trust_remote_code, revision) self.hf_config = get_config(self.path, trust_remote_code, revision)
if context_length is not None:
self.context_len = context_length
else:
self.context_len = get_context_length(self.hf_config)
# Unify the config keys for hf_config # Unify the config keys for hf_config
self.context_len = get_context_length(self.hf_config)
self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
self.num_attention_heads = self.hf_config.num_attention_heads self.num_attention_heads = self.hf_config.num_attention_heads
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None) self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
......
...@@ -546,6 +546,7 @@ class Runtime: ...@@ -546,6 +546,7 @@ class Runtime:
trust_remote_code: bool = True, trust_remote_code: bool = True,
mem_fraction_static: float = ServerArgs.mem_fraction_static, mem_fraction_static: float = ServerArgs.mem_fraction_static,
max_prefill_num_token: int = ServerArgs.max_prefill_num_token, max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
context_length: int = ServerArgs.context_length,
tp_size: int = 1, tp_size: int = 1,
model_mode: List[str] = (), model_mode: List[str] = (),
schedule_heuristic: str = "lpm", schedule_heuristic: str = "lpm",
...@@ -567,6 +568,7 @@ class Runtime: ...@@ -567,6 +568,7 @@ class Runtime:
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
mem_fraction_static=mem_fraction_static, mem_fraction_static=mem_fraction_static,
max_prefill_num_token=max_prefill_num_token, max_prefill_num_token=max_prefill_num_token,
context_length=context_length,
tp_size=tp_size, tp_size=tp_size,
model_mode=model_mode, model_mode=model_mode,
schedule_heuristic=schedule_heuristic, schedule_heuristic=schedule_heuristic,
......
...@@ -16,6 +16,7 @@ class ServerArgs: ...@@ -16,6 +16,7 @@ class ServerArgs:
trust_remote_code: bool = True trust_remote_code: bool = True
mem_fraction_static: Optional[float] = None mem_fraction_static: Optional[float] = None
max_prefill_num_token: Optional[int] = None max_prefill_num_token: Optional[int] = None
context_length: Optional[int] = None
tp_size: int = 1 tp_size: int = 1
model_mode: List[str] = () model_mode: List[str] = ()
schedule_heuristic: str = "lpm" schedule_heuristic: str = "lpm"
...@@ -117,6 +118,12 @@ class ServerArgs: ...@@ -117,6 +118,12 @@ class ServerArgs:
default=ServerArgs.max_prefill_num_token, default=ServerArgs.max_prefill_num_token,
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.", help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
) )
parser.add_argument(
"--context-length",
type=int,
default=ServerArgs.context_length,
help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
)
parser.add_argument( parser.add_argument(
"--tp-size", "--tp-size",
type=int, type=int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment