Unverified Commit fbbe16fa authored by Binyao Jiang's avatar Binyao Jiang Committed by GitHub
Browse files

[GDN] Fuse b.sigmoid(), fused_gdn_gating and unsqueeze into one kernel: up to...

[GDN] Fuse b.sigmoid(), fused_gdn_gating and unsqueeze into one kernel: up to 0.85% e2e speedup (#12508)
parent 6a1a64fa
from typing import Tuple
import torch
import triton
import triton.language as tl
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
# beta_output = b.sigmoid()
@triton.jit
def fused_gdn_gating_kernel(
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
NUM_HEADS: tl.constexpr,
beta: tl.constexpr,
threshold: tl.constexpr,
BLK_HEADS: tl.constexpr,
):
i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
mask = head_off < NUM_HEADS
blk_A_log = tl.load(A_log + head_off, mask=mask)
blk_a = tl.load(a + off, mask=mask)
blk_b = tl.load(b + off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=mask)
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
softplus_x = tl.where(
beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
)
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask)
def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch, num_heads = a.shape
seq_len = 1
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
fused_gdn_gating_kernel[grid](
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
num_heads,
beta,
threshold,
8,
num_warps=1,
)
return g, beta_output
...@@ -5,6 +5,7 @@ from einops import rearrange ...@@ -5,6 +5,7 @@ from einops import rearrange
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
from sglang.srt.layers.attention.fla.fused_gdn_gating import fused_gdn_gating
from sglang.srt.layers.attention.fla.fused_recurrent import ( from sglang.srt.layers.attention.fla.fused_recurrent import (
fused_recurrent_gated_delta_rule_update, fused_recurrent_gated_delta_rule_update,
) )
...@@ -30,7 +31,6 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -30,7 +31,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.models.qwen3_next import fused_gdn_gating
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import is_cuda, is_npu from sglang.srt.utils import is_cuda, is_npu
...@@ -697,11 +697,7 @@ class GDNAttnBackend(MambaAttnBackendBase): ...@@ -697,11 +697,7 @@ class GDNAttnBackend(MambaAttnBackendBase):
key = key.view(1, actual_seq_len, num_heads, head_k_dim) key = key.view(1, actual_seq_len, num_heads, head_k_dim)
value = value.view(1, actual_seq_len, num_value_heads, head_v_dim) value = value.view(1, actual_seq_len, num_value_heads, head_v_dim)
beta = b.sigmoid() g, beta = fused_gdn_gating(A_log, a, b, dt_bias)
g = fused_gdn_gating(A_log, a, dt_bias)
g = g.unsqueeze(0)
beta = beta.unsqueeze(0)
if is_target_verify: if is_target_verify:
core_attn_out = fused_recurrent_gated_delta_rule_update( core_attn_out = fused_recurrent_gated_delta_rule_update(
......
...@@ -190,51 +190,6 @@ def fused_qkvzba_split_reshape_cat( ...@@ -190,51 +190,6 @@ def fused_qkvzba_split_reshape_cat(
return mixed_qkv, z, b, a return mixed_qkv, z, b, a
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
@triton.jit
def fused_gdn_gating_kernel(
g,
A_log,
a,
dt_bias,
seq_len,
NUM_HEADS: tl.constexpr,
beta: tl.constexpr,
threshold: tl.constexpr,
BLK_HEADS: tl.constexpr,
):
i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
mask = head_off < NUM_HEADS
blk_A_log = tl.load(A_log + head_off, mask=mask)
blk_a = tl.load(a + off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=mask)
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
softplus_x = tl.where(
beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
)
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
batch, num_heads = a.shape
seq_len = 1
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty_like(a, dtype=torch.float32)
fused_gdn_gating_kernel[grid](
g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
)
return g
class Qwen3GatedDeltaNet(nn.Module): class Qwen3GatedDeltaNet(nn.Module):
def __init__( def __init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment