Commit fa973559 authored by zhuwenwen's avatar zhuwenwen
Browse files

Add fa pad conditions and automatic switching strategy

parent a528f350
...@@ -228,7 +228,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -228,7 +228,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = False self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8192 # NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen >= 8000
self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
if self.use_flash_attn_auto: if self.use_flash_attn_auto:
...@@ -340,7 +340,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -340,7 +340,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# prompt, and they have the same length. # prompt, and they have the same length.
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
if self.use_flash_attn_auto: if self.use_flash_attn_auto:
if prefill_meta.max_prefill_seq_len > 8192: if prefill_meta.max_prefill_seq_len >= 8000:
out = self.attn_func_triton( out = self.attn_func_triton(
q=query, q=query,
k=key, k=key,
......
...@@ -9,6 +9,7 @@ if TYPE_CHECKING: ...@@ -9,6 +9,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
...@@ -130,12 +131,12 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -130,12 +131,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention # flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN": "VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control vllm to automatically switch between Triton FA and CK FA # flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO": "VLLM_USE_FLASH_ATTN_AUTO":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "False").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
("true", "1")), ("true", "1")),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
......
...@@ -174,6 +174,11 @@ class BaiChuanAttention(nn.Module): ...@@ -174,6 +174,11 @@ class BaiChuanAttention(nn.Module):
self.scaling, self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
...@@ -183,7 +188,7 @@ class BaiChuanAttention(nn.Module): ...@@ -183,7 +188,7 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
...@@ -333,6 +338,12 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -333,6 +338,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
......
...@@ -97,6 +97,11 @@ class GLMAttention(nn.Module): ...@@ -97,6 +97,11 @@ class GLMAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
...@@ -106,7 +111,7 @@ class GLMAttention(nn.Module): ...@@ -106,7 +111,7 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
...@@ -360,6 +365,12 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -360,6 +365,12 @@ class ChatGLMForCausalLM(nn.Module):
self.lm_head_weight = self.transformer.output_layer.weight self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
......
...@@ -152,6 +152,11 @@ class LlamaAttention(nn.Module): ...@@ -152,6 +152,11 @@ class LlamaAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
...@@ -161,7 +166,7 @@ class LlamaAttention(nn.Module): ...@@ -161,7 +166,7 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
...@@ -367,8 +372,8 @@ class LlamaForCausalLM(nn.Module): ...@@ -367,8 +372,8 @@ class LlamaForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
self.quant_method = None
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
......
...@@ -114,6 +114,11 @@ class QWenAttention(nn.Module): ...@@ -114,6 +114,11 @@ class QWenAttention(nn.Module):
self.scaling, self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
...@@ -123,7 +128,7 @@ class QWenAttention(nn.Module): ...@@ -123,7 +128,7 @@ class QWenAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
...@@ -246,7 +251,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -246,7 +251,7 @@ class QWenLMHeadModel(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
......
...@@ -144,6 +144,11 @@ class Qwen2Attention(nn.Module): ...@@ -144,6 +144,11 @@ class Qwen2Attention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
...@@ -153,7 +158,7 @@ class Qwen2Attention(nn.Module): ...@@ -153,7 +158,7 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
...@@ -327,7 +332,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -327,7 +332,7 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
......
...@@ -804,6 +804,30 @@ class ModelRunner: ...@@ -804,6 +804,30 @@ class ModelRunner:
max_num_seqs = min( max_num_seqs = min(
max_num_seqs, max_num_seqs,
int(max_num_batched_tokens / vlm_config.image_feature_size)) int(max_num_batched_tokens / vlm_config.image_feature_size))
import vllm.envs as envs
if envs.VLLM_USE_FLASH_ATTN_AUTO:
for group_id in range(1):
seq_len = 8000
if vlm_config is None:
seq_data = SequenceData([0] * seq_len)
dummy_multi_modal_data = None
else:
seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \
.dummy_data_for_profiling(seq_len, model_config, vlm_config)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=dummy_multi_modal_data,
)
seqs.append(seq)
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment