Unverified Commit 82653f66 authored by DefTruth's avatar DefTruth Committed by GitHub
Browse files

feat: Add a unified merge_state API (#5428)

parent 22da3d97
from typing import Optional, Tuple
import torch
from sgl_kernel import merge_state_v2
from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
# Automatically fallback to the Triton kernel in some cases
# (e.g., for AMD GPUs, when the head dimension is not a multiple
# of 4 or 8, and in FP8 precision)
def _supported_dtypes(o: torch.Tensor) -> bool:
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
def _supported_headdim(o: torch.Tensor) -> bool:
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
if o.dtype == torch.float32:
return headdim % 4 == 0
return headdim % 8 == 0
def merge_state(
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output: Optional[torch.Tensor] = None,
output_lse: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if (
_is_cuda
and _supported_dtypes(prefix_output)
and _supported_headdim(prefix_output)
):
return merge_state_v2(
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
)
else:
# Fallback to Triton kernel
return merge_state_triton(
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
)
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
@triton.jit
def merge_state_kernel(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)
p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
p_lse = float("-inf") if p_lse == float("inf") else p_lse
s_lse = float("-inf") if s_lse == float("inf") else s_lse
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
out_se = tl.exp(p_lse) + tl.exp(s_lse)
if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)
s_out = tl.load(
suffix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)
p_scale = tl.exp(p_lse) / out_se
s_scale = tl.exp(s_lse) / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
out,
mask=head_mask,
)
def merge_state_triton(
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output: Optional[torch.Tensor] = None,
output_lse: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Avoid creating new tensors if they are already provided
if output is None:
output = torch.empty_like(prefix_output)
if output_lse is None:
output_lse = torch.empty_like(prefix_lse)
num_tokens = output.shape[0]
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)
merge_state_kernel[(num_tokens, num_query_heads)](
output,
output_lse,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
head_size,
padded_head_size,
output_lse is not None,
)
return output, output_lse
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment