Unverified Commit 25caa7a8 authored by jacky.cheng's avatar jacky.cheng Committed by GitHub
Browse files

[AMD] Support Wave attention backend with AMD GPU optimizations (#8660)


Signed-off-by: default avatarStanley Winata <stanley.winata@amd.com>
Signed-off-by: default avatarHarsh Menon <harsh@nod-labs.com>
Signed-off-by: default avatarnithinsubbiah <nithinsubbiah@gmail.com>
Signed-off-by: default avatarIvan Butygin <ivan.butygin@gmail.com>
Signed-off-by: default avatarxintin <gaurav.verma@amd.com>
Co-authored-by: default avatarHarsh Menon <harsh@nod-labs.com>
Co-authored-by: default avatarStanley Winata <stanley.winata@amd.com>
Co-authored-by: default avatarStanley Winata <68087699+raikonenfnu@users.noreply.github.com>
Co-authored-by: default avatarStanley Winata <stanley@nod-labs.com>
Co-authored-by: default avatarIvan Butygin <ivan.butygin@gmail.com>
Co-authored-by: default avatarnithinsubbiah <nithinsubbiah@gmail.com>
Co-authored-by: default avatarNithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com>
Co-authored-by: default avatarIvan Butygin <ibutygin@amd.com>
parent 03d11449
...@@ -14,6 +14,7 @@ You can test them according to your needs. ...@@ -14,6 +14,7 @@ You can test them according to your needs.
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ | | **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
| **TRTLLM MLA** | ✅ | ❌ | ✅ | ✅ | ❌ | | **TRTLLM MLA** | ✅ | ❌ | ✅ | ✅ | ❌ |
| **Ascend** | ✅ | ❌ | ✅ | ❌ | ❌ | | **Ascend** | ✅ | ❌ | ✅ | ❌ | ❌ |
| **Wave** | ✅ | ❌ | ❌ | ❌ | ❌ |
**Notes:** **Notes:**
- TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend. - TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend.
...@@ -70,6 +71,10 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti ...@@ -70,6 +71,10 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
``` ```
- Wave
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend wave
```
## Steps to add a new attention backend ## Steps to add a new attention backend
To add a new attention backend, you can learn from the existing backends To add a new attention backend, you can learn from the existing backends
......
...@@ -82,6 +82,7 @@ srt_hip = [ ...@@ -82,6 +82,7 @@ srt_hip = [
"sglang[runtime_common]", "sglang[runtime_common]",
"torch", "torch",
"petit_kernel==0.0.2", "petit_kernel==0.0.2",
"wave-lang==1.0.1",
] ]
# CPU: torch wheel for CPU needs to be installed from https://download.pytorch.org/whl/cpu # CPU: torch wheel for CPU needs to be installed from https://download.pytorch.org/whl/cpu
......
This diff is collapsed.
"""
Memory-efficient attention for decoding.
It supports page size = 1.
"""
import functools
import logging
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.constraints import GenericDot, MMAOperand, MMAType
from wave_lang.kernel.wave.templates.paged_decode_attention import (
get_paged_decode_attention_kernels,
get_paged_decode_intermediate_arrays_shapes,
paged_decode_attention_shape,
)
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
logger = logging.getLogger(__name__)
import os
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
@functools.lru_cache(maxsize=4096)
def get_wave_kernel(
shape: paged_decode_attention_shape,
max_kv_splits,
input_dtype,
output_dtype,
logit_cap,
):
mha = (shape.num_query_heads // shape.num_kv_heads) == 1
# Get the kernels (either compile or load from cache).
if mha:
mfma_variant = (
GenericDot(along_dim=MMAOperand.M, k_vec_size=4, k_mult=1),
GenericDot(along_dim=MMAOperand.M, k_vec_size=1, k_mult=64),
)
else:
mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
(
phase_0,
phase_1,
hyperparams_0,
hyperparams_1,
dynamic_symbols_0,
dynamic_symbols_1,
) = get_paged_decode_attention_kernels(
shape,
mfma_variant,
max_kv_splits,
input_dtype=input_dtype,
output_dtype=output_dtype,
logit_cap=logit_cap,
)
hyperparams_0.update(get_default_scheduling_params())
hyperparams_1.update(get_default_scheduling_params())
options = WaveCompileOptions(
subs=hyperparams_0,
canonicalize=True,
run_bench=False,
use_buffer_load_ops=True,
use_buffer_store_ops=True,
waves_per_eu=2,
dynamic_symbols=dynamic_symbols_0,
wave_runtime=True,
)
options = set_default_run_config(options)
phase_0 = wave_compile(options, phase_0)
options = WaveCompileOptions(
subs=hyperparams_1,
canonicalize=True,
run_bench=False,
use_buffer_load_ops=False,
use_buffer_store_ops=False,
waves_per_eu=4,
dynamic_symbols=dynamic_symbols_1,
wave_runtime=True,
)
options = set_default_run_config(options)
phase_1 = wave_compile(options, phase_1)
return phase_0, phase_1
def decode_attention_intermediate_arrays_shapes(
num_seqs, head_size_kv, num_query_heads, max_kv_splits
):
# Not all fields are used, but we need to pass them to the function
shape = paged_decode_attention_shape(
num_query_heads=num_query_heads,
num_kv_heads=0,
head_size=0,
head_size_kv=head_size_kv,
block_size=0,
num_seqs=num_seqs,
)
return get_paged_decode_intermediate_arrays_shapes(shape, max_kv_splits)
def decode_attention_wave(
q,
k_buffer,
v_buffer,
o,
b_req_idx,
req_to_token,
attn_logits,
attn_logits_max,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap,
):
num_seqs, num_query_heads, head_size = q.shape
_, num_kv_heads, _ = k_buffer.shape
_, _, head_size_kv = v_buffer.shape
block_size = 32
shape = paged_decode_attention_shape(
num_query_heads,
num_kv_heads,
head_size,
head_size_kv,
block_size,
num_seqs,
)
phase_0, phase_1 = get_wave_kernel(
shape, max_kv_splits, q.dtype, o.dtype, logit_cap
)
mb_qk = phase_0(
q,
k_buffer,
v_buffer,
b_req_idx,
req_to_token,
attn_logits,
attn_logits_max,
)
if dump_generated_mlir:
filename = f"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb_qk.module_op.get_asm())
mb_sv = phase_1(attn_logits, attn_logits_max, b_req_idx, o)
if dump_generated_mlir:
filename = f"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb_sv.module_op.get_asm())
def decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
b_req_idx,
req_to_token,
attn_logits,
attn_logits_max,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap=0.0,
):
decode_attention_wave(
q,
k_buffer,
v_buffer,
o,
b_req_idx,
req_to_token,
attn_logits,
attn_logits_max,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap,
)
"""
Memory-efficient attention for prefill.
It support page size = 1.
"""
import functools
import os
import torch
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.constraints import MMAType
from wave_lang.kernel.wave.scheduling.schedule import SchedulingType
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
from wave_lang.kernel.wave.templates.extend_attention import get_extend_attention_kernel
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
@functools.lru_cache
def get_wave_kernel(
shape: AttentionShape,
q_shape: tuple[int],
k_shape: tuple[int],
v_shape: tuple[int],
k_cache_shape: tuple[int],
v_cache_shape: tuple[int],
o_shape: tuple[int],
input_dtype: torch.dtype,
output_dtype: torch.dtype,
size_dtype: torch.dtype,
is_causal: bool,
logit_cap: float,
layer_scaling: float,
):
assert shape.num_query_heads % shape.num_kv_heads == 0
mfma_variant = (MMAType.F32_16x16x32_K8_F16, MMAType.F32_16x16x16_F16)
(
extend_attention,
hyperparams,
dynamic_symbols,
) = get_extend_attention_kernel(
shape,
mfma_variant,
q_shape,
k_shape,
v_shape,
k_cache_shape,
v_cache_shape,
o_shape,
input_dtype=input_dtype,
output_dtype=output_dtype,
size_dtype=size_dtype,
is_causal=is_causal,
layer_scaling=layer_scaling,
logit_cap=logit_cap,
)
hyperparams.update(get_default_scheduling_params())
options = WaveCompileOptions(
subs=hyperparams,
canonicalize=True,
run_bench=False,
schedule=SchedulingType.NONE,
use_scheduling_barriers=False,
dynamic_symbols=dynamic_symbols,
use_buffer_load_ops=True,
use_buffer_store_ops=True,
waves_per_eu=2,
denorm_fp_math_f32="preserve-sign",
gpu_native_math_precision=True,
wave_runtime=True,
)
options = set_default_run_config(options)
extend_attention = wave_compile(options, extend_attention)
return extend_attention
def extend_attention_wave(
q_extend,
k_extend,
v_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_indptr,
max_seq_len,
output,
is_causal=True,
layer_scaling=None,
logit_cap=0,
):
shape = AttentionShape(
num_query_heads=q_extend.shape[1],
num_kv_heads=k_extend.shape[1],
head_size=q_extend.shape[2],
head_size_kv=k_extend.shape[2],
num_seqs=kv_indptr.shape[0] - 1,
max_seq_len=max_seq_len,
)
# Run the wave kernel.
extend_attention = get_wave_kernel(
shape,
q_extend.shape,
k_extend.shape,
v_extend.shape,
k_buffer.shape,
v_buffer.shape,
output.shape,
input_dtype=q_extend.dtype,
output_dtype=output.dtype,
size_dtype=qo_indptr.dtype,
is_causal=is_causal,
layer_scaling=layer_scaling,
logit_cap=logit_cap,
)
mb = extend_attention(
q_extend,
k_extend,
v_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
max_seq_len,
output,
)
if dump_generated_mlir:
shape_list = [
q_extend.shape[0],
q_extend.shape[1],
k_extend.shape[1],
q_extend.shape[2],
k_extend.shape[2],
]
filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
"""
Memory-efficient attention for prefill.
It support page size = 1.
"""
import math
import os
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.constraints import MMAType
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
from wave_lang.kernel.wave.templates.prefill_attention import (
get_prefill_attention_kernel,
)
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
def prefill_attention_wave(
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True
):
shape = AttentionShape(
num_query_heads=q.shape[1],
num_kv_heads=k.shape[1],
head_size=q.shape[2],
head_size_kv=k.shape[2],
num_seqs=b_seq_len.shape[0],
max_seq_len=max_seq_len,
total_seq_len=q.shape[0],
)
assert shape.num_query_heads % shape.num_kv_heads == 0
output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv)
# Run the wave kernel.
mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
(prefill, hyperparams) = get_prefill_attention_kernel(
shape,
mfma_variant,
q.shape,
k.shape,
v.shape,
output_shape,
input_dtype=q.dtype,
output_dtype=o.dtype,
size_dtype=b_seq_len.dtype,
)
hyperparams.update(get_default_scheduling_params())
log2e = 1.44269504089
dk_sqrt = math.sqrt(1.0 / shape.head_size)
options = WaveCompileOptions(
subs=hyperparams,
canonicalize=True,
run_bench=False,
use_scheduling_barriers=False,
)
options = set_default_run_config(options)
prefill = wave_compile(options, prefill)
mb = prefill(
q * dk_sqrt * log2e,
k,
v,
b_start_loc,
b_seq_len,
o,
)
if dump_generated_mlir:
shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]]
filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
...@@ -1487,6 +1487,10 @@ class ModelRunner: ...@@ -1487,6 +1487,10 @@ class ModelRunner:
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
return AiterAttnBackend(self) return AiterAttnBackend(self)
elif self.server_args.attention_backend == "wave":
from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
return WaveAttnBackend(self)
elif backend_str == "ascend": elif backend_str == "ascend":
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
......
...@@ -1323,6 +1323,7 @@ class ServerArgs: ...@@ -1323,6 +1323,7 @@ class ServerArgs:
"trtllm_mla", "trtllm_mla",
"trtllm_mha", "trtllm_mha",
"dual_chunk_flash_attn", "dual_chunk_flash_attn",
"wave",
] ]
parser.add_argument( parser.add_argument(
"--attention-backend", "--attention-backend",
......
...@@ -196,6 +196,8 @@ suite_amd = { ...@@ -196,6 +196,8 @@ suite_amd = {
TestFile("test_torch_native_attention_backend.py", 123), TestFile("test_torch_native_attention_backend.py", 123),
TestFile("test_triton_attention_backend.py", 150), TestFile("test_triton_attention_backend.py", 150),
# TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701 # TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701
TestFile("test_wave_attention_kernels.py", 2),
TestFile("test_wave_attention_backend.py", 150),
], ],
"per-commit-2-gpu-amd": [ "per-commit-2-gpu-amd": [
TestFile("lora/test_lora_tp.py", 116), TestFile("lora/test_lora_tp.py", 116),
......
"""
Usage:
python3 -m unittest test_wave_attention_backend.TestWaveAttnBackend.test_mmlu
"""
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
is_in_ci,
popen_launch_server,
run_bench_one_batch,
)
class TestWaveAttnBackend(unittest.TestCase):
def test_latency(self):
_, output_throughput, _ = run_bench_one_batch(
DEFAULT_MODEL_NAME_FOR_TEST,
[
"--attention-backend",
"wave",
"--enable-torch-compile",
],
)
if is_in_ci():
self.assertGreater(output_throughput, 153)
def _test_mmlu(self):
model = DEFAULT_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--attention-backend", "wave"],
)
try:
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
finally:
kill_process_tree(process.pid)
if __name__ == "__main__":
unittest.main()
import random
import unittest
import torch
from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd_grouped as triton_decode_attention_fwd_grouped,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd,
redundant_attention,
)
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
from sglang.srt.layers.attention.wave_ops.decode_attention import (
decode_attention_intermediate_arrays_shapes,
decode_attention_wave,
)
from sglang.srt.layers.attention.wave_ops.extend_attention import extend_attention_wave
from sglang.srt.layers.attention.wave_ops.prefill_attention import (
prefill_attention_wave,
)
class TestWaveAttention(unittest.TestCase):
def _set_all_seeds(self, seed):
"""Set all random seeds for reproducibility."""
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def setUp(self):
# Set seeds before each test method
self._set_all_seeds(42)
def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D):
dtype = torch.float16
extend_seq_len = 1024
b_seq_len_prefix = torch.full(
(B,), N_CTX // B, dtype=torch.int32, device="cuda"
)
b_seq_len_extend = torch.full(
(B,), extend_seq_len, dtype=torch.int32, device="cuda"
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = torch.zeros(
(b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda"
)
for i in range(B):
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
o_extend_mask = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
)
o_redundant = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
custom_mask = None
mask_indptr = None
redundant_attention(
q_extend,
o_redundant,
k_buffer,
v_buffer,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
max_len_in_batch,
)
is_causal = True
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
is_causal,
mask_indptr,
max_len_extend,
)
o_wave = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
extend_attention_wave(
q_extend,
k_extend,
v_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_indptr,
max_len_extend,
o_wave,
is_causal=is_causal,
)
self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2))
self.assertTrue(torch.allclose(o_wave, o_redundant, rtol=1e-2))
def test_extend_attention(self):
# Define the varying parameter values
attention_values = [128]
# Loop through the values and call the method
for value in attention_values:
self._test_extend_attention_once(32, 16384, 6, 1, value)
def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V):
dtype = torch.float16
seq_len = S # This represents the number of tokens already in the sequence
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
max_kv_splits = 8
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
# k_buffer and v_buffer represent all previous tokens
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q
o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
req_to_token = torch.arange(total_tokens, device="cuda", dtype=torch.int32)
b_req_idx = torch.zeros(B + 1, device="cuda", dtype=torch.int32)
b_seq_len = torch.full((B,), seq_len, device="cuda", dtype=torch.int32)
b_req_idx[1 : B + 1] = torch.cumsum(b_seq_len, dim=0)
attn_logits = torch.empty(
(B, H_Q, max_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
attn_lse = torch.empty(
(B, H_Q, max_kv_splits),
dtype=torch.float32,
device="cuda",
)
logit_cap = 0.0
triton_decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o_triton,
b_req_idx,
req_to_token,
attn_logits,
attn_lse,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap,
)
attn_logits_shape, attn_logits_max_shape = (
decode_attention_intermediate_arrays_shapes(B, D_V, H_Q, max_kv_splits)
)
attn_logits = torch.empty(
attn_logits_shape,
dtype=torch.float32,
device="cuda",
)
attn_logits_max = torch.empty(
attn_logits_max_shape,
dtype=torch.float32,
device="cuda",
)
decode_attention_wave(
q,
k_buffer,
v_buffer,
o,
b_req_idx,
req_to_token,
attn_logits,
attn_logits_max,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap,
)
cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_triton.flatten(), dim=0
)
print(cos_sim.item())
self.assertTrue(cos_sim.item() > 0.99)
self.assertTrue(torch.allclose(o, o_triton, atol=3e-2))
def test_grouped_decode_attention(self):
seq_lens = [5, 100, 128, 500]
configs = [
(2, 16, 16, 64, 64),
(2, 16, 1, 64, 64),
(2, 128, 1, 80, 80),
(32, 128, 2, 512, 512),
(2, 128, 2, 512, 512),
(2, 128, 1, 576, 512),
]
for S in seq_lens:
for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)
def _test_context_attention_once(self, head_dim, is_causal):
# Set up a simple test case
dtype = torch.float16
num_heads = 4
kv_heads = 1
seq_lens = [128, 256]
max_seq_len = max(seq_lens)
# Create random input tensors
q = torch.randn(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda")
k = torch.randn(sum(seq_lens), kv_heads, head_dim, dtype=dtype, device="cuda")
v = torch.randn(sum(seq_lens), kv_heads, head_dim, dtype=dtype, device="cuda")
o_triton = torch.zeros(
sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda"
)
o = torch.zeros(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda")
# Create b_start_loc and b_seq_len tensors
b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda")
b_seq_len = torch.tensor(seq_lens, device="cuda")
context_attention_fwd(
q, k, v, o_triton, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal
)
prefill_attention_wave(
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal
)
cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_triton.flatten(), dim=0
)
print(cos_sim.item())
self.assertTrue(torch.allclose(o, o_triton, atol=3e-2))
self.assertTrue(cos_sim.item() > 1 - (1e-5))
def test_context_attention(self):
head_dim = [128, 96]
for dim in head_dim:
for is_causal in [False]:
self._test_context_attention_once(dim, is_causal)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment