Unverified Commit 24e6ad3f authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[V1] Remove num_input_tokens from attn_metadata (#17193)


Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
parent 2ef5d106
......@@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any,
if vllm_config.parallel_config.data_parallel_size > 1:
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
if attn_metadata is not None:
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = attn_metadata.num_input_tokens
if attn_metadata is not None and hasattr(attn_metadata,
"num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends or no attn_metadata
batchsize = num_tokens
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = batchsize
......@@ -124,7 +122,7 @@ def set_forward_context(attn_metadata: Any,
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = attn_metadata.num_input_tokens
batchsize = num_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
......
......@@ -94,9 +94,6 @@ class FlashAttentionMetadata:
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# for local attention
@dataclass
class LocalAttentionMetadata:
......
......@@ -183,9 +183,6 @@ class FlashInferMetadata:
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
@property
def query_start_loc(self):
# The GPUModelRunner expects to be able to access this property.
......
......@@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]):
num_decode_tokens: int
num_prefills: int
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# The dimension of the attention heads
head_dim: Optional[int] = None
......
......@@ -1036,7 +1036,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
attn_metadata.num_input_tokens = num_input_tokens
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
......@@ -1088,7 +1087,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
output = self.model(
input_ids=input_ids,
positions=positions,
......
......@@ -769,7 +769,10 @@ class TPUModelRunner:
xm.mark_step()
num_reqs = self.input_batch.num_reqs
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=scheduler_output.total_num_scheduled_tokens):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment