Commit 1c18cce0 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix tests and update the usage of fa

parent b40f2ffc
...@@ -364,6 +364,8 @@ def test_multi_query_kv_attention( ...@@ -364,6 +364,8 @@ def test_multi_query_kv_attention(
attn_bias=attn_bias, attn_bias=attn_bias,
p=0.0, p=0.0,
scale=scale, scale=scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
) )
output = output.squeeze(0) output = output.squeeze(0)
......
...@@ -3,6 +3,7 @@ from typing import Optional, Union ...@@ -3,6 +3,7 @@ from typing import Optional, Union
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import is_hip
def seeded_uniform( def seeded_uniform(
...@@ -69,8 +70,14 @@ def seeded_uniform( ...@@ -69,8 +70,14 @@ def seeded_uniform(
# Manual tuning. This seems to give best performance on A100 for # Manual tuning. This seems to give best performance on A100 for
# simple kernels like this. # simple kernels like this.
if philox_block_size >= 8192: if philox_block_size >= 8192:
if is_hip():
num_warps = 16
else:
num_warps = 32 num_warps = 32
elif philox_block_size >= 4096: elif philox_block_size >= 4096:
if is_hip():
num_warps = 8
else:
num_warps = 16 num_warps = 16
elif philox_block_size >= 2048: elif philox_block_size >= 2048:
num_warps = 8 num_warps = 8
......
...@@ -6,6 +6,7 @@ import triton.language as tl ...@@ -6,6 +6,7 @@ import triton.language as tl
from vllm.model_executor.layers.ops.rand import seeded_uniform from vllm.model_executor.layers.ops.rand import seeded_uniform
from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import is_hip
_EPS: tl.constexpr = 1e-6 _EPS: tl.constexpr = 1e-6
...@@ -266,8 +267,14 @@ def _sample(probs: torch.Tensor, ...@@ -266,8 +267,14 @@ def _sample(probs: torch.Tensor,
# Manual tuning. This seems to give best performance on A100 for # Manual tuning. This seems to give best performance on A100 for
# simple kernels like this. # simple kernels like this.
if block_size >= 8192: if block_size >= 8192:
if is_hip():
num_warps = 16
else:
num_warps = 32 num_warps = 32
elif block_size >= 4096: elif block_size >= 4096:
if is_hip():
num_warps = 8
else:
num_warps = 16 num_warps = 16
elif block_size >= 2048: elif block_size >= 2048:
num_warps = 8 num_warps = 8
......
...@@ -23,6 +23,7 @@ def get_model_architecture( ...@@ -23,6 +23,7 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM']
use_triton_fa_architectures = ['DeepseekV2ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
...@@ -35,6 +36,10 @@ def get_model_architecture( ...@@ -35,6 +36,10 @@ def get_model_architecture(
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0' os.environ['FA_PAD'] = '0'
if any(arch in architectures for arch in use_triton_fa_architectures):
os.environ['VLLM_USE_TRITON_FLASH_ATTN'] = '1'
os.environ['VLLM_USE_FLASH_ATTN_AUTO'] = '0'
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None if (model_config.quantization is not None
......
...@@ -903,7 +903,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -903,7 +903,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
import vllm.envs as envs import vllm.envs as envs
if envs.VLLM_USE_FLASH_ATTN_AUTO: if envs.VLLM_USE_FLASH_ATTN_AUTO:
for group_id in range(1): for group_id in range(1):
if max_num_batched_tokens >= 8000:
seq_len = 8000 seq_len = 8000
else:
seq_len = max_num_batched_tokens
batch_size += seq_len batch_size += seq_len
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment