Commit 408f663a authored by zhuwenwen's avatar zhuwenwen
Browse files

remove the automatic switching strategy of fa

parent aa1e273a
...@@ -281,31 +281,19 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -281,31 +281,19 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = False self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8000
self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
if self.use_flash_attn_auto: # from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
from vllm.attention.ops.flash_attn_triton_mqa_gqa import ( # triton_attention)
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func) flash_attn_varlen_func)
self.attn_func_triton = flash_attn_varlen_func self.attn_func = flash_attn_varlen_func # triton_attention
logger.debug("Using Triton FA in ROCmBackend")
from flash_attn import flash_attn_varlen_func # noqa: F401 if self.sliding_window != (-1, -1):
self.attn_func_cu = flash_attn_varlen_func logger.warning("ROCm Triton FA does not currently support "
logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA") "sliding window attention. If using half "
else: "precision, please try using the ROCm CK "
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 "FA backend instead by setting the env var "
# triton_attention) "`VLLM_USE_TRITON_FLASH_ATTN=0`")
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func = flash_attn_varlen_func # triton_attention
logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
"sliding window attention. If using half "
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
else: else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn # if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either # either
...@@ -414,47 +402,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -414,47 +402,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype, query.dtype,
attn_metadata.seq_lens, attn_metadata.seq_lens,
make_attn_mask=False) # type: ignore make_attn_mask=False) # type: ignore
if self.use_flash_attn_auto:
if prefill_meta.max_prefill_seq_len > 8000:
out = self.attn_func_triton(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
if envs.VLLM_USE_CL_FLASH_ATTN:
out = self.attn_func_cu(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
out = self.attn_func_cu(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
# out = self.attn_func( # out = self.attn_func(
# query, # query,
# key, # key,
...@@ -466,17 +414,17 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -466,17 +414,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# self.scale, # self.scale,
# attn_masks, # attn_masks,
# ) # )
out = self.attn_func( out = self.attn_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len, max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len, max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
) )
elif self.use_naive_attn: elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
......
...@@ -202,12 +202,13 @@ def which_attn_to_use( ...@@ -202,12 +202,13 @@ def which_attn_to_use(
# AMD GPUs. # AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
# if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
# if current_platform.get_device_capability()[0] != 9: # if current_platform.get_device_capability()[0] != 9:
# # not Instinct series GPUs. if torch.cuda.get_device_capability()[0] != 9:
# logger.info("flash_attn is not supported on NAVI GPUs.") # not Instinct series GPUs.
# else: logger.info("flash_attn is not supported on NAVI GPUs.")
# logger.info("%s is not supported in AMD GPUs.", selected_backend) else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH return _Backend.ROCM_FLASH
# FlashAttn in NVIDIA GPUs. # FlashAttn in NVIDIA GPUs.
......
...@@ -13,7 +13,6 @@ if TYPE_CHECKING: ...@@ -13,7 +13,6 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_CL_FLASH_ATTN: bool = False VLLM_USE_CL_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
...@@ -196,17 +195,12 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -196,17 +195,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention # flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN": "VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in
("true", "1")), ("true", "1")),
# flag to control if vllm should use cutlass flash attention # flag to control if vllm should use cutlass flash attention
"VLLM_USE_CL_FLASH_ATTN": "VLLM_USE_CL_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_CL_FLASH_ATTN", "False").lower() in lambda: (os.environ.get("VLLM_USE_CL_FLASH_ATTN", "True").lower() in
("true", "1")),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
......
...@@ -23,7 +23,6 @@ def get_model_architecture( ...@@ -23,7 +23,6 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM']
use_triton_fa_architectures = ['DeepseekV2ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
...@@ -35,10 +34,6 @@ def get_model_architecture( ...@@ -35,10 +34,6 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0' os.environ['FA_PAD'] = '0'
if any(arch in architectures for arch in use_triton_fa_architectures):
os.environ['VLLM_USE_TRITON_FLASH_ATTN'] = '1'
os.environ['VLLM_USE_FLASH_ATTN_AUTO'] = '0'
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
......
...@@ -1179,33 +1179,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1179,33 +1179,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs = 1 max_num_seqs = 1
batch_size = 0 batch_size = 0
import vllm.envs as envs
if envs.VLLM_USE_FLASH_ATTN_AUTO:
for group_id in range(1):
if max_num_batched_tokens >= 8000:
seq_len = 8000
else:
seq_len = max_num_batched_tokens
batch_size += seq_len
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=dummy_multi_modal_data,
)
seqs.append(seq)
max_num_batched_tokens -= seq_len
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment