"pcdet/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "a38a580f9ccf96c116269280fa5f6f721aa8bbcd"
Commit 0d148a7d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add FlashAttention support to msa.py

parent bcb0b70f
...@@ -89,12 +89,14 @@ class MSAAttention(nn.Module): ...@@ -89,12 +89,14 @@ class MSAAttention(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def _chunk(self, def _chunk(self,
m: torch.Tensor, m: torch.Tensor,
biases: List[torch.Tensor], biases: Optional[List[torch.Tensor]],
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool, use_memory_efficient_kernel: bool,
use_lma: bool, use_lma: bool,
use_flash: bool,
flash_mask: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
def fn(m, biases): def fn(m, biases, flash_mask):
m = self.layer_norm_m(m) m = self.layer_norm_m(m)
return self.mha( return self.mha(
q_x=m, q_x=m,
...@@ -102,14 +104,23 @@ class MSAAttention(nn.Module): ...@@ -102,14 +104,23 @@ class MSAAttention(nn.Module):
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash,
flash_mask=flash_mask,
) )
inputs = {"m": m}
if(biases is not None):
inputs["biases"] = biases
else:
fn = partial(fn, biases=None)
if(use_flash and flash_mask is not None):
inputs["flash_mask"] = flash_mask
else:
fn = partial(fn, flash_mask=None)
return chunk_layer( return chunk_layer(
fn, fn,
{ inputs,
"m": m,
"biases": biases,
},
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]) no_batch_dims=len(m.shape[:-2])
) )
...@@ -175,6 +186,7 @@ class MSAAttention(nn.Module): ...@@ -175,6 +186,7 @@ class MSAAttention(nn.Module):
m, mask_bias, z = self._prep_inputs( m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe m, z, mask, inplace_safe=inplace_safe
) )
m = self.layer_norm_m(m)
q, k, v = self.mha._prep_qkv(m, m) q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z return m, q, k, v, mask_bias, z
...@@ -210,6 +222,7 @@ class MSAAttention(nn.Module): ...@@ -210,6 +222,7 @@ class MSAAttention(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_chunk_logits: Optional[int] = None, _chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None, _checkpoint_chunks: Optional[bool] = None,
...@@ -237,6 +250,10 @@ class MSAAttention(nn.Module): ...@@ -237,6 +250,10 @@ class MSAAttention(nn.Module):
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
if(use_flash):
assert z is None
biases = None
else:
m, mask_bias, z = self._prep_inputs( m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe m, z, mask, inplace_safe=inplace_safe
) )
...@@ -252,6 +269,8 @@ class MSAAttention(nn.Module): ...@@ -252,6 +269,8 @@ class MSAAttention(nn.Module):
chunk_size, chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
) )
else: else:
m = self.layer_norm_m(m) m = self.layer_norm_m(m)
...@@ -261,6 +280,8 @@ class MSAAttention(nn.Module): ...@@ -261,6 +280,8 @@ class MSAAttention(nn.Module):
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
) )
return m return m
...@@ -336,6 +357,7 @@ class MSAColumnAttention(nn.Module): ...@@ -336,6 +357,7 @@ class MSAColumnAttention(nn.Module):
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -353,7 +375,13 @@ class MSAColumnAttention(nn.Module): ...@@ -353,7 +375,13 @@ class MSAColumnAttention(nn.Module):
if mask is not None: if mask is not None:
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size, use_lma=use_lma) m = self._msa_att(
m,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment