Commit 693d5ed4 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove seq_lens_tensor of v1 fa

parent e326b648
...@@ -123,7 +123,6 @@ class FlashAttentionMetadata: ...@@ -123,7 +123,6 @@ class FlashAttentionMetadata:
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
max_seq_len: int max_seq_len: int
seq_lens: torch.Tensor seq_lens: torch.Tensor
# seq_lens_tensor: torch.Tensor
block_table: torch.Tensor block_table: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
...@@ -228,7 +227,6 @@ class FlashAttentionMetadataBuilder( ...@@ -228,7 +227,6 @@ class FlashAttentionMetadataBuilder(
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
# seq_lens_tensor = common_attn_metadata.seq_lens_tensor
block_table = self.block_table block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs] block_table_tensor = block_table.get_device_tensor()[:num_reqs]
...@@ -391,7 +389,6 @@ class FlashAttentionMetadataBuilder( ...@@ -391,7 +389,6 @@ class FlashAttentionMetadataBuilder(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
seq_lens=seq_lens, seq_lens=seq_lens,
# seq_lens_tensor=seq_lens_tensor,
block_table=block_table_tensor, block_table=block_table_tensor,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
use_cascade=use_cascade, use_cascade=use_cascade,
...@@ -623,53 +620,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -623,53 +620,6 @@ class FlashAttentionImpl(AttentionImpl):
# num_splits=attn_metadata.max_num_splits, # num_splits=attn_metadata.max_num_splits,
is_prefix_cache=False, is_prefix_cache=False,
) )
# if num_actual_tokens > 1:
# vllm_flash_attn_varlen_func(
# q=query[:num_actual_tokens],
# k=key_cache,
# v=value_cache,
# out=output[:num_actual_tokens],
# cu_seqlens_q=cu_seqlens_q,
# max_seqlen_q=max_seqlen_q,
# seqused_k=seqused_k,
# max_seqlen_k=max_seqlen_k,
# softmax_scale=self.scale,
# causal=True,
# alibi_slopes=self.alibi_slopes,
# window_size=self.sliding_window,
# block_table=block_table,
# softcap=self.logits_soft_cap,
# scheduler_metadata=scheduler_metadata,
# # fa_version=self.vllm_flash_attn_version,
# # q_descale=layer._q_scale.expand(descale_shape),
# # k_descale=layer._k_scale.expand(descale_shape),
# # v_descale=layer._v_scale.expand(descale_shape),
# # num_splits=attn_metadata.max_num_splits,
# is_prefix_cache=False,
# )
# else:
# from flash_attn import vllm_flash_attn_with_kvcache
# if envs.VLLM_USE_PA_PRINT_PARAM:
# print("PA SIZE:")
# print(f"q.shape = {query[:num_actual_tokens].unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}, kv_cache_dtype = {self.kv_cache_dtype}")
# print(f"cache_seqlens.shape = {attn_metadata.seq_lens_tensor.shape}, block_tables.shape = {block_table.shape}")
# print(f"softmax_scale = {self.scale:.3f}, window_size = {self.sliding_window}, softcap = {self.logits_soft_cap}, alibi_slopes = {self.alibi_slopes}")
# output[:num_actual_tokens] = vllm_flash_attn_with_kvcache(
# q=query[:num_actual_tokens].unsqueeze(1),
# k_cache=key_cache,
# v_cache=value_cache,
# cache_seqlens=attn_metadata.seq_lens_tensor,
# softmax_scale=self.scale,
# causal=True,
# alibi_slopes=self.alibi_slopes,
# window_size=self.sliding_window,
# block_table=block_table,
# softcap=self.logits_soft_cap,
# # k_scale=layer._k_scale.expand(descale_shape),
# # v_scale=layer._v_scale.expand(descale_shape),
# ).squeeze(1)
return output return output
assert not use_local_attn, ( assert not use_local_attn, (
......
...@@ -4,7 +4,7 @@ import abc ...@@ -4,7 +4,7 @@ import abc
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, Optional from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
import numpy as np import numpy as np
import torch import torch
...@@ -35,8 +35,6 @@ class CommonAttentionMetadata: ...@@ -35,8 +35,6 @@ class CommonAttentionMetadata:
seq_lens: torch.Tensor seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens """(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens""" and newly scheduled tokens"""
# seq_lens_tensor: torch.Tensor
# """seq_lens stored as a tensor."""
num_reqs: int num_reqs: int
"""Number of requests""" """Number of requests"""
num_actual_tokens: int num_actual_tokens: int
......
...@@ -248,7 +248,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -248,7 +248,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens = torch.zeros(self.max_num_reqs, self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
# self.seq_lens_tensor = torch.zeros_like(self.seq_lens)
self.slot_mapping = torch.zeros(self.max_num_tokens, self.slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int64, dtype=torch.int64,
device=self.device) device=self.device)
...@@ -702,7 +701,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -702,7 +701,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc = self.query_start_loc[:num_reqs + 1] query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs] seq_lens = self.seq_lens[:num_reqs]
# seq_lens_tensor = self.seq_lens_tensor[:num_reqs]
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
...@@ -2033,7 +2031,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2033,7 +2031,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True) non_blocking=True)
seq_lens = self.seq_lens[:num_reqs] seq_lens = self.seq_lens[:num_reqs]
# seq_lens_tensor = self.seq_lens_tensor[:num_reqs]
num_speculative_tokens = 0 if self.speculative_config is None else self.speculative_config.num_lookahead_slots num_speculative_tokens = 0 if self.speculative_config is None else self.speculative_config.num_lookahead_slots
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment