Commit 1c18cce0 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix tests and update the usage of fa

parent b40f2ffc
......@@ -364,6 +364,8 @@ def test_multi_query_kv_attention(
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = output.squeeze(0)
......
......@@ -3,6 +3,7 @@ from typing import Optional, Union
import torch
import triton
import triton.language as tl
from vllm.utils import is_hip
def seeded_uniform(
......@@ -69,9 +70,15 @@ def seeded_uniform(
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if philox_block_size >= 8192:
num_warps = 32
if is_hip():
num_warps = 16
else:
num_warps = 32
elif philox_block_size >= 4096:
num_warps = 16
if is_hip():
num_warps = 8
else:
num_warps = 16
elif philox_block_size >= 2048:
num_warps = 8
......
......@@ -6,6 +6,7 @@ import triton.language as tl
from vllm.model_executor.layers.ops.rand import seeded_uniform
from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import is_hip
_EPS: tl.constexpr = 1e-6
......@@ -266,9 +267,15 @@ def _sample(probs: torch.Tensor,
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if block_size >= 8192:
num_warps = 32
if is_hip():
num_warps = 16
else:
num_warps = 32
elif block_size >= 4096:
num_warps = 16
if is_hip():
num_warps = 8
else:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
......
......@@ -23,6 +23,7 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM']
use_triton_fa_architectures = ['DeepseekV2ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1'
......@@ -34,6 +35,10 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '0'
os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0'
if any(arch in architectures for arch in use_triton_fa_architectures):
os.environ['VLLM_USE_TRITON_FLASH_ATTN'] = '1'
os.environ['VLLM_USE_FLASH_ATTN_AUTO'] = '0'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
......
......@@ -903,7 +903,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
import vllm.envs as envs
if envs.VLLM_USE_FLASH_ATTN_AUTO:
for group_id in range(1):
seq_len = 8000
if max_num_batched_tokens >= 8000:
seq_len = 8000
else:
seq_len = max_num_batched_tokens
batch_size += seq_len
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment