"tutorials/vscode:/vscode.git/clone" did not exist on "47993776dff8b38d456f30e75b2afad1bd6d5df9"
Unverified Commit 08360553 authored by Kai-Hsun Chen's avatar Kai-Hsun Chen Committed by GitHub
Browse files

[Chore] Rename model_overide_args to model_override_args (#1284)


Signed-off-by: default avatarKai-Hsun Chen <kaihsun@anyscale.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 00b19f19
...@@ -197,19 +197,19 @@ if __name__ == "__main__": ...@@ -197,19 +197,19 @@ if __name__ == "__main__":
print("Invalid model path. Please specify a valid model path.") print("Invalid model path. Please specify a valid model path.")
exit() exit()
model_overide_args = {} model_override_args = {}
model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride model_override_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
model_overide_args["architectures"] = ["LlavaVidForCausalLM"] model_override_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = args.num_frames model_override_args["num_frames"] = args.num_frames
model_overide_args["model_type"] = "llava" model_override_args["model_type"] = "llava"
if "34b" in args.model_path.lower(): if "34b" in args.model_path.lower():
model_overide_args["image_token_index"] = 64002 model_override_args["image_token_index"] = 64002
if args.num_frames == 32: if args.num_frames == 32:
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} model_override_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_overide_args["max_sequence_length"] = 4096 * 2 model_override_args["max_sequence_length"] = 4096 * 2
model_overide_args["tokenizer_model_max_length"] = 4096 * 2 model_override_args["tokenizer_model_max_length"] = 4096 * 2
elif args.num_frames < 32: elif args.num_frames < 32:
pass pass
else: else:
...@@ -223,7 +223,7 @@ if __name__ == "__main__": ...@@ -223,7 +223,7 @@ if __name__ == "__main__":
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
port=cur_port, port=cur_port,
additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4], additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
model_overide_args=model_overide_args, model_override_args=model_override_args,
tp_size=1, tp_size=1,
) )
sgl.set_default_backend(runtime) sgl.set_default_backend(runtime)
......
...@@ -10,17 +10,17 @@ if __name__ == "__main__": ...@@ -10,17 +10,17 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args) server_args = ServerArgs.from_cli_args(args)
model_overide_args = {} model_override_args = {}
model_overide_args["mm_spatial_pool_stride"] = 2 model_override_args["mm_spatial_pool_stride"] = 2
model_overide_args["architectures"] = ["LlavaVidForCausalLM"] model_override_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = 16 model_override_args["num_frames"] = 16
model_overide_args["model_type"] = "llavavid" model_override_args["model_type"] = "llavavid"
if model_overide_args["num_frames"] == 32: if model_override_args["num_frames"] == 32:
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} model_override_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_overide_args["max_sequence_length"] = 4096 * 2 model_override_args["max_sequence_length"] = 4096 * 2
model_overide_args["tokenizer_model_max_length"] = 4096 * 2 model_override_args["tokenizer_model_max_length"] = 4096 * 2
model_overide_args["model_max_length"] = 4096 * 2 model_override_args["model_max_length"] = 4096 * 2
if "34b" in args.model_path.lower(): if "34b" in args.model_path.lower():
model_overide_args["image_token_index"] = 64002 model_override_args["image_token_index"] = 64002
launch_server(server_args, model_overide_args, None) launch_server(server_args, model_override_args, None)
...@@ -62,7 +62,7 @@ def get_config( ...@@ -62,7 +62,7 @@ def get_config(
model: str, model: str,
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None, revision: Optional[str] = None,
model_overide_args: Optional[dict] = None, model_override_args: Optional[dict] = None,
): ):
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision model, trust_remote_code=trust_remote_code, revision=revision
...@@ -70,8 +70,8 @@ def get_config( ...@@ -70,8 +70,8 @@ def get_config(
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision) config = config_class.from_pretrained(model, revision=revision)
if model_overide_args: if model_override_args:
config.update(model_overide_args) config.update(model_override_args)
return config return config
......
...@@ -71,12 +71,12 @@ class ControllerMulti: ...@@ -71,12 +71,12 @@ class ControllerMulti:
self, self,
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
model_overide_args, model_override_args,
): ):
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
self.port_args = port_args self.port_args = port_args
self.model_overide_args = model_overide_args self.model_override_args = model_override_args
self.load_balance_method = LoadBalanceMethod.from_str( self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method server_args.load_balance_method
) )
...@@ -114,7 +114,7 @@ class ControllerMulti: ...@@ -114,7 +114,7 @@ class ControllerMulti:
self.server_args, self.server_args,
self.port_args, self.port_args,
pipe_controller_writer, pipe_controller_writer,
self.model_overide_args, self.model_override_args,
True, True,
gpu_ids, gpu_ids,
dp_worker_id, dp_worker_id,
...@@ -189,14 +189,14 @@ def start_controller_process( ...@@ -189,14 +189,14 @@ def start_controller_process(
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
pipe_writer, pipe_writer,
model_overide_args: dict, model_override_args: dict,
): ):
"""Start a controller process.""" """Start a controller process."""
configure_logger(server_args) configure_logger(server_args)
try: try:
controller = ControllerMulti(server_args, port_args, model_overide_args) controller = ControllerMulti(server_args, port_args, model_override_args)
except Exception: except Exception:
pipe_writer.send(get_exception_traceback()) pipe_writer.send(get_exception_traceback())
raise raise
......
...@@ -40,7 +40,7 @@ class ControllerSingle: ...@@ -40,7 +40,7 @@ class ControllerSingle:
self, self,
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
model_overide_args: dict, model_override_args: dict,
gpu_ids: List[int], gpu_ids: List[int],
is_data_parallel_worker: bool, is_data_parallel_worker: bool,
dp_worker_id: int, dp_worker_id: int,
...@@ -76,7 +76,7 @@ class ControllerSingle: ...@@ -76,7 +76,7 @@ class ControllerSingle:
tp_rank_range, tp_rank_range,
server_args, server_args,
port_args.nccl_ports[dp_worker_id], port_args.nccl_ports[dp_worker_id],
model_overide_args, model_override_args,
) )
# Launch tp rank 0 # Launch tp rank 0
...@@ -85,7 +85,7 @@ class ControllerSingle: ...@@ -85,7 +85,7 @@ class ControllerSingle:
0, 0,
server_args, server_args,
port_args.nccl_ports[dp_worker_id], port_args.nccl_ports[dp_worker_id],
model_overide_args, model_override_args,
) )
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
...@@ -126,7 +126,7 @@ def start_controller_process( ...@@ -126,7 +126,7 @@ def start_controller_process(
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
pipe_writer: multiprocessing.connection.Connection, pipe_writer: multiprocessing.connection.Connection,
model_overide_args: dict, model_override_args: dict,
is_data_parallel_worker: bool = False, is_data_parallel_worker: bool = False,
gpu_ids: List[int] = None, gpu_ids: List[int] = None,
dp_worker_id: int = None, dp_worker_id: int = None,
...@@ -149,7 +149,7 @@ def start_controller_process( ...@@ -149,7 +149,7 @@ def start_controller_process(
controller = ControllerSingle( controller = ControllerSingle(
server_args, server_args,
port_args, port_args,
model_overide_args, model_override_args,
gpu_ids, gpu_ids,
is_data_parallel_worker, is_data_parallel_worker,
dp_worker_id, dp_worker_id,
......
...@@ -77,7 +77,7 @@ class TokenizerManager: ...@@ -77,7 +77,7 @@ class TokenizerManager:
self, self,
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
model_overide_args: dict = None, model_override_args: dict = None,
): ):
self.server_args = server_args self.server_args = server_args
...@@ -95,7 +95,7 @@ class TokenizerManager: ...@@ -95,7 +95,7 @@ class TokenizerManager:
self.hf_config = get_config( self.hf_config = get_config(
self.model_path, self.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args, model_override_args=model_override_args,
) )
self.is_generation = is_generation_model( self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding self.hf_config.architectures, self.server_args.is_embedding
......
...@@ -76,7 +76,7 @@ class ModelTpServer: ...@@ -76,7 +76,7 @@ class ModelTpServer:
tp_rank: int, tp_rank: int,
server_args: ServerArgs, server_args: ServerArgs,
nccl_port: int, nccl_port: int,
model_overide_args: dict, model_override_args: dict,
): ):
suppress_other_loggers() suppress_other_loggers()
...@@ -93,7 +93,7 @@ class ModelTpServer: ...@@ -93,7 +93,7 @@ class ModelTpServer:
server_args.model_path, server_args.model_path,
server_args.trust_remote_code, server_args.trust_remote_code,
context_length=server_args.context_length, context_length=server_args.context_length,
model_overide_args=model_overide_args, model_override_args=model_override_args,
) )
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
...@@ -876,7 +876,7 @@ def run_tp_server( ...@@ -876,7 +876,7 @@ def run_tp_server(
tp_rank: int, tp_rank: int,
server_args: ServerArgs, server_args: ServerArgs,
nccl_port: int, nccl_port: int,
model_overide_args: dict, model_override_args: dict,
): ):
"""Run a tensor parallel model server.""" """Run a tensor parallel model server."""
configure_logger(server_args, prefix=f" TP{tp_rank}") configure_logger(server_args, prefix=f" TP{tp_rank}")
...@@ -887,7 +887,7 @@ def run_tp_server( ...@@ -887,7 +887,7 @@ def run_tp_server(
tp_rank, tp_rank,
server_args, server_args,
nccl_port, nccl_port,
model_overide_args, model_override_args,
) )
tp_cpu_group = model_server.model_runner.tp_group.cpu_group tp_cpu_group = model_server.model_runner.tp_group.cpu_group
...@@ -904,14 +904,14 @@ def launch_tp_servers( ...@@ -904,14 +904,14 @@ def launch_tp_servers(
tp_rank_range: List[int], tp_rank_range: List[int],
server_args: ServerArgs, server_args: ServerArgs,
nccl_port: int, nccl_port: int,
model_overide_args: dict, model_override_args: dict,
): ):
"""Launch multiple tensor parallel servers.""" """Launch multiple tensor parallel servers."""
procs = [] procs = []
for i in tp_rank_range: for i in tp_rank_range:
proc = multiprocessing.Process( proc = multiprocessing.Process(
target=run_tp_server, target=run_tp_server,
args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args), args=(gpu_ids[i], i, server_args, nccl_port, model_override_args),
) )
proc.start() proc.start()
procs.append(proc) procs.append(proc)
......
...@@ -33,17 +33,17 @@ class ModelConfig: ...@@ -33,17 +33,17 @@ class ModelConfig:
trust_remote_code: bool = True, trust_remote_code: bool = True,
revision: Optional[str] = None, revision: Optional[str] = None,
context_length: Optional[int] = None, context_length: Optional[int] = None,
model_overide_args: Optional[dict] = None, model_override_args: Optional[dict] = None,
) -> None: ) -> None:
self.path = path self.path = path
self.trust_remote_code = trust_remote_code self.trust_remote_code = trust_remote_code
self.revision = revision self.revision = revision
self.model_overide_args = model_overide_args self.model_override_args = model_override_args
self.hf_config = get_config( self.hf_config = get_config(
self.path, self.path,
trust_remote_code, trust_remote_code,
revision, revision,
model_overide_args=model_overide_args, model_override_args=model_override_args,
) )
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
if context_length is not None: if context_length is not None:
......
...@@ -195,9 +195,9 @@ class ModelRunner: ...@@ -195,9 +195,9 @@ class ModelRunner:
monkey_patch_vllm_qvk_linear_loader() monkey_patch_vllm_qvk_linear_loader()
self.dtype = self.vllm_model_config.dtype self.dtype = self.vllm_model_config.dtype
if self.model_config.model_overide_args is not None: if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update( self.vllm_model_config.hf_config.update(
self.model_config.model_overide_args self.model_config.model_override_args
) )
self.model = get_model( self.model = get_model(
......
...@@ -272,7 +272,7 @@ async def retrieve_file_content(file_id: str): ...@@ -272,7 +272,7 @@ async def retrieve_file_content(file_id: str):
def launch_server( def launch_server(
server_args: ServerArgs, server_args: ServerArgs,
model_overide_args: Optional[dict] = None, model_override_args: Optional[dict] = None,
pipe_finish_writer: Optional[mp.connection.Connection] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None,
): ):
"""Launch an HTTP server.""" """Launch an HTTP server."""
...@@ -317,7 +317,7 @@ def launch_server( ...@@ -317,7 +317,7 @@ def launch_server(
tp_rank_range, tp_rank_range,
server_args, server_args,
ports[3], ports[3],
model_overide_args, model_override_args,
) )
try: try:
...@@ -328,7 +328,7 @@ def launch_server( ...@@ -328,7 +328,7 @@ def launch_server(
return return
# Launch processes # Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args)
if server_args.chat_template: if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
...@@ -341,7 +341,7 @@ def launch_server( ...@@ -341,7 +341,7 @@ def launch_server(
proc_controller = mp.Process( proc_controller = mp.Process(
target=start_controller_process, target=start_controller_process,
args=(server_args, port_args, pipe_controller_writer, model_overide_args), args=(server_args, port_args, pipe_controller_writer, model_override_args),
) )
proc_controller.start() proc_controller.start()
...@@ -501,7 +501,7 @@ class Runtime: ...@@ -501,7 +501,7 @@ class Runtime:
def __init__( def __init__(
self, self,
log_level: str = "error", log_level: str = "error",
model_overide_args: Optional[dict] = None, model_override_args: Optional[dict] = None,
*args, *args,
**kwargs, **kwargs,
): ):
...@@ -525,7 +525,7 @@ class Runtime: ...@@ -525,7 +525,7 @@ class Runtime:
proc = mp.Process( proc = mp.Process(
target=launch_server, target=launch_server,
args=(self.server_args, model_overide_args, pipe_writer), args=(self.server_args, model_override_args, pipe_writer),
) )
proc.start() proc.start()
pipe_writer.close() pipe_writer.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment