Unverified Commit c4768dcf authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[Kernel] Fix fused_gdn_gating (#28343)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
parent a65a934e
...@@ -1367,8 +1367,10 @@ def fused_gdn_gating_kernel( ...@@ -1367,8 +1367,10 @@ def fused_gdn_gating_kernel(
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
# compute beta_output = sigmoid(b) # compute beta_output = sigmoid(b)
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32))) blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask) tl.store(
beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
)
def fused_gdn_gating( def fused_gdn_gating(
...@@ -1389,7 +1391,7 @@ def fused_gdn_gating( ...@@ -1389,7 +1391,7 @@ def fused_gdn_gating(
seq_len = 1 seq_len = 1
grid = (batch, seq_len, triton.cdiv(num_heads, 8)) grid = (batch, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device) beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
fused_gdn_gating_kernel[grid]( fused_gdn_gating_kernel[grid](
g, g,
beta_output, beta_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment