Unverified Commit 08c01d8c authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

fix: register model after engine load (#1145)

parent 8d636ebd
...@@ -84,13 +84,6 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -84,13 +84,6 @@ async def init(runtime: DistributedRuntime, config: Config):
""" """
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
)
arg_map = { arg_map = {
"model_path": config.model_path, "model_path": config.model_path,
...@@ -124,6 +117,14 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -124,6 +117,14 @@ async def init(runtime: DistributedRuntime, config: Config):
engine_args = ServerArgs(**arg_map) engine_args = ServerArgs(**arg_map)
engine_client = sglang.Engine(server_args=engine_args) engine_client = sglang.Engine(server_args=engine_args)
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes) # the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked # after the lease is revoked
await endpoint.serve_endpoint(RequestHandler(engine_client).generate) await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
......
...@@ -133,17 +133,18 @@ async def worker(runtime: DistributedRuntime): ...@@ -133,17 +133,18 @@ async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args()) await init(runtime, cmd_line_args())
def _check_and_set_env_value(key, expected, allow_override=False):
if not allow_override and key in os.environ and os.environ[key] != expected:
raise ValueError(
f"{key} is set and doesn't equal expected {expected}. Please unset variable before launch."
)
os.environ.setdefault(key, expected)
async def init(runtime: DistributedRuntime, config: Config): async def init(runtime: DistributedRuntime, config: Config):
""" """
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
)
arg_map = { arg_map = {
"model": config.model_path, "model": config.model_path,
...@@ -170,14 +171,20 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -170,14 +171,20 @@ async def init(runtime: DistributedRuntime, config: Config):
arg_map = {**arg_map, **json_map} # json_map gets precedence arg_map = {**arg_map, **json_map} # json_map gets precedence
# Patch won't start KVCacheEventManager unless these four are set # Patch won't start KVCacheEventManager unless these four are set
os.environ["VLLM_WORKER_ID"] = str(endpoint.lease_id())
os.environ[ component = runtime.namespace(config.namespace).component(config.component)
"VLLM_KV_CAPI_PATH" await component.create_service()
] = "libdynamo_llm_capi.so" # Must be on LD_LIBRARY_PATH endpoint = component.endpoint(config.endpoint)
os.environ["VLLM_KV_NAMESPACE"] = config.namespace
os.environ["VLLM_KV_COMPONENT"] = config.component _check_and_set_env_value("VLLM_WORKER_ID", str(endpoint.lease_id()))
_check_and_set_env_value(
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests "VLLM_KV_CAPI_PATH", "libdynamo_llm_capi.so", allow_override=True
)
_check_and_set_env_value("VLLM_KV_NAMESPACE", config.namespace)
_check_and_set_env_value("VLLM_KV_COMPONENT", config.component)
_check_and_set_env_value(
"VLLM_NO_USAGE_STATS", "1", allow_override=True
) # Avoid internal HTTP requests
engine_args = AsyncEngineArgs(**arg_map) engine_args = AsyncEngineArgs(**arg_map)
model_config = engine_args.create_model_config() model_config = engine_args.create_model_config()
# Load default sampling params from `generation_config.json` # Load default sampling params from `generation_config.json`
...@@ -186,6 +193,9 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -186,6 +193,9 @@ async def init(runtime: DistributedRuntime, config: Config):
engine_context = build_async_engine_client_from_engine_args(engine_args) engine_context = build_async_engine_client_from_engine_args(engine_args)
engine_client = await engine_context.__aenter__() engine_client = await engine_context.__aenter__()
await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
)
handler = RequestHandler(component, engine_client, default_sampling_params) handler = RequestHandler(component, engine_client, default_sampling_params)
handler.setup_kv_metrics() handler.setup_kv_metrics()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment