Commit 408f663a authored by zhuwenwen's avatar zhuwenwen
Browse files

remove the automatic switching strategy of fa

parent aa1e273a
......@@ -281,18 +281,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8000
self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO
if self.use_triton_flash_attn:
if self.use_flash_attn_auto:
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func_triton = flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func_cu = flash_attn_varlen_func
logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA")
else:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
......@@ -305,7 +294,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
......@@ -414,47 +402,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=False) # type: ignore
if self.use_flash_attn_auto:
if prefill_meta.max_prefill_seq_len > 8000:
out = self.attn_func_triton(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
if envs.VLLM_USE_CL_FLASH_ATTN:
out = self.attn_func_cu(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
out = self.attn_func_cu(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
# out = self.attn_func(
# query,
# key,
......
......@@ -202,12 +202,13 @@ def which_attn_to_use(
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
# if selected_backend == _Backend.ROCM_FLASH:
if selected_backend == _Backend.ROCM_FLASH:
# if current_platform.get_device_capability()[0] != 9:
# # not Instinct series GPUs.
# logger.info("flash_attn is not supported on NAVI GPUs.")
# else:
# logger.info("%s is not supported in AMD GPUs.", selected_backend)
if torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH
# FlashAttn in NVIDIA GPUs.
......
......@@ -13,7 +13,6 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_CL_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0
......@@ -196,17 +195,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control if vllm should use cutlass flash attention
"VLLM_USE_CL_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_CL_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
lambda: (os.environ.get("VLLM_USE_CL_FLASH_ATTN", "True").lower() in
("true", "1")),
# flag to control vllm to use optimized kernels
......
......@@ -23,7 +23,6 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM']
use_triton_fa_architectures = ['DeepseekV2ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1'
......@@ -36,10 +35,6 @@ def get_model_architecture(
os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0'
if any(arch in architectures for arch in use_triton_fa_architectures):
os.environ['VLLM_USE_TRITON_FLASH_ATTN'] = '1'
os.environ['VLLM_USE_FLASH_ATTN_AUTO'] = '0'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors"]
......
......@@ -1179,33 +1179,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs = 1
batch_size = 0
import vllm.envs as envs
if envs.VLLM_USE_FLASH_ATTN_AUTO:
for group_id in range(1):
if max_num_batched_tokens >= 8000:
seq_len = 8000
else:
seq_len = max_num_batched_tokens
batch_size += seq_len
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=dummy_multi_modal_data,
)
seqs.append(seq)
max_num_batched_tokens -= seq_len
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment