Unverified Commit 24e6ad3f authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[V1] Remove num_input_tokens from attn_metadata (#17193)


Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
parent 2ef5d106
...@@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any, ...@@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any,
if vllm_config.parallel_config.data_parallel_size > 1: if vllm_config.parallel_config.data_parallel_size > 1:
dp_size = vllm_config.parallel_config.data_parallel_size dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
if attn_metadata is not None: if attn_metadata is not None and hasattr(attn_metadata,
if hasattr(attn_metadata, "num_prefill_tokens"): "num_prefill_tokens"):
# for v0 attention backends # for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \ batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens attn_metadata.num_decode_tokens
else: else:
# for v1 attention backends # for v1 attention backends or no attn_metadata
batchsize = attn_metadata.num_input_tokens
else:
batchsize = num_tokens batchsize = num_tokens
num_tokens_across_dp = [0] * dp_size num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = batchsize num_tokens_across_dp[dp_rank] = batchsize
...@@ -124,7 +122,7 @@ def set_forward_context(attn_metadata: Any, ...@@ -124,7 +122,7 @@ def set_forward_context(attn_metadata: Any,
attn_metadata.num_decode_tokens attn_metadata.num_decode_tokens
else: else:
# for v1 attention backends # for v1 attention backends
batchsize = attn_metadata.num_input_tokens batchsize = num_tokens
# we use synchronous scheduling right now, # we use synchronous scheduling right now,
# adding a sync point here should not affect # adding a sync point here should not affect
# scheduling of the next batch # scheduling of the next batch
......
...@@ -94,9 +94,6 @@ class FlashAttentionMetadata: ...@@ -94,9 +94,6 @@ class FlashAttentionMetadata:
scheduler_metadata: Optional[torch.Tensor] = None scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# for local attention # for local attention
@dataclass @dataclass
class LocalAttentionMetadata: class LocalAttentionMetadata:
......
...@@ -183,9 +183,6 @@ class FlashInferMetadata: ...@@ -183,9 +183,6 @@ class FlashInferMetadata:
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
@property @property
def query_start_loc(self): def query_start_loc(self):
# The GPUModelRunner expects to be able to access this property. # The GPUModelRunner expects to be able to access this property.
......
...@@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]): ...@@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]):
num_decode_tokens: int num_decode_tokens: int
num_prefills: int num_prefills: int
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# The dimension of the attention heads # The dimension of the attention heads
head_dim: Optional[int] = None head_dim: Optional[int] = None
......
...@@ -1036,7 +1036,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1036,7 +1036,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_input_tokens = round_up(num_scheduled_tokens, tp_size) num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else: else:
num_input_tokens = num_scheduled_tokens num_input_tokens = num_scheduled_tokens
attn_metadata.num_input_tokens = num_input_tokens
# _prepare_inputs may reorder the batch, so we must gather multi # _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order # modal outputs after that to ensure the correct order
...@@ -1088,7 +1087,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1088,7 +1087,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Run the decoder. # Run the decoder.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config): with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
output = self.model( output = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
......
...@@ -769,7 +769,10 @@ class TPUModelRunner: ...@@ -769,7 +769,10 @@ class TPUModelRunner:
xm.mark_step() xm.mark_step()
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
# Run the decoder # Run the decoder
with set_forward_context(attn_metadata, self.vllm_config): with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=scheduler_output.total_num_scheduled_tokens):
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=self.position_ids, positions=self.position_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment