"megatron/model/vscode:/vscode.git/clone" did not exist on "b8428a7fa5003a312d47173599400cbce980b14c"
Commit 908a5b22 authored by Tri Dao's avatar Tri Dao
Browse files

Set num_warps=4 for headdim=64 in Triton fw (h/t Michael Benesty)

parent 74797571
......@@ -44,14 +44,15 @@ import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
# This config has a race condition when EVEN_M == False, disabling it for now.
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
],
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
)
# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
# # This config has a race condition when EVEN_M == False, disabling it for now.
# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
# ],
# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
# )
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
......@@ -617,8 +618,8 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
o = torch.empty_like(q)
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
# BLOCK = 128
# num_warps = 4 if d <= 64 else 8
BLOCK = 128
num_warps = 4 if d <= 64 else 8
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_fwd_kernel[grid](
q, k, v, bias, o,
......@@ -634,9 +635,9 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type, causal, BLOCK_HEADDIM,
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
# num_warps=num_warps,
# num_stages=1,
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return o, lse, softmax_scale # softmax_scale could have been updated
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment