Unverified Commit db71c38f authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

Scale kkt after reduction (#10604)

parent 7a68b422
......@@ -74,8 +74,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_beta[:, None]
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
b_A += tl.dot(b_k, tl.trans(b_k))
if USE_G:
p_g = tl.make_block_ptr(
......@@ -85,6 +84,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
b_g_diff = b_g[:, None] - b_g[None, :]
b_A = b_A * safe_exp(b_g_diff)
b_A *= b_beta[:, None]
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment