Commit 8a1e7a3d authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_PA_PRINT_PARAM to print v1 fa size

parent 74a444b5
......@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None:
sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'):
version = 'das.opt1.beta.' + sha[:7]
version = 'das.opt1.rc1.' + sha[:7]
else:
if (major, minor) >= ('2', '5'):
version = 'das.opt1.beta'
version = 'das.opt1.rc1'
# dtk version
......
......@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Optional
import numpy as np
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
......@@ -122,6 +123,7 @@ class FlashAttentionMetadata:
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
# seq_lens_tensor: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
......@@ -226,6 +228,7 @@ class FlashAttentionMetadataBuilder(
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
# seq_lens_tensor = common_attn_metadata.seq_lens_tensor
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
......@@ -388,6 +391,7 @@ class FlashAttentionMetadataBuilder(
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
# seq_lens_tensor=seq_lens_tensor,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
......@@ -590,6 +594,12 @@ class FlashAttentionImpl(AttentionImpl):
num_splits=attn_metadata.max_num_splits,
)
else:
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:")
print(f"q.shape = {query[:num_actual_tokens].unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
......@@ -613,6 +623,53 @@ class FlashAttentionImpl(AttentionImpl):
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache=False,
)
# if num_actual_tokens > 1:
# vllm_flash_attn_varlen_func(
# q=query[:num_actual_tokens],
# k=key_cache,
# v=value_cache,
# out=output[:num_actual_tokens],
# cu_seqlens_q=cu_seqlens_q,
# max_seqlen_q=max_seqlen_q,
# seqused_k=seqused_k,
# max_seqlen_k=max_seqlen_k,
# softmax_scale=self.scale,
# causal=True,
# alibi_slopes=self.alibi_slopes,
# window_size=self.sliding_window,
# block_table=block_table,
# softcap=self.logits_soft_cap,
# scheduler_metadata=scheduler_metadata,
# # fa_version=self.vllm_flash_attn_version,
# # q_descale=layer._q_scale.expand(descale_shape),
# # k_descale=layer._k_scale.expand(descale_shape),
# # v_descale=layer._v_scale.expand(descale_shape),
# # num_splits=attn_metadata.max_num_splits,
# is_prefix_cache=False,
# )
# else:
# from flash_attn import vllm_flash_attn_with_kvcache
# if envs.VLLM_USE_PA_PRINT_PARAM:
# print("PA SIZE:")
# print(f"q.shape = {query[:num_actual_tokens].unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}, kv_cache_dtype = {self.kv_cache_dtype}")
# print(f"cache_seqlens.shape = {attn_metadata.seq_lens_tensor.shape}, block_tables.shape = {block_table.shape}")
# print(f"softmax_scale = {self.scale:.3f}, window_size = {self.sliding_window}, softcap = {self.logits_soft_cap}, alibi_slopes = {self.alibi_slopes}")
# output[:num_actual_tokens] = vllm_flash_attn_with_kvcache(
# q=query[:num_actual_tokens].unsqueeze(1),
# k_cache=key_cache,
# v_cache=value_cache,
# cache_seqlens=attn_metadata.seq_lens_tensor,
# softmax_scale=self.scale,
# causal=True,
# alibi_slopes=self.alibi_slopes,
# window_size=self.sliding_window,
# block_table=block_table,
# softcap=self.logits_soft_cap,
# # k_scale=layer._k_scale.expand(descale_shape),
# # v_scale=layer._v_scale.expand(descale_shape),
# ).squeeze(1)
return output
assert not use_local_attn, (
......
......@@ -4,7 +4,7 @@ import abc
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, Optional
import numpy as np
import torch
......@@ -35,6 +35,8 @@ class CommonAttentionMetadata:
seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
# seq_lens_tensor: torch.Tensor
# """seq_lens stored as a tensor."""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
......
......@@ -247,6 +247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
# self.seq_lens_tensor = torch.zeros_like(self.seq_lens)
self.slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
......@@ -700,10 +701,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
# seq_lens_tensor = self.seq_lens_tensor[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
seq_lens=seq_lens,
# seq_lens_tensor=seq_lens_tensor,
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
......@@ -2018,11 +2021,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
seq_lens = self.seq_lens[:num_reqs]
# seq_lens_tensor = self.seq_lens_tensor[:num_reqs]
num_speculative_tokens = 0 if self.speculative_config is None else self.speculative_config.num_lookahead_slots
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
seq_lens=seq_lens,
# seq_lens_tensor=seq_lens_tensor,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment