Commit b40f2ffc authored by zhuwenwen's avatar zhuwenwen
Browse files

Add fa pad conditions and automatic switching strategy

parent 4d821524
...@@ -276,7 +276,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -276,7 +276,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = False self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8192 # NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8000
self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
if self.use_flash_attn_auto: if self.use_flash_attn_auto:
...@@ -286,7 +286,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -286,7 +286,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func_ck = flash_attn_varlen_func self.attn_func_ck = flash_attn_varlen_func
logger.debug("When SEQ_LEN > 8192, Use Triton FA in ROCmBackend, otherwise Use CK FA") logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA")
else: else:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 # from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention) # triton_attention)
...@@ -410,7 +410,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -410,7 +410,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata.seq_lens, attn_metadata.seq_lens,
make_attn_mask=False) # type: ignore make_attn_mask=False) # type: ignore
if self.use_flash_attn_auto: if self.use_flash_attn_auto:
if prefill_meta.max_prefill_seq_len > 8192: if prefill_meta.max_prefill_seq_len > 8000:
out = self.attn_func_triton( out = self.attn_func_triton(
q=query, q=query,
k=key, k=key,
......
...@@ -11,6 +11,7 @@ if TYPE_CHECKING: ...@@ -11,6 +11,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
...@@ -178,12 +179,12 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -178,12 +179,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention # flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN": "VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control vllm to automatically switch between Triton FA and CK FA # flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO": "VLLM_USE_FLASH_ATTN_AUTO":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "False").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
("true", "1")), ("true", "1")),
# Internal flag to enable Dynamo graph capture # Internal flag to enable Dynamo graph capture
......
...@@ -177,6 +177,12 @@ class BaiChuanAttention(nn.Module): ...@@ -177,6 +177,12 @@ class BaiChuanAttention(nn.Module):
self.scaling, self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
...@@ -186,7 +192,7 @@ class BaiChuanAttention(nn.Module): ...@@ -186,7 +192,7 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
...@@ -341,6 +347,12 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -341,6 +347,12 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
quant_config=quant_config) quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
......
...@@ -99,6 +99,11 @@ class GLMAttention(nn.Module): ...@@ -99,6 +99,11 @@ class GLMAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
...@@ -108,7 +113,7 @@ class GLMAttention(nn.Module): ...@@ -108,7 +113,7 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
...@@ -366,6 +371,12 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -366,6 +371,12 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = self.transformer.output_layer self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
......
...@@ -166,6 +166,11 @@ class LlamaAttention(nn.Module): ...@@ -166,6 +166,11 @@ class LlamaAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
...@@ -175,7 +180,7 @@ class LlamaAttention(nn.Module): ...@@ -175,7 +180,7 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
...@@ -417,8 +422,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -417,8 +422,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.sampler = Sampler() self.sampler = Sampler()
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.quant_method = None
self.quant_method = None
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
......
...@@ -116,6 +116,11 @@ class QWenAttention(nn.Module): ...@@ -116,6 +116,11 @@ class QWenAttention(nn.Module):
self.scaling, self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
...@@ -125,7 +130,7 @@ class QWenAttention(nn.Module): ...@@ -125,7 +130,7 @@ class QWenAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
...@@ -262,7 +267,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -262,7 +267,7 @@ class QWenLMHeadModel(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
......
...@@ -148,6 +148,11 @@ class Qwen2Attention(nn.Module): ...@@ -148,6 +148,11 @@ class Qwen2Attention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
...@@ -157,7 +162,7 @@ class Qwen2Attention(nn.Module): ...@@ -157,7 +162,7 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
...@@ -356,7 +361,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -356,7 +361,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
......
...@@ -900,6 +900,32 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -900,6 +900,32 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs = 1 max_num_seqs = 1
batch_size = 0 batch_size = 0
import vllm.envs as envs
if envs.VLLM_USE_FLASH_ATTN_AUTO:
for group_id in range(1):
seq_len = 8000
batch_size += seq_len
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=dummy_multi_modal_data,
)
seqs.append(seq)
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment