Unverified Commit dea5f887 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Metrics labels for multimodal. (#2835)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 8d54eb7e
......@@ -261,7 +261,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
......
......@@ -332,7 +332,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
......
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import List, Optional, Tuple
from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import StatLoggerBase
......@@ -48,9 +48,15 @@ class NullStatLogger(StatLoggerBase):
class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
def __init__(self, component: Component, dp_rank: int) -> None:
def __init__(
self,
component: Component,
dp_rank: int,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.inner = WorkerMetricsPublisher()
self.inner.create_endpoint(component)
metrics_labels = metrics_labels or []
self.inner.create_endpoint(component, metrics_labels)
self.dp_rank = dp_rank
self.num_gpu_block = 1
self.request_total_slots = 1
......@@ -141,15 +147,23 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
class StatLoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM."""
def __init__(self, component: Component, dp_rank: int = 0) -> None:
def __init__(
self,
component: Component,
dp_rank: int = 0,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.component = component
self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank
self.metrics_labels = metrics_labels or []
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank:
return NullStatLogger()
logger = DynamoStatLoggerPublisher(self.component, dp_rank)
logger = DynamoStatLoggerPublisher(
self.component, dp_rank, metrics_labels=self.metrics_labels
)
self.created_logger = logger
return logger
......
......@@ -308,7 +308,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
......
......@@ -25,7 +25,6 @@ from typing import Tuple
import torch
import uvloop
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
......@@ -107,14 +106,15 @@ class VllmBaseWorker:
def __init__(
self,
args: argparse.Namespace,
engine_args: AsyncEngineArgs,
component: Component,
endpoint: Endpoint,
config: Config,
):
self.enable_disagg = args.enable_disagg
self.endpoint = args.endpoint
self.downstream_endpoint = args.downstream_endpoint
self.engine_args = engine_args
self.engine_args = config.engine_args
self.config = config
self.setup_vllm_engine(component, endpoint)
async def async_init(self, runtime: DistributedRuntime):
......@@ -142,6 +142,7 @@ class VllmBaseWorker:
self.stats_logger = StatLoggerFactory(
component,
self.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", self.config.model)],
)
self.engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
......@@ -444,20 +445,24 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
if args.worker_type in ["prefill", "encode_prefill"]:
handler: VllmBaseWorker = VllmPDWorker(
args, config.engine_args, component, generate_endpoint
args, component, generate_endpoint, config
)
elif args.worker_type == "decode":
handler = VllmDecodeWorker(
args, config.engine_args, component, generate_endpoint
)
handler = VllmDecodeWorker(args, component, generate_endpoint, config)
await handler.async_init(runtime)
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
metrics_labels = [("model", config.model)]
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=metrics_labels
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=metrics_labels
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment