Commit a611726e authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix: propogate env vars from input cli/yaml into process (#208)

parent 9be75482
...@@ -98,7 +98,7 @@ curl -X POST http://localhost:3000/generate \ ...@@ -98,7 +98,7 @@ curl -X POST http://localhost:3000/generate \
-d '{"text": "federer"}' -d '{"text": "federer"}'
``` ```
You should see the following output: You should see the following output
```bash ```bash
federer-mid-back federer-mid-back
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import os
import subprocess
import typing as t
import click
import click_option_group as cog
import rich
import yaml
from bentoml_cli.utils import is_valid_bento_name, is_valid_bento_tag
from rich.syntax import Syntax
from rich.table import Table
from simple_di import Provide, inject
if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
from bentoml._internal.cloud import BentoCloudClient
from bentoml._internal.container import DefaultBuilder
from click import Context, Parameter
DYNAMO_FIGLET = """
██████╗ ██╗ ██╗███╗ ██╗ █████╗ ███╗ ███╗ ██████╗
██╔══██╗╚██╗ ██╔╝████╗ ██║██╔══██╗████╗ ████║██╔═══██╗
██║ ██║ ╚████╔╝ ██╔██╗ ██║███████║██╔████╔██║██║ ██║
██║ ██║ ╚██╔╝ ██║╚██╗██║██╔══██║██║╚██╔╝██║██║ ██║
██████╔╝ ██║ ██║ ╚████║██║ ██║██║ ╚═╝ ██║╚██████╔╝
╚═════╝ ╚═╝ ╚═╝ ╚═══╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═════╝
"""
ALLOWED_PLATFORMS = [
"windows",
"linux",
"macos",
"x86_64-pc-windows-msvc",
"i686-pc-windows-msvc",
"x86_64-unknown-linux-gnu",
"aarch64-apple-darwin",
"x86_64-apple-darwin",
"aarch64-unknown-linux-gnu",
"aarch64-unknown-linux-musl",
"x86_64-unknown-linux-musl",
"x86_64-manylinux_2_17",
"x86_64-manylinux_2_28",
"x86_64-manylinux_2_31",
"x86_64-manylinux_2_32",
"x86_64-manylinux_2_33",
"x86_64-manylinux_2_34",
"x86_64-manylinux_2_35",
"x86_64-manylinux_2_36",
"x86_64-manylinux_2_37",
"x86_64-manylinux_2_38",
"x86_64-manylinux_2_39",
"x86_64-manylinux_2_40",
"aarch64-manylinux_2_17",
"aarch64-manylinux_2_28",
"aarch64-manylinux_2_31",
"aarch64-manylinux_2_32",
"aarch64-manylinux_2_33",
"aarch64-manylinux_2_34",
"aarch64-manylinux_2_35",
"aarch64-manylinux_2_36",
"aarch64-manylinux_2_37",
"aarch64-manylinux_2_38",
"aarch64-manylinux_2_39",
"aarch64-manylinux_2_40",
]
def parse_delete_targets_argument_callback(
ctx: Context,
params: Parameter,
value: t.Any, # pylint: disable=unused-argument
) -> list[str]:
if value is None:
return []
value = " ".join(value)
if "," in value:
delete_targets = value.split(",")
else:
delete_targets = value.split()
delete_targets = list(map(str.strip, delete_targets))
for delete_target in delete_targets:
if not (
is_valid_bento_tag(delete_target) or is_valid_bento_name(delete_target)
):
raise click.BadParameter(
f'Bad formatting: "{delete_target}". Please present a valid bento bundle name or "name:version" tag. For list of bento bundles, separate delete targets by ",", for example: "my_service:v1,my_service:v2,classifier"'
)
return delete_targets
def bento_management_commands() -> click.Group:
import bentoml
from bentoml import Tag
from bentoml._internal.configuration import get_quiet_mode
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.utils import human_readable_size
from bentoml.bentos import build_bentofile, import_bento
from bentoml_cli.utils import BentoMLCommandGroup
@click.group(cls=BentoMLCommandGroup)
def bentos():
"""Commands for managing Bento bundles."""
pass
@bentos.command()
@click.argument("bento_tag", type=click.STRING)
@click.option(
"-o",
"--output",
type=click.Choice(["json", "yaml", "path"]),
default="yaml",
)
@inject
def get(
bento_tag: str,
output: str,
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
) -> None: # type: ignore
"""Print Bento details by providing the bento_tag.
\b
bentoml get iris_classifier:qojf5xauugwqtgxi
bentoml get iris_classifier:qojf5xauugwqtgxi --output=json
"""
bento = bento_store.get(bento_tag)
if output == "path":
rich.print(bento.path)
elif output == "json":
info = json.dumps(bento.info.to_dict(), indent=2, default=str)
rich.print_json(info)
else:
info = yaml.dump(bento.info.to_dict(), indent=2, sort_keys=False)
rich.print(Syntax(info, "yaml", background_color="default"))
@bentos.command(name="list")
@click.argument("bento_name", type=click.STRING, required=False)
@click.option(
"-o",
"--output",
type=click.Choice(["json", "yaml", "table"]),
default="table",
)
@inject
def list_bentos(
bento_name: str,
output: str,
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
) -> None: # type: ignore
"""List Bentos in local store
\b
# show all bentos saved
$ bentoml list
\b
# show all versions of bento with the name FraudDetector
$ bentoml list FraudDetector
"""
bentos = bento_store.list(bento_name)
res: list[dict[str, str]] = []
for bento in sorted(bentos, key=lambda x: x.info.creation_time, reverse=True):
bento_size = bento.file_size
model_size = bento.total_size() - bento_size
res.append(
{
"tag": str(bento.tag),
"size": human_readable_size(bento_size),
"model_size": human_readable_size(model_size),
"creation_time": bento.info.creation_time.astimezone().strftime(
"%Y-%m-%d %H:%M:%S"
),
}
)
if output == "json":
info = json.dumps(res, indent=2)
rich.print(info)
elif output == "yaml":
info = t.cast(str, yaml.safe_dump(res, indent=2))
rich.print(Syntax(info, "yaml", background_color="default"))
else:
table = Table(box=None)
table.add_column("Tag")
table.add_column("Size")
table.add_column("Model Size")
table.add_column("Creation Time")
for bento in res:
table.add_row(
bento["tag"],
bento["size"],
bento["model_size"],
bento["creation_time"],
)
rich.print(table)
@bentos.command()
@click.argument(
"delete_targets",
nargs=-1,
callback=parse_delete_targets_argument_callback,
required=True,
)
@click.option(
"-y",
"--yes",
"--assume-yes",
is_flag=True,
help="Skip confirmation when deleting a specific bento bundle",
)
@inject
def delete(
delete_targets: list[str],
yes: bool,
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
) -> None: # type: ignore
"""Delete Bento in local bento store.
\b
Examples:
* Delete single bento bundle by "name:version", e.g: `bentoml delete IrisClassifier:v1`
* Bulk delete all bento bundles with a specific name, e.g.: `bentoml delete IrisClassifier`
* Bulk delete multiple bento bundles by name and version, separated by ",", e.g.: `bentoml delete Irisclassifier:v1,MyPredictService:v2`
* Bulk delete multiple bento bundles by name and version, separated by " ", e.g.: `bentoml delete Irisclassifier:v1 MyPredictService:v2`
* Bulk delete without confirmation, e.g.: `bentoml delete IrisClassifier --yes`
"""
def delete_target(target: str) -> None:
tag = Tag.from_str(target)
if tag.version is None:
to_delete_bentos = bento_store.list(target)
else:
to_delete_bentos = [bento_store.get(tag)]
for bento in to_delete_bentos:
if yes:
delete_confirmed = True
else:
delete_confirmed = click.confirm(f"delete bento {bento.tag}?")
if delete_confirmed:
bento_store.delete(bento.tag)
rich.print(f"{bento} deleted.")
for target in delete_targets:
delete_target(target)
@bentos.command()
@click.argument("bento_tag", type=click.STRING)
@click.argument(
"out_path",
type=click.STRING,
default="",
required=False,
)
@inject
def export(
bento_tag: str,
out_path: str,
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
) -> None: # type: ignore
"""Export a Bento to an external file archive
\b
Arguments:
BENTO_TAG: bento identifier
OUT_PATH: output path of exported bento.
If out_path argument is not provided, bento is exported to name-version.bento in the current directory.
Beside the native .bento format, we also support ('tar'), tar.gz ('gz'), tar.xz ('xz'), tar.bz2 ('bz2'), and zip.
\b
Examples:
bentoml export FraudDetector:20210709_DE14C9
bentoml export FraudDetector:20210709_DE14C9 ./my_bento.bento
bentoml export FraudDetector:latest ./my_bento.bento
bentoml export FraudDetector:latest s3://mybucket/bentos/my_bento.bento
"""
bento = bento_store.get(bento_tag)
out_path = bento.export(out_path)
rich.print(f"{bento} exported to {out_path}.")
@bentos.command(name="import")
@click.argument("bento_path", type=click.STRING)
def import_bento_(bento_path: str) -> None: # type: ignore
"""Import a previously exported Bento archive file
\b
Arguments:
BENTO_PATH: path of Bento archive file
\b
Examples:
bentoml import ./my_bento.bento
bentoml import s3://mybucket/bentos/my_bento.bento
"""
bento = import_bento(bento_path)
rich.print(f"{bento} imported.")
@bentos.command()
@click.argument("bento_tag", type=click.STRING)
@click.option(
"-f",
"--force",
is_flag=True,
default=False,
help="Force pull from remote Bento store to local and overwrite even if it already exists in local",
)
@click.option("--with-models", is_flag=True, default=False, help="Pull models too")
@inject
def pull(
bento_tag: str,
force: bool,
with_models: bool,
cloud_client: BentoCloudClient = Provide[BentoMLContainer.bentocloud_client],
) -> None: # type: ignore
"""Pull Bento from a remote Bento store server."""
cloud_client.bento.pull(bento_tag, force=force, with_models=with_models)
@bentos.command()
@click.argument("bento_tag", type=click.STRING)
@click.option(
"-f",
"--force",
is_flag=True,
default=False,
help="Forced push to remote Bento store even if it exists in remote",
)
@click.option(
"-t",
"--threads",
default=10,
help="Number of threads to use for upload",
)
@inject
def push(
bento_tag: str,
force: bool,
threads: int,
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
cloud_client: BentoCloudClient = Provide[BentoMLContainer.bentocloud_client],
) -> None: # type: ignore
"""Push Bento to a remote Bento store server."""
bento_obj = bento_store.get(bento_tag)
if not bento_obj:
raise click.ClickException(f"Bento {bento_tag} not found in local store")
cloud_client.bento.push(bento_obj, force=force, threads=threads)
@bentos.command()
@click.argument("build_ctx", type=click.Path(), default=".")
@click.option(
"-f", "--bentofile", help="Path to bentofile. Default to 'bentofile.yaml'"
)
@click.option(
"--version",
type=click.STRING,
default=None,
help="Bento version. By default the version will be generated.",
)
@click.option(
"--label",
"labels",
type=click.STRING,
multiple=True,
help="(multiple)Bento labels",
metavar="KEY=VALUE",
)
@click.option(
"-o",
"--output",
type=click.Choice(["tag", "default"]),
default="default",
show_default=True,
help="Output log format. '-o tag' to display only bento tag.",
)
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name="Utilities options")
@cog.optgroup.option(
"--containerize",
default=False,
is_flag=True,
type=click.BOOL,
help="Whether to containerize the Bento after building. '--containerize' is the shortcut of 'bentoml build && bentoml containerize'.",
)
@cog.optgroup.option(
"--push",
default=False,
is_flag=True,
type=click.BOOL,
help="Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first.",
)
@click.option(
"--force", is_flag=True, default=False, help="Forced push to BentoCloud"
)
@click.option("--threads", default=10, help="Number of threads to use for upload")
@click.option(
"--platform",
default=None,
help="Platform to build for",
type=click.Choice(ALLOWED_PLATFORMS),
)
def build( # type: ignore
build_ctx: str,
bentofile: str | None,
version: str | None,
labels: tuple[str, ...],
output: t.Literal["tag", "default"],
push: bool,
force: bool,
threads: int,
containerize: bool,
platform: str | None,
):
"""Build a new Bento from current directory."""
from bentoml._internal.configuration import set_quiet_mode
from bentoml._internal.log import configure_logging
if output == "tag":
set_quiet_mode()
configure_logging()
labels_dict: dict[str, t.Any] = {}
for label in labels:
key, label_value = label.split("=", 1)
labels_dict[key] = label_value
service: str | None = None
if ":" in build_ctx:
service = build_ctx
build_ctx = "."
bento = build_bentofile(
bentofile,
service=service,
version=version,
labels=labels_dict or None,
build_ctx=build_ctx,
platform=platform,
)
containerize_cmd = f"dynamo containerize {bento.tag}"
# push_cmd = f"dynamo push {bento.tag}"
# NOTE: Don't remove the return statement here, since we will need this
# for usage stats collection if users are opt-in.
if output == "tag":
rich.print(f"__tag__:{bento.tag}")
else:
if not get_quiet_mode():
rich.print(DYNAMO_FIGLET)
rich.print(f"[green]Successfully built {bento.tag}.")
next_steps = []
if not containerize:
next_steps.append(
"\n\n* Containerize your Bento with `dynamo containerize`:\n"
f" $ {containerize_cmd} [or dynamo build --containerize]"
)
# if not push:
# next_steps.append(
# "\n\n* Push to BentoCloud with `bentoml push`:\n"
# f" $ {push_cmd} [or bentoml build --push]"
# )
if next_steps:
rich.print(f"\n[blue]Next steps: {''.join(next_steps)}[/]")
if push:
if not get_quiet_mode():
rich.print(f"\n[magenta]Pushing {bento} to BentoCloud...[/]")
cloud_client = BentoMLContainer.bentocloud_client.get()
cloud_client.bento.push(bento, force=force, threads=threads)
elif containerize:
backend: DefaultBuilder = t.cast(
"DefaultBuilder", os.getenv("BENTOML_CONTAINERIZE_BACKEND", "docker")
)
try:
bentoml.container.health(backend)
except subprocess.CalledProcessError:
raise bentoml.exceptions.BentoMLException(
f"Backend {backend} is not healthy"
)
bentoml.container.build(bento.tag, backend=backend)
return bento
return bentos
bento_command = bento_management_commands()
...@@ -68,11 +68,17 @@ def _parse_service_arg(arg_name: str, arg_value: str) -> tuple[str, str, t.Any]: ...@@ -68,11 +68,17 @@ def _parse_service_arg(arg_name: str, arg_value: str) -> tuple[str, str, t.Any]:
parts = arg_name.split(".") parts = arg_name.split(".")
service = parts[0] service = parts[0]
# Handle nested keys (e.g., ServiceArgs.workers or ServiceArgs.envs.CUDA_VISIBLE_DEVICES)
nested_keys = parts[1:] nested_keys = parts[1:]
# Parse value based on type # Special case: if this is a ServiceArgs.envs.* path, keep value as string
if (
len(nested_keys) >= 2
and nested_keys[0] == "ServiceArgs"
and nested_keys[1] == "envs"
):
value: t.Union[str, int, float, bool, dict, list] = arg_value
else:
# Parse value based on type for non-env vars
try: try:
value = json.loads(arg_value) value = json.loads(arg_value)
except json.JSONDecodeError: except json.JSONDecodeError:
......
...@@ -188,18 +188,34 @@ def create_dynamo_watcher( ...@@ -188,18 +188,34 @@ def create_dynamo_watcher(
if worker_envs: if worker_envs:
args.extend(["--worker-env", json.dumps(worker_envs)]) args.extend(["--worker-env", json.dumps(worker_envs)])
# Update env to include ServiceConfig # Update env to include ServiceConfig and service-specific environment variables
worker_env = env.copy() if env else {} worker_env = env.copy() if env else {}
# Pass through the main service config
if "DYNAMO_SERVICE_CONFIG" in os.environ: if "DYNAMO_SERVICE_CONFIG" in os.environ:
worker_env["DYNAMO_SERVICE_CONFIG"] = os.environ["DYNAMO_SERVICE_CONFIG"] worker_env["DYNAMO_SERVICE_CONFIG"] = os.environ["DYNAMO_SERVICE_CONFIG"]
# Create the watcher with dependency map in environment # Get service-specific environment variables from DYNAMO_SERVICE_ENVS
if "DYNAMO_SERVICE_ENVS" in os.environ:
try:
service_envs = json.loads(os.environ["DYNAMO_SERVICE_ENVS"])
if svc.name in service_envs:
service_args = service_envs[svc.name].get("ServiceArgs", {})
if "envs" in service_args:
worker_env.update(service_args["envs"])
logger.info(
f"Added service-specific environment variables for {svc.name}"
)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse DYNAMO_SERVICE_ENVS: {e}")
# Create the watcher with updated environment
watcher = create_watcher( watcher = create_watcher(
name=f"dynamo_service_{svc.name}", name=f"dynamo_service_{svc.name}",
args=args, args=args,
numprocesses=num_workers, numprocesses=num_workers,
working_dir=working_dir, working_dir=working_dir,
env=worker_env, # Use updated environment env=worker_env,
) )
return watcher, socket, uri return watcher, socket, uri
......
...@@ -19,7 +19,7 @@ import os ...@@ -19,7 +19,7 @@ import os
import bentoml import bentoml
# TODO: "dynamo:latest-vllm" image will not be available to image builder in k8s # TODO: "dynamo:latest-vllm-dev" image will not be available to image builder in k8s
# so We'd consider publishing the base image for releases to public nvcr.io registry. # so We'd consider publishing the base image for releases to public nvcr.io registry.
image_name = os.getenv("DYNAMO_IMAGE", "dynamo:latest-vllm") image_name = os.getenv("DYNAMO_IMAGE", "dynamo:latest-vllm-dev")
DYNAMO_IMAGE = bentoml.images.PythonImage(base_image=image_name) DYNAMO_IMAGE = bentoml.images.PythonImage(base_image=image_name)
...@@ -177,11 +177,24 @@ class DynamoService(Service[T]): ...@@ -177,11 +177,24 @@ class DynamoService(Service[T]):
return next_service return next_service
def _remove_service_args(self, service_name: str): def _remove_service_args(self, service_name: str):
"""Remove ServiceArgs from the environment config after using them""" """Remove ServiceArgs from the environment config after using them, preserving envs"""
config_str = os.environ.get("DYNAMO_SERVICE_CONFIG") config_str = os.environ.get("DYNAMO_SERVICE_CONFIG")
if config_str: if config_str:
config = json.loads(config_str) config = json.loads(config_str)
if service_name in config and "ServiceArgs" in config[service_name]: if service_name in config and "ServiceArgs" in config[service_name]:
# Save envs to separate env var before removing ServiceArgs
service_args = config[service_name]["ServiceArgs"]
if "envs" in service_args:
service_envs = os.environ.get("DYNAMO_SERVICE_ENVS", "{}")
envs_config = json.loads(service_envs)
if service_name not in envs_config:
envs_config[service_name] = {}
envs_config[service_name]["ServiceArgs"] = {
"envs": service_args["envs"]
}
os.environ["DYNAMO_SERVICE_ENVS"] = json.dumps(envs_config)
# Remove ServiceArgs from main config
del config[service_name]["ServiceArgs"] del config[service_name]["ServiceArgs"]
os.environ["DYNAMO_SERVICE_CONFIG"] = json.dumps(config) os.environ["DYNAMO_SERVICE_CONFIG"] = json.dumps(config)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
from components.processor import Processor
from components.routerless.worker import VllmWorkerRouterLess
from components.worker import VllmWorker
from pydantic import BaseModel
from dynamo.sdk import depends, service
from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE
class FrontendConfig(BaseModel):
model: str
endpoint: str
port: int = 8080
@service(
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
image=DYNAMO_IMAGE,
)
# todo this should be called ApiServer
class Frontend:
worker = depends(VllmWorker)
worker_routerless = depends(VllmWorkerRouterLess)
processor = depends(Processor)
def __init__(self):
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
subprocess.run(
["llmctl", "http", "remove", "chat-models", frontend_config.model]
)
subprocess.run(
[
"llmctl",
"http",
"add",
"chat-models",
frontend_config.model,
frontend_config.endpoint,
]
)
print("Starting HTTP server")
subprocess.run(
["http", "-p", str(frontend_config.port)], stdout=None, stderr=None
)
...@@ -22,6 +22,7 @@ from pydantic import BaseModel ...@@ -22,6 +22,7 @@ from pydantic import BaseModel
from dynamo.sdk import depends, service from dynamo.sdk import depends, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE
class FrontendConfig(BaseModel): class FrontendConfig(BaseModel):
...@@ -33,6 +34,7 @@ class FrontendConfig(BaseModel): ...@@ -33,6 +34,7 @@ class FrontendConfig(BaseModel):
@service( @service(
resources={"cpu": "10", "memory": "20Gi"}, resources={"cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
image=DYNAMO_IMAGE,
) )
# todo this should be called ApiServer # todo this should be called ApiServer
class Frontend: class Frontend:
...@@ -58,6 +60,12 @@ class Frontend: ...@@ -58,6 +60,12 @@ class Frontend:
] ]
) )
subprocess.run( print("Starting HTTP server")
process = subprocess.Popen(
["http", "-p", str(frontend_config.port)], stdout=None, stderr=None ["http", "-p", str(frontend_config.port)], stdout=None, stderr=None
) )
try:
process.wait()
except KeyboardInterrupt:
process.terminate()
process.wait()
...@@ -31,14 +31,7 @@ from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest ...@@ -31,14 +31,7 @@ from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from dynamo.llm import KvMetricsPublisher from dynamo.llm import KvMetricsPublisher
from dynamo.sdk import ( from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
async_on_start,
depends,
dynamo_context,
dynamo_endpoint,
server_context,
service,
)
@service( @service(
...@@ -90,13 +83,6 @@ class VllmWorker: ...@@ -90,13 +83,6 @@ class VllmWorker:
os.environ["VLLM_KV_NAMESPACE"] = "dynamo" os.environ["VLLM_KV_NAMESPACE"] = "dynamo"
os.environ["VLLM_KV_COMPONENT"] = class_name os.environ["VLLM_KV_COMPONENT"] = class_name
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}") vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
# note: worker_index is 1-based, but CUDA_VISIBLE_DEVICES is 0-based
gpu_idx = (
self.engine_args.cuda_visible_device_offset
+ server_context.worker_index
- 1
)
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_idx}"
self.metrics_publisher = KvMetricsPublisher() self.metrics_publisher = KvMetricsPublisher()
@async_on_start @async_on_start
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment