"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "e99e467384001e284e0722a33362866b10fed65b"
Unverified Commit dea5f887 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Metrics labels for multimodal. (#2835)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 8d54eb7e
...@@ -261,7 +261,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co ...@@ -261,7 +261,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
try: try:
await asyncio.gather( await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate), generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to serve endpoints: {e}") logger.error(f"Failed to serve endpoints: {e}")
......
...@@ -332,7 +332,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co ...@@ -332,7 +332,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
try: try:
await asyncio.gather( await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate), generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to serve endpoints: {e}") logger.error(f"Failed to serve endpoints: {e}")
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import List, Optional, Tuple
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import StatLoggerBase from vllm.v1.metrics.loggers import StatLoggerBase
...@@ -48,9 +48,15 @@ class NullStatLogger(StatLoggerBase): ...@@ -48,9 +48,15 @@ class NullStatLogger(StatLoggerBase):
class DynamoStatLoggerPublisher(StatLoggerBase): class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface.""" """Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
def __init__(self, component: Component, dp_rank: int) -> None: def __init__(
self,
component: Component,
dp_rank: int,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.inner = WorkerMetricsPublisher() self.inner = WorkerMetricsPublisher()
self.inner.create_endpoint(component) metrics_labels = metrics_labels or []
self.inner.create_endpoint(component, metrics_labels)
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.num_gpu_block = 1 self.num_gpu_block = 1
self.request_total_slots = 1 self.request_total_slots = 1
...@@ -141,15 +147,23 @@ class DynamoStatLoggerPublisher(StatLoggerBase): ...@@ -141,15 +147,23 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
class StatLoggerFactory: class StatLoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM.""" """Factory for creating stat logger publishers. Required by vLLM."""
def __init__(self, component: Component, dp_rank: int = 0) -> None: def __init__(
self,
component: Component,
dp_rank: int = 0,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.component = component self.component = component
self.created_logger: Optional[DynamoStatLoggerPublisher] = None self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.metrics_labels = metrics_labels or []
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase: def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank: if self.dp_rank != dp_rank:
return NullStatLogger() return NullStatLogger()
logger = DynamoStatLoggerPublisher(self.component, dp_rank) logger = DynamoStatLoggerPublisher(
self.component, dp_rank, metrics_labels=self.metrics_labels
)
self.created_logger = logger self.created_logger = logger
return logger return logger
......
...@@ -308,7 +308,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co ...@@ -308,7 +308,9 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
try: try:
await asyncio.gather( await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate), generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to serve endpoints: {e}") logger.error(f"Failed to serve endpoints: {e}")
......
...@@ -25,7 +25,6 @@ from typing import Tuple ...@@ -25,7 +25,6 @@ from typing import Tuple
import torch import torch
import uvloop import uvloop
from vllm.distributed.kv_events import ZmqEventPublisher from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -107,14 +106,15 @@ class VllmBaseWorker: ...@@ -107,14 +106,15 @@ class VllmBaseWorker:
def __init__( def __init__(
self, self,
args: argparse.Namespace, args: argparse.Namespace,
engine_args: AsyncEngineArgs,
component: Component, component: Component,
endpoint: Endpoint, endpoint: Endpoint,
config: Config,
): ):
self.enable_disagg = args.enable_disagg self.enable_disagg = args.enable_disagg
self.endpoint = args.endpoint self.endpoint = args.endpoint
self.downstream_endpoint = args.downstream_endpoint self.downstream_endpoint = args.downstream_endpoint
self.engine_args = engine_args self.engine_args = config.engine_args
self.config = config
self.setup_vllm_engine(component, endpoint) self.setup_vllm_engine(component, endpoint)
async def async_init(self, runtime: DistributedRuntime): async def async_init(self, runtime: DistributedRuntime):
...@@ -142,6 +142,7 @@ class VllmBaseWorker: ...@@ -142,6 +142,7 @@ class VllmBaseWorker:
self.stats_logger = StatLoggerFactory( self.stats_logger = StatLoggerFactory(
component, component,
self.engine_args.data_parallel_rank or 0, self.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", self.config.model)],
) )
self.engine_client = AsyncLLM.from_vllm_config( self.engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config, vllm_config=vllm_config,
...@@ -444,20 +445,24 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co ...@@ -444,20 +445,24 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
if args.worker_type in ["prefill", "encode_prefill"]: if args.worker_type in ["prefill", "encode_prefill"]:
handler: VllmBaseWorker = VllmPDWorker( handler: VllmBaseWorker = VllmPDWorker(
args, config.engine_args, component, generate_endpoint args, component, generate_endpoint, config
) )
elif args.worker_type == "decode": elif args.worker_type == "decode":
handler = VllmDecodeWorker( handler = VllmDecodeWorker(args, component, generate_endpoint, config)
args, config.engine_args, component, generate_endpoint
)
await handler.async_init(runtime) await handler.async_init(runtime)
logger.info(f"Starting to serve the {args.endpoint} endpoint...") logger.info(f"Starting to serve the {args.endpoint} endpoint...")
metrics_labels = [("model", config.model)]
try: try:
await asyncio.gather( await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate), generate_endpoint.serve_endpoint(
clear_endpoint.serve_endpoint(handler.clear_kv_blocks), handler.generate, metrics_labels=metrics_labels
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=metrics_labels
),
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to serve endpoints: {e}") logger.error(f"Failed to serve endpoints: {e}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment