Unverified Commit 75f89dc4 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] add a flag to track batchsize statistics (#11059)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent e7391949
...@@ -69,6 +69,7 @@ if TYPE_CHECKING: ...@@ -69,6 +69,7 @@ if TYPE_CHECKING:
VLLM_DISABLED_KERNELS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
def get_default_cache_root(): def get_default_cache_root():
...@@ -452,6 +453,8 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -452,6 +453,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, enable multiprocessing in LLM for the V1 code path. # If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING": "VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))), lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
"VLLM_LOG_BATCHSIZE_INTERVAL":
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
} }
# end-env-vars-definition # end-env-vars-definition
......
import time
from collections import Counter
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
batchsize_counter: Counter = Counter()
last_logging_time: float = 0
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
@dataclass @dataclass
...@@ -26,7 +37,26 @@ def get_forward_context() -> ForwardContext: ...@@ -26,7 +37,26 @@ def get_forward_context() -> ForwardContext:
@contextmanager @contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig): def set_forward_context(context: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc.""" can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global track_batchsize, batchsize_counter
global last_logging_time, batchsize_logging_interval
if track_batchsize and context is not None:
if hasattr(context, "num_prefill_tokens"):
# for v0 attention backends
batchsize = context.num_prefill_tokens + context.num_decode_tokens
else:
# for v1 attention backends
batchsize = context.num_input_tokens
batchsize_counter[batchsize] += 1
if time.monotonic() - last_logging_time > batchsize_logging_interval:
last_logging_time = time.monotonic()
sorted_data = sorted(batchsize_counter.items(),
key=lambda x: x[1],
reverse=True)
logger.info("Batchsize distribution (batchsize, count): %s",
sorted_data)
global _forward_context global _forward_context
prev_context = _forward_context prev_context = _forward_context
_forward_context = ForwardContext( _forward_context = ForwardContext(
......
...@@ -56,6 +56,7 @@ class FlashAttentionMetadata: ...@@ -56,6 +56,7 @@ class FlashAttentionMetadata:
seq_start_loc: torch.Tensor seq_start_loc: torch.Tensor
block_table: torch.Tensor block_table: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
num_input_tokens: int = 0 # Number of tokens including padding.
class FlashAttentionImpl(AttentionImpl): class FlashAttentionImpl(AttentionImpl):
......
...@@ -445,6 +445,8 @@ class GPUModelRunner: ...@@ -445,6 +445,8 @@ class GPUModelRunner:
# Eager mode. # Eager mode.
num_input_tokens = num_scheduled_tokens num_input_tokens = num_scheduled_tokens
attn_metadata.num_input_tokens = num_input_tokens
# Get the inputs embeds. # Get the inputs embeds.
if encoder_outputs: if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings( inputs_embeds = self.model.get_input_embeddings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment