Commit 6b5f271c authored by Tri Dao's avatar Tri Dao
Browse files

[Triton] Avoid einops repeat by using Tensor.expand

parent 88c4e5db
......@@ -38,8 +38,6 @@ import math
import torch
from einops import rearrange, repeat
import triton
import triton.language as tl
......@@ -605,11 +603,7 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
else:
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)')
if bias.shape[:2] == (1, nheads):
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
elif bias.shape[:2] == (batch, 1):
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
......@@ -684,11 +678,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
else:
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)')
if bias.shape[:2] == (1, nheads):
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
elif bias.shape[:2] == (batch, 1):
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
# BLOCK_M = 128
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment