Commit 8a1e7a3d authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_PA_PRINT_PARAM to print v1 fa size

parent 74a444b5
...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt1.beta.' + sha[:7] version = 'das.opt1.rc1.' + sha[:7]
else: else:
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt1.beta' version = 'das.opt1.rc1'
# dtk version # dtk version
......
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Optional ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Optional
import numpy as np import numpy as np
import torch import torch
import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType, AttentionMetadata, AttentionType,
...@@ -122,6 +123,7 @@ class FlashAttentionMetadata: ...@@ -122,6 +123,7 @@ class FlashAttentionMetadata:
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
max_seq_len: int max_seq_len: int
seq_lens: torch.Tensor seq_lens: torch.Tensor
# seq_lens_tensor: torch.Tensor
block_table: torch.Tensor block_table: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
...@@ -226,6 +228,7 @@ class FlashAttentionMetadataBuilder( ...@@ -226,6 +228,7 @@ class FlashAttentionMetadataBuilder(
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
# seq_lens_tensor = common_attn_metadata.seq_lens_tensor
block_table = self.block_table block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs] block_table_tensor = block_table.get_device_tensor()[:num_reqs]
...@@ -388,6 +391,7 @@ class FlashAttentionMetadataBuilder( ...@@ -388,6 +391,7 @@ class FlashAttentionMetadataBuilder(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
seq_lens=seq_lens, seq_lens=seq_lens,
# seq_lens_tensor=seq_lens_tensor,
block_table=block_table_tensor, block_table=block_table_tensor,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
use_cascade=use_cascade, use_cascade=use_cascade,
...@@ -590,6 +594,12 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -590,6 +594,12 @@ class FlashAttentionImpl(AttentionImpl):
num_splits=attn_metadata.max_num_splits, num_splits=attn_metadata.max_num_splits,
) )
else: else:
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:")
print(f"q.shape = {query[:num_actual_tokens].unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
vllm_flash_attn_varlen_func( vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens], q=query[:num_actual_tokens],
k=key_cache, k=key_cache,
...@@ -613,6 +623,53 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -613,6 +623,53 @@ class FlashAttentionImpl(AttentionImpl):
# num_splits=attn_metadata.max_num_splits, # num_splits=attn_metadata.max_num_splits,
is_prefix_cache=False, is_prefix_cache=False,
) )
# if num_actual_tokens > 1:
# vllm_flash_attn_varlen_func(
# q=query[:num_actual_tokens],
# k=key_cache,
# v=value_cache,
# out=output[:num_actual_tokens],
# cu_seqlens_q=cu_seqlens_q,
# max_seqlen_q=max_seqlen_q,
# seqused_k=seqused_k,
# max_seqlen_k=max_seqlen_k,
# softmax_scale=self.scale,
# causal=True,
# alibi_slopes=self.alibi_slopes,
# window_size=self.sliding_window,
# block_table=block_table,
# softcap=self.logits_soft_cap,
# scheduler_metadata=scheduler_metadata,
# # fa_version=self.vllm_flash_attn_version,
# # q_descale=layer._q_scale.expand(descale_shape),
# # k_descale=layer._k_scale.expand(descale_shape),
# # v_descale=layer._v_scale.expand(descale_shape),
# # num_splits=attn_metadata.max_num_splits,
# is_prefix_cache=False,
# )
# else:
# from flash_attn import vllm_flash_attn_with_kvcache
# if envs.VLLM_USE_PA_PRINT_PARAM:
# print("PA SIZE:")
# print(f"q.shape = {query[:num_actual_tokens].unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}, kv_cache_dtype = {self.kv_cache_dtype}")
# print(f"cache_seqlens.shape = {attn_metadata.seq_lens_tensor.shape}, block_tables.shape = {block_table.shape}")
# print(f"softmax_scale = {self.scale:.3f}, window_size = {self.sliding_window}, softcap = {self.logits_soft_cap}, alibi_slopes = {self.alibi_slopes}")
# output[:num_actual_tokens] = vllm_flash_attn_with_kvcache(
# q=query[:num_actual_tokens].unsqueeze(1),
# k_cache=key_cache,
# v_cache=value_cache,
# cache_seqlens=attn_metadata.seq_lens_tensor,
# softmax_scale=self.scale,
# causal=True,
# alibi_slopes=self.alibi_slopes,
# window_size=self.sliding_window,
# block_table=block_table,
# softcap=self.logits_soft_cap,
# # k_scale=layer._k_scale.expand(descale_shape),
# # v_scale=layer._v_scale.expand(descale_shape),
# ).squeeze(1)
return output return output
assert not use_local_attn, ( assert not use_local_attn, (
......
...@@ -4,7 +4,7 @@ import abc ...@@ -4,7 +4,7 @@ import abc
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -35,6 +35,8 @@ class CommonAttentionMetadata: ...@@ -35,6 +35,8 @@ class CommonAttentionMetadata:
seq_lens: torch.Tensor seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens """(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens""" and newly scheduled tokens"""
# seq_lens_tensor: torch.Tensor
# """seq_lens stored as a tensor."""
num_reqs: int num_reqs: int
"""Number of requests""" """Number of requests"""
num_actual_tokens: int num_actual_tokens: int
......
...@@ -247,6 +247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -247,6 +247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens = torch.zeros(self.max_num_reqs, self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
# self.seq_lens_tensor = torch.zeros_like(self.seq_lens)
self.slot_mapping = torch.zeros(self.max_num_tokens, self.slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int64, dtype=torch.int64,
device=self.device) device=self.device)
...@@ -700,10 +701,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -700,10 +701,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc = self.query_start_loc[:num_reqs + 1] query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs] seq_lens = self.seq_lens[:num_reqs]
# seq_lens_tensor = self.seq_lens_tensor[:num_reqs]
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
seq_lens=seq_lens, seq_lens=seq_lens,
# seq_lens_tensor=seq_lens_tensor,
num_reqs=num_reqs, num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens, num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens,
...@@ -2018,11 +2021,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2018,11 +2021,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True) non_blocking=True)
seq_lens = self.seq_lens[:num_reqs] seq_lens = self.seq_lens[:num_reqs]
# seq_lens_tensor = self.seq_lens_tensor[:num_reqs]
num_speculative_tokens = 0 if self.speculative_config is None else self.speculative_config.num_lookahead_slots num_speculative_tokens = 0 if self.speculative_config is None else self.speculative_config.num_lookahead_slots
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
seq_lens=seq_lens, seq_lens=seq_lens,
# seq_lens_tensor=seq_lens_tensor,
num_reqs=num_reqs, num_reqs=num_reqs,
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
max_query_len=num_tokens, max_query_len=num_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment