"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "edf22c052e0d91eca4687ee678b06a485f78666d"
Unverified Commit db71c38f authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

Scale kkt after reduction (#10604)

parent 7a68b422
...@@ -74,8 +74,7 @@ def chunk_scaled_dot_kkt_fwd_kernel( ...@@ -74,8 +74,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
(1, 0), (1, 0),
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_beta[:, None] b_A += tl.dot(b_k, tl.trans(b_k))
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
if USE_G: if USE_G:
p_g = tl.make_block_ptr( p_g = tl.make_block_ptr(
...@@ -85,6 +84,7 @@ def chunk_scaled_dot_kkt_fwd_kernel( ...@@ -85,6 +84,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
b_g_diff = b_g[:, None] - b_g[None, :] b_g_diff = b_g[:, None] - b_g[None, :]
b_A = b_A * safe_exp(b_g_diff) b_A = b_A * safe_exp(b_g_diff)
b_A *= b_beta[:, None]
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0) b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
p_A = tl.make_block_ptr( p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0) A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment