"lib/vscode:/vscode.git/clone" did not exist on "84e71e27d36e3db7168e673137ac9d6d10537efe"
Commit 3a6764a4 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix fa and triton tests

parent 2dbefd03
...@@ -359,6 +359,8 @@ def test_multi_query_kv_attention( ...@@ -359,6 +359,8 @@ def test_multi_query_kv_attention(
attn_bias=attn_bias, attn_bias=attn_bias,
p=0.0, p=0.0,
scale=scale, scale=scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
) )
output = output.squeeze(0) output = output.squeeze(0)
......
...@@ -3,6 +3,7 @@ from typing import Optional, Union ...@@ -3,6 +3,7 @@ from typing import Optional, Union
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import is_hip
def seeded_uniform( def seeded_uniform(
...@@ -69,8 +70,14 @@ def seeded_uniform( ...@@ -69,8 +70,14 @@ def seeded_uniform(
# Manual tuning. This seems to give best performance on A100 for # Manual tuning. This seems to give best performance on A100 for
# simple kernels like this. # simple kernels like this.
if philox_block_size >= 8192: if philox_block_size >= 8192:
if is_hip():
num_warps = 16
else:
num_warps = 32 num_warps = 32
elif philox_block_size >= 4096: elif philox_block_size >= 4096:
if is_hip():
num_warps = 8
else:
num_warps = 16 num_warps = 16
elif philox_block_size >= 2048: elif philox_block_size >= 2048:
num_warps = 8 num_warps = 8
......
...@@ -6,6 +6,7 @@ import triton ...@@ -6,6 +6,7 @@ import triton
import triton.language as tl import triton.language as tl
from vllm.model_executor.layers.ops.rand import seeded_uniform from vllm.model_executor.layers.ops.rand import seeded_uniform
from vllm.utils import is_hip
_EPS = 1e-6 _EPS = 1e-6
...@@ -278,8 +279,14 @@ def _sample(probs: torch.Tensor, ...@@ -278,8 +279,14 @@ def _sample(probs: torch.Tensor,
# Manual tuning. This seems to give best performance on A100 for # Manual tuning. This seems to give best performance on A100 for
# simple kernels like this. # simple kernels like this.
if block_size >= 8192: if block_size >= 8192:
if is_hip():
num_warps = 16
else:
num_warps = 32 num_warps = 32
elif block_size >= 4096: elif block_size >= 4096:
if is_hip():
num_warps = 8
else:
num_warps = 16 num_warps = 16
elif block_size >= 2048: elif block_size >= 2048:
num_warps = 8 num_warps = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment