Unverified Commit 08c01d8c authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

fix: register model after engine load (#1145)

parent 8d636ebd
......@@ -84,13 +84,6 @@ async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
)
arg_map = {
"model_path": config.model_path,
......@@ -124,6 +117,14 @@ async def init(runtime: DistributedRuntime, config: Config):
engine_args = ServerArgs(**arg_map)
engine_client = sglang.Engine(server_args=engine_args)
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
......
......@@ -133,17 +133,18 @@ async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
def _check_and_set_env_value(key, expected, allow_override=False):
if not allow_override and key in os.environ and os.environ[key] != expected:
raise ValueError(
f"{key} is set and doesn't equal expected {expected}. Please unset variable before launch."
)
os.environ.setdefault(key, expected)
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
)
arg_map = {
"model": config.model_path,
......@@ -170,14 +171,20 @@ async def init(runtime: DistributedRuntime, config: Config):
arg_map = {**arg_map, **json_map} # json_map gets precedence
# Patch won't start KVCacheEventManager unless these four are set
os.environ["VLLM_WORKER_ID"] = str(endpoint.lease_id())
os.environ[
"VLLM_KV_CAPI_PATH"
] = "libdynamo_llm_capi.so" # Must be on LD_LIBRARY_PATH
os.environ["VLLM_KV_NAMESPACE"] = config.namespace
os.environ["VLLM_KV_COMPONENT"] = config.component
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
_check_and_set_env_value("VLLM_WORKER_ID", str(endpoint.lease_id()))
_check_and_set_env_value(
"VLLM_KV_CAPI_PATH", "libdynamo_llm_capi.so", allow_override=True
)
_check_and_set_env_value("VLLM_KV_NAMESPACE", config.namespace)
_check_and_set_env_value("VLLM_KV_COMPONENT", config.component)
_check_and_set_env_value(
"VLLM_NO_USAGE_STATS", "1", allow_override=True
) # Avoid internal HTTP requests
engine_args = AsyncEngineArgs(**arg_map)
model_config = engine_args.create_model_config()
# Load default sampling params from `generation_config.json`
......@@ -186,6 +193,9 @@ async def init(runtime: DistributedRuntime, config: Config):
engine_context = build_async_engine_client_from_engine_args(engine_args)
engine_client = await engine_context.__aenter__()
await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
)
handler = RequestHandler(component, engine_client, default_sampling_params)
handler.setup_kv_metrics()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment