Commit 0d148a7d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add FlashAttention support to msa.py

parent bcb0b70f
......@@ -89,12 +89,14 @@ class MSAAttention(nn.Module):
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
biases: List[torch.Tensor],
biases: Optional[List[torch.Tensor]],
chunk_size: int,
use_memory_efficient_kernel: bool,
use_lma: bool,
use_flash: bool,
flash_mask: Optional[torch.Tensor],
) -> torch.Tensor:
def fn(m, biases):
def fn(m, biases, flash_mask):
m = self.layer_norm_m(m)
return self.mha(
q_x=m,
......@@ -102,14 +104,23 @@ class MSAAttention(nn.Module):
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=flash_mask,
)
inputs = {"m": m}
if(biases is not None):
inputs["biases"] = biases
else:
fn = partial(fn, biases=None)
if(use_flash and flash_mask is not None):
inputs["flash_mask"] = flash_mask
else:
fn = partial(fn, flash_mask=None)
return chunk_layer(
fn,
{
"m": m,
"biases": biases,
},
inputs,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2])
)
......@@ -175,6 +186,7 @@ class MSAAttention(nn.Module):
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
m = self.layer_norm_m(m)
q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z
......@@ -210,6 +222,7 @@ class MSAAttention(nn.Module):
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
......@@ -235,15 +248,19 @@ class MSAAttention(nn.Module):
chunk_logits=_chunk_logits,
checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
)
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias]
if(z is not None):
biases.append(z)
)
if(use_flash):
assert z is None
biases = None
else:
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias]
if(z is not None):
biases.append(z)
if chunk_size is not None:
m = self._chunk(
......@@ -252,6 +269,8 @@ class MSAAttention(nn.Module):
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
else:
m = self.layer_norm_m(m)
......@@ -261,6 +280,8 @@ class MSAAttention(nn.Module):
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
return m
......@@ -336,6 +357,7 @@ class MSAColumnAttention(nn.Module):
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_lma: bool = False,
use_flash: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -353,7 +375,13 @@ class MSAColumnAttention(nn.Module):
if mask is not None:
mask = mask.transpose(-1, -2)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size, use_lma=use_lma)
m = self._msa_att(
m,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment