Commit 1ee40c4d authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_TRITON_OPT_MLA to use optimized MLA attention

update mla optest
parent e3a0d6bb
...@@ -508,10 +508,10 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -508,10 +508,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
version = 'das.opt1.cust1.' + sha[:7] version = 'das.opt1.' + sha[:7]
else: else:
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
version = 'das.opt1.cust1' version = 'das.opt1'
# dtk version # dtk version
...@@ -696,7 +696,7 @@ package_data = { ...@@ -696,7 +696,7 @@ package_data = {
"model_executor/layers/fused_moe/configs/*.json", "model_executor/layers/fused_moe/configs/*.json",
"model_executor/layers/quantization/utils/configs/*.json", "model_executor/layers/quantization/utils/configs/*.json",
"benchmarks/*.py", "benchmarks/*.py",
"model_executor/layers/quantization/configs/w8a8/*.json", "attention/backends/configs/*.json",
"model_executor/layers/quantization/configs/awq/*.json" "model_executor/layers/quantization/configs/awq/*.json"
] ]
} }
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
import pytest import pytest
import torch import torch
import triton
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_decode_attention import decode_attention_fwd, decode_attention_v1, decode_attention_v2
def cdiv(a, b): def cdiv(a, b):
return (a + b - 1) // b return (a + b - 1) // b
...@@ -25,13 +25,13 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ...@@ -25,13 +25,13 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
sm_scale = 1.0 / (D_QK**0.5) sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 8 num_kv_splits = 8
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) # 向上取整:65, (1027+16-1)//16
req_to_page = torch.randint(0, req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE, CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1), (B, num_pages_per_batch, 1), #shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device="cuda") device="cuda")
req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) # 维度扩展,从torch.Size([3, 65, 1])扩展至torch.Size([3, 65, 16])
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view( req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1) 1, 1, -1)
req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token.view(B, -1)
...@@ -47,14 +47,22 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ...@@ -47,14 +47,22 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
# o will have the same shape as q # o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda") b_seq_len = torch.full((B, ), seq_len, device="cuda")
b_start_loc = torch.arange(0, k_buffer.shape[0] * PAGE_SIZE, k_buffer.shape[0] * PAGE_SIZE // q.shape[0], device="cuda").to(torch.int32)
attn_logits_v1 = torch.empty(
(q.shape[1], k_buffer.shape[0]*PAGE_SIZE),
dtype=torch.float16,
device="cuda")
attn_logits = torch.empty( attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1), (B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32, dtype=torch.float32,
device="cuda", device="cuda",
) )
quantiles = [0.5, 0.2, 0.8]
# Call the original implementation. # Call the original implementation.
decode_attention_fwd( decode_attention_fwd(
...@@ -87,5 +95,81 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ...@@ -87,5 +95,81 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
sm_scale, sm_scale,
PAGE_SIZE, PAGE_SIZE,
) )
assert torch.allclose(o, o1) assert torch.allclose(o, o1)
# v0_tc_ms, v0_tc_min_ms, v0_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_fwd(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention ori kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v0_tc_ms)
decode_attention_v1(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_start_loc,
b_seq_len,
attn_logits_v1,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)
# v1_tc_ms, v1_tc_min_ms, v1_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention v1 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v1_tc_ms)
decode_attention_v2(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)
# v2_tc_ms, v2_tc_min_ms, v2_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_v2(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention v2 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v2_tc_ms)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import functools
import json
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -33,6 +36,57 @@ if TYPE_CHECKING: ...@@ -33,6 +36,57 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
@functools.lru_cache
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
raise ValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
return None
class TritonMLABackend(AttentionBackend): class TritonMLABackend(AttentionBackend):
...@@ -735,12 +789,15 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): ...@@ -735,12 +789,15 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1) PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# TODO
# config = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
# Run MQA # Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables, decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits, decode_meta.seq_lens_tensor, attn_logits,
attn_metadata.num_kv_splits, self.scale, attn_metadata.num_kv_splits, self.scale, # config,
PAGE_SIZE) PAGE_SIZE)
return self._v_up_proj_and_o_proj(o) return self._v_up_proj_and_o_proj(o)
...@@ -30,11 +30,12 @@ It supports page size >= 1. ...@@ -30,11 +30,12 @@ It supports page size >= 1.
import os import os
import logging import logging
import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm import envs
is_hip_ = current_platform.is_rocm() is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0" os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0"
...@@ -221,7 +222,6 @@ def _decode_att_m_fwd( ...@@ -221,7 +222,6 @@ def _decode_att_m_fwd(
PAGE_SIZE=page_size, PAGE_SIZE=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
num_warps=num_warps, num_warps=num_warps,
num_stages=2,
Lk=Lk, Lk=Lk,
Lv=Lv, Lv=Lv,
) )
...@@ -458,7 +458,6 @@ def _decode_grouped_att_m_fwd( ...@@ -458,7 +458,6 @@ def _decode_grouped_att_m_fwd(
PAGE_SIZE=page_size, PAGE_SIZE=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
num_warps=4, num_warps=4,
num_stages=2,
Lk=Lk, Lk=Lk,
Lv=Lv, Lv=Lv,
**extra_kargs, **extra_kargs,
...@@ -541,7 +540,7 @@ def _decode_softmax_reducev_fwd( ...@@ -541,7 +540,7 @@ def _decode_softmax_reducev_fwd(
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = { extra_kargs = {
"waves_per_eu": 4, "waves_per_eu": 0,
"matrix_instr_nonkdim": 16, "matrix_instr_nonkdim": 16,
"kpack": 2 "kpack": 2
} }
...@@ -560,7 +559,6 @@ def _decode_softmax_reducev_fwd( ...@@ -560,7 +559,6 @@ def _decode_softmax_reducev_fwd(
BLOCK_DV=BLOCK_DV, BLOCK_DV=BLOCK_DV,
Lv=Lv, Lv=Lv,
num_warps=4, num_warps=4,
num_stages=2,
**extra_kargs, **extra_kargs,
) )
...@@ -623,6 +621,806 @@ def decode_attention_fwd_grouped( ...@@ -623,6 +621,806 @@ def decode_attention_fwd_grouped(
num_kv_splits) num_kv_splits)
# opt
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
],
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"]
)
@triton.jit
def _decode_v1_kernel_stage1_use_tc(
Q,
K_Buffer,
sm_scale,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
att_stride_h,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
SPLIT_K: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_k_id = tl.program_id(2)
reduce_dtype = Att_Out.dtype.element_ty
if BLOCK_H < kv_group_num:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
# cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_req_idx = cur_batch
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(
Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0
).to(reduce_dtype)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
split_k_start = kv_len_per_split * split_k_id
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
for start_n in range(split_k_start, split_k_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
offs_n // PAGE_SIZE,
mask=offs_n < split_k_end,
other=0,
)
k_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (
k_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk),
other=0.0,
).to(reduce_dtype)
qk = tl.dot(q, k)
if BLOCK_DPE > 0:
offs_buf_kpe = (
k_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=offs_n[None, :] < split_k_end,
other=0.0,
).to(reduce_dtype)
qk += tl.dot(qpe, kpe)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
offs_o = cur_head[:, None] * att_stride_h + (
cur_batch_in_all_start_index + offs_n[None, :]
)
tl.store(
Att_Out + offs_o,
qk,
mask=mask_h[:, None] & (offs_n[None, :] < split_k_end),
)
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=0, num_stages=1),
],
key=["B_Seqlen","stride_logic_h","stride_buf_vbs","stride_buf_vh"]
)
@triton.jit
def _decode_v1_kernel_stage2_use_tc(
logits,
V_Buffer,
Out,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
stride_logic_h,
stride_buf_vbs,
stride_buf_vh,
stride_obs,
stride_oh,
stride_req_to_token_b,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_H: tl.constexpr,
PAGE_SIZE: tl.constexpr,
Lv: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1)
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
mask_h = mask_h & (cur_head < q_head_num)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = cur_batch #tl.load(B_req_idx + cur_batch)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
v_ptrs = V_Buffer + offs_buf_v
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
v_page_number = tl.load(
Req_to_tokens
+ cur_batch_req_idx * stride_req_to_token_b
+ (start_n + offs_n) // PAGE_SIZE,
mask=(start_n + offs_n) < cur_batch_seq_len,
other=0,
)
v_loc = v_page_number * PAGE_SIZE + (start_n + offs_n) % PAGE_SIZE
offs_qk = cur_head[:, None] * stride_logic_h + (
cur_batch_start_loc + start_n + offs_n[None, :]
)
qk = tl.load(
logits + offs_qk,
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
other=float("-inf"),
) #[head, block_n]
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
e_sum = e_sum * old_scale + tl.sum(p, 1)
v = tl.load(
v_ptrs + v_loc[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
) #[block_n,head_dim]
p = p.to(v.dtype)
acc = acc * old_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
acc = acc / e_sum[:, None]
off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv))
def _decode_v1_stage1_use_tc(
q,
k_buffer,
att_out,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
sm_scale,
page_size,
num_kv_splits,
logit_cap,
):
Lk = k_buffer.shape[-1]
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lk == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
else:
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DPE = 0
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
SPLIT_K = num_kv_splits
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = lambda META: (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
SPLIT_K,
)
_decode_v1_kernel_stage1_use_tc[grid](
q,
k_buffer,
sm_scale,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
att_out.stride(0),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_H=BLOCK_H,
SPLIT_K=SPLIT_K,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
Lk=Lk,
kpack=2,
)
return _decode_v1_kernel_stage1_use_tc.best_config
def _decode_v1_stage2_use_tc(
logits,
v_buffer,
o,
req_to_tokens,
#b_req_idx,
b_start_loc,
b_seq_len,
page_size,
):
batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logits.shape[0] // v_buffer.shape[-2]
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
Lv = v_buffer.shape[-1]
BLOCK_DMODEL = triton.next_power_of_2(Lv)
_decode_v1_kernel_stage2_use_tc[grid](
logits,
v_buffer,
o,
req_to_tokens,
#b_req_idx,
b_start_loc,
b_seq_len,
logits.stride(0),
v_buffer.stride(-3),
v_buffer.stride(-2),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_H=BLOCK_H,
PAGE_SIZE=page_size,
Lv=Lv,
)
return _decode_v1_kernel_stage2_use_tc.best_config
def decode_attention_v1(
q,
k_buffer,
v_buffer,
o,
req_to_token,
#b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
# GQA/MQA/MLA
_decode_v1_stage1_best_config = _decode_v1_stage1_use_tc(
q,
k_buffer,
attn_logits,
req_to_token,
#b_req_idx,
b_start_loc,
b_seq_len,
sm_scale,
page_size,
num_kv_splits,
logit_cap,
)
_decode_v1_stage2_best_config = _decode_v1_stage2_use_tc(
attn_logits,
v_buffer,
o,
req_to_token,
#b_req_idx,
b_start_loc,
b_seq_len,
page_size,
)
return _decode_v1_stage1_best_config, _decode_v1_stage2_best_config
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=1),
],
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"]
)
@triton.jit
def _decode_v2_kernel_stage1_use_tc(
Q,
K_Buffer,
V_Buffer,
sm_scale,
Req_to_tokens,
# B_req_idx,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_kv_id = tl.program_id(2)
if BLOCK_H < kv_group_num:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk
mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
# cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_req_idx = cur_batch
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
other=0.0,
)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
)
offs_buf_v = (
kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_mid_o = (
cur_batch * stride_mid_ob
+ cur_head[:, None] * stride_mid_oh
+ split_kv_id * stride_mid_os
+ offs_dv[None, :]
)
tl.store(
Att_Out + offs_mid_o,
acc / e_sum[:, None],
mask=(mask_h[:, None]) & (mask_dv[None, :]),
)
offs_mid_o_1 = (
cur_batch * stride_mid_ob
+ cur_head * stride_mid_oh
+ split_kv_id * stride_mid_os
+ Lv
)
tl.store(
Att_Out + offs_mid_o_1,
e_max + tl.log(e_sum),
mask=mask_h,
)
def _decode_v2_stage1_use_tc(
q,
k_buffer,
v_buffer,
att_out,
Req_to_tokens,
# B_req_idx,
B_Seqlen,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
):
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lk == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
else:
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_H = 16
NUM_KV_SPLITS = num_kv_splits
grid = lambda META: (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
NUM_KV_SPLITS,
)
_decode_v2_kernel_stage1_use_tc[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
# B_req_idx,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
v_buffer.stride(-3),
v_buffer.stride(-2),
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
Lk=Lk,
Lv=Lv,
kpack=2,
)
return _decode_v2_kernel_stage1_use_tc.best_config
@triton.autotune(
configs=[
triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=2, num_stages=1),
triton.Config({}, num_warps=4, num_stages=1),
triton.Config({}, num_warps=8, num_stages=1),
],
key=["B_Seqlen", "stride_mid_ob", "stride_mid_oh", "stride_mid_os"]
)
@triton.jit
def _decode_v2_kernel_stage2(
Mid_O,
O,
B_Seqlen,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
if split_kv_end > split_kv_start:
tv = tl.load(
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
mask=mask_d,
)
def _decode_v2_stage2_use_tc(
logits,
q,
o,
v_buffer,
b_seq_len,
num_kv_splits,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)
NUM_KV_SPLITS = num_kv_splits
grid = (batch, head_num)
_decode_v2_kernel_stage2[grid](
logits,
o,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
)
return _decode_v2_kernel_stage2.best_config
def decode_attention_v2(
q,
k_buffer,
v_buffer,
o,
req_to_token,
# b_req_idx,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
_decode_v2_stage1_best_config = _decode_v2_stage1_use_tc(
q,
k_buffer,
v_buffer,
attn_logits,
req_to_token,
# b_req_idx,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
_decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config
def decode_attention_fwd( def decode_attention_fwd(
q, q,
k_buffer, k_buffer,
...@@ -633,12 +1431,13 @@ def decode_attention_fwd( ...@@ -633,12 +1431,13 @@ def decode_attention_fwd(
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
# config,
page_size=1, page_size=1,
logit_cap=0.0, logit_cap=0.0,
): ):
assert num_kv_splits == attn_logits.shape[2] assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2] kv_group_num = q.shape[1] // v_buffer.shape[-2]
b_start_loc = torch.arange(0, k_buffer.shape[0] * page_size, k_buffer.shape[0] * page_size // q.shape[0], device="cuda").to(torch.int32)
if kv_group_num == 1: if kv_group_num == 1:
# MHA # MHA
decode_attention_fwd_normal( decode_attention_fwd_normal(
...@@ -656,16 +1455,88 @@ def decode_attention_fwd( ...@@ -656,16 +1455,88 @@ def decode_attention_fwd(
) )
else: else:
# GQA/MQA/MLA # GQA/MQA/MLA
decode_attention_fwd_grouped( if envs.VLLM_USE_TRITON_OPT_MLA:
q, decode_attention_v2(
k_buffer, q,
v_buffer, k_buffer,
o, v_buffer,
req_to_token, o,
b_seq_len, req_to_token,
attn_logits, b_seq_len,
num_kv_splits, attn_logits,
sm_scale, num_kv_splits,
page_size, sm_scale,
logit_cap, page_size,
) logit_cap,
)
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
# device="cuda")
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits, # sub
# sm_scale,
# page_size,
# logit_cap,
# )
# TODO
# if best_config['kernel_kind'] == 'v1_2stages_tc':
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
# device="cuda")
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits,
# sm_scale,
# config,
# page_size,
# logit_cap,
# )
# elif best_config['kernel_kind'] == 'v2_tc':
# decode_attention_v2(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale,
# config,
# page_size,
# logit_cap,
# )
# else:
# print("Unknown mla kernel kind: ", best_config['kernel_kind'])
else:
decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
\ No newline at end of file
...@@ -15,6 +15,7 @@ if TYPE_CHECKING: ...@@ -15,6 +15,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
...@@ -573,6 +574,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -573,6 +574,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vLLM will disable the MLA attention optimizations. # If set, vLLM will disable the MLA attention optimizations.
"VLLM_MLA_DISABLE": "VLLM_MLA_DISABLE":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
# If set, vLLM will use optimized MLA attention optimizations.
"VLLM_USE_TRITON_OPT_MLA":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))),
# Flag that can control whether or not we perform matrix-absorption for MLA # Flag that can control whether or not we perform matrix-absorption for MLA
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the # decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment