Commit b0dfa004 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_TRITON_CAT to opt torch cat

parent a99300bd
...@@ -195,6 +195,7 @@ if TYPE_CHECKING: ...@@ -195,6 +195,7 @@ if TYPE_CHECKING:
VLLM_USE_APEX_RN: bool = False VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False VLLM_USE_LIGHT_OP: bool = False
VLLM_USE_TRITON_CAT: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1342,6 +1343,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1342,6 +1343,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_LIGHT_OP": "VLLM_USE_LIGHT_OP":
lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe
"VLLM_USE_TRITON_CAT":
lambda: (os.environ.get("VLLM_USE_TRITON_CAT", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -224,6 +224,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -224,6 +224,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
split_decodes_and_prefills) split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.mla.concatv3Tritonfinal import concat_helper
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -1217,8 +1218,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1217,8 +1218,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), if envs.VLLM_USE_TRITON_CAT:
dim=-1) k = concat_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=-1)
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk( attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata, prefill=prefill_metadata,
...@@ -1267,7 +1272,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1267,7 +1272,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) if envs.VLLM_USE_TRITON_CAT:
k = concat_helper((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._run_prefill_new_tokens( output = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill, prefill=attn_metadata.prefill,
......
import triton
import triton.language as tl
import torch
from functools import reduce
import pytest
import torch
import math
@pytest.mark.parametrize("shape_pair,dim", [
(((4, 8, 512), (4, 8, 64)), 2),
(((8, 8, 512), (8, 8, 64)), 2),
(((16, 8, 512), (16, 8, 64)), 2),
(((32, 8, 512), (32, 8, 64)), 2),
(((64, 8, 512), (64, 8, 64)), 2),
(((128, 8, 512), (128, 8, 64)), 2),
(((256, 8, 512), (256, 8, 64)), 2),
(((512, 8, 512), (512, 8, 64)), 2),
(((672, 8, 512), (672, 8, 64)), 2),
(((768, 8, 512), (768, 8, 64)), 2),
(((896, 8, 512), (896, 8, 64)), 2),
(((1024, 8, 512), (1024, 8, 64)), 2),
(((4, 16, 512), (4, 16, 64)), 2),
(((8, 16, 512), (8, 16, 64)), 2),
(((16, 16, 512), (16, 16, 64)), 2),
(((32, 16, 512), (32, 16, 64)), 2),
(((64, 16, 512), (64, 16, 64)), 2),
(((128, 16, 512), (128, 16, 64)), 2),
(((256, 16, 512), (256, 16, 64)), 2),
(((512, 16, 512), (512, 16, 64)), 2),
(((672, 16, 512), (672, 16, 64)), 2),
(((768, 16, 512), (768, 16, 64)), 2),
(((896, 16, 512), (896, 16, 64)), 2),
(((1024, 16, 512), (1024, 16, 64)), 2),
(((4, 32, 512), (4, 32, 64)), 2),
(((8, 32, 512), (8, 32, 64)), 2),
(((16, 32, 512), (16, 32, 64)), 2),
(((32, 32, 512), (32, 32, 64)), 2),
(((64, 32, 512), (64, 32, 64)), 2),
(((128, 32, 512), (128, 32, 64)), 2),
(((256, 32, 512), (256, 32, 64)), 2),
(((512, 32, 512), (512, 32, 64)), 2),
(((672, 32, 512), (672, 32, 64)), 2),
(((768, 32, 512), (768, 32, 64)), 2),
(((896, 32, 512), (896, 32, 64)), 2),
(((1024, 32, 512), (1024, 32, 64)), 2),
])
def test_concat_Acc(shape_pair, dim):
torch.manual_seed(1)
shape1, shape2 = shape_pair
x = torch.randn(*shape1, device='cuda', dtype=torch.bfloat16)
y = torch.randn(*shape2, device='cuda', dtype=torch.bfloat16)
expected = torch.cat([x,y], dim=dim)
result = concat_helper(x, y, dim=dim)
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
@triton.jit
def concat_kernel(
A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel,
Per_block,
section_num,
BLOCK_SIZE: tl.constexpr
):
block_idx = tl.program_id(0)
for sub_section_index in range(Per_block):
sub_offset = block_idx * Per_block + sub_section_index
if sub_offset <= section_num-1:
C_ptr_block_start = C_ptr + sub_offset * C_section_numel
A_ptr_block_start = A_ptr + sub_offset * A_section_numel
B_ptr_block_start = B_ptr + sub_offset * B_section_numel
for offset in range(0, A_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < A_section_numel
val_from_A = tl.load(A_ptr_block_start + offset_idx, mask=mask)
tl.store(C_ptr_block_start + offset_idx, val_from_A, mask=mask)
for offset in range(0, B_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < B_section_numel
val_from_B = tl.load(B_ptr_block_start + offset_idx, mask=mask)
tl.store(C_ptr_block_start + A_section_numel + offset_idx, val_from_B, mask=mask)
def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
A = A.contiguous()
B = B.contiguous()
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
if dim!=0 :
block_num = reduce(lambda x, y: x * y, output_shape[:dim])
unit_offset_A, unit_offset_B, unit_offset_C = A.stride(dim-1),B.stride(dim-1),C.stride(dim-1)
Per_block = 1
if (A.shape[1]==8 and A.shape[0] > 128) or ( A.shape[1]==16 and A.shape[0] > 96) or ( A.shape[1]==32 and A.shape[0] > 64):
Per_block = 2
num_blocks = math.ceil(block_num/Per_block)
concat_kernel[(num_blocks,)](
A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C,
Per_block,
block_num,
BLOCK_SIZE=1024)
return C
assert False, "not support"
configs = []
configs.append(
triton.testing.Benchmark(
x_names=['size'],
x_vals=[4,8,16,32,64,96,128,256,512,768,1024],
x_log=True,
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
styles=[('blue', '-'), ('green', '-')],
ylabel='s',
plot_name='concat-dim2',
args={"dim":2},
),
)
@triton.testing.perf_report(configs)
def benchmark(size, provider, dim):
x = torch.rand([size,8,512], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,8,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs)
def benchmark_16(size, provider, dim):
x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs)
def benchmark_32(size, provider, dim):
x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
if __name__ == '__main__':
benchmark.run(save_path="./triton_test_8",print_data=True)
benchmark_16.run(save_path="./triton_test_16",print_data=True)
benchmark_32.run(save_path="./triton_test_32",print_data=True)
\ No newline at end of file
...@@ -19,6 +19,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend, ...@@ -19,6 +19,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonMetadataBuilder) MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import envs
from vllm.v1.attention.backends.mla.concatv3Tritonfinal import concat_helper
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -176,8 +178,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -176,8 +178,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
q = torch.cat([q_nope, q_pe], dim=-1)\ if envs.VLLM_USE_TRITON_CAT:
.unsqueeze(1) # Add seqlen dim of 1 (decode) q = concat_helper(q_nope, q_pe, dim=-1)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache( o, _ = flash_mla_with_kvcache(
q=q, q=q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment