Commit b40f2ffc authored by zhuwenwen's avatar zhuwenwen
Browse files

Add fa pad conditions and automatic switching strategy

parent 4d821524
......@@ -276,7 +276,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8192
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8000
self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO
if self.use_triton_flash_attn:
if self.use_flash_attn_auto:
......@@ -286,7 +286,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func_ck = flash_attn_varlen_func
logger.debug("When SEQ_LEN > 8192, Use Triton FA in ROCmBackend, otherwise Use CK FA")
logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA")
else:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
......@@ -410,7 +410,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata.seq_lens,
make_attn_mask=False) # type: ignore
if self.use_flash_attn_auto:
if prefill_meta.max_prefill_seq_len > 8192:
if prefill_meta.max_prefill_seq_len > 8000:
out = self.attn_func_triton(
q=query,
k=key,
......
......@@ -11,6 +11,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
......@@ -178,12 +179,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "False").lower() in
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
("true", "1")),
# Internal flag to enable Dynamo graph capture
......
......@@ -178,6 +178,12 @@ class BaiChuanAttention(nn.Module):
cache_config=cache_config,
quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward(
self,
positions: torch.Tensor,
......@@ -186,7 +192,7 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1':
if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
......@@ -341,6 +347,12 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
......
......@@ -100,6 +100,11 @@ class GLMAttention(nn.Module):
cache_config=cache_config,
quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward(
self,
hidden_states: torch.Tensor,
......@@ -108,7 +113,7 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1':
if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
......@@ -366,6 +371,12 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
......
......@@ -167,6 +167,11 @@ class LlamaAttention(nn.Module):
cache_config=cache_config,
quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward(
self,
positions: torch.Tensor,
......@@ -175,7 +180,7 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1':
if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
......@@ -417,8 +422,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
self.quant_method = None
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
......
......@@ -117,6 +117,11 @@ class QWenAttention(nn.Module):
cache_config=cache_config,
quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward(
self,
positions: torch.Tensor,
......@@ -125,7 +130,7 @@ class QWenAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
if os.environ.get('FA_PAD') == '1':
if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
......
......@@ -149,6 +149,11 @@ class Qwen2Attention(nn.Module):
cache_config=cache_config,
quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward(
self,
positions: torch.Tensor,
......@@ -157,7 +162,7 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1':
if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
......
......@@ -900,6 +900,32 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs = 1
batch_size = 0
import vllm.envs as envs
if envs.VLLM_USE_FLASH_ATTN_AUTO:
for group_id in range(1):
seq_len = 8000
batch_size += seq_len
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=dummy_multi_modal_data,
)
seqs.append(seq)
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment