Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
...@@ -456,6 +456,34 @@ class CoreAttention(MegatronModule): ...@@ -456,6 +456,34 @@ class CoreAttention(MegatronModule):
return context_layer return context_layer
class FlashSelfAttentionTorch(torch.nn.Module):
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
assert flash_attn_func is not None, ('Triton version of FlashAttention is not installed.')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.attention_dropout = attention_dropout
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda
if os.environ.get('USE_BSHD',None):
q, k, v = [rearrange(x, 's b h d -> b s h d').contiguous()
for x in (q, k, v)]
else:
q, k, v = [rearrange(x, 's b h d -> b h s d').contiguous()
for x in (q, k, v)]
output = SDPA(q, k, v, is_causal=self.causal, dropout_p=self.attention_dropout, scale=self.softmax_scale)
if os.environ.get('USE_BSHD',None):
output = rearrange(output, 'b s h d -> s b (h d)').contiguous()
else:
output = rearrange(output, 'b h s d -> s b (h d)').contiguous()
return output
class FlashSelfAttention(torch.nn.Module): class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax. """Implement the scaled dot product attention with softmax.
...@@ -582,10 +610,11 @@ class ParallelAttention(MegatronModule): ...@@ -582,10 +610,11 @@ class ParallelAttention(MegatronModule):
else: else:
kv_projection_size = args.kv_channels * args.num_attention_heads kv_projection_size = args.kv_channels * args.num_attention_heads
self.use_flash_attn = (args.use_flash_attn_cutlass or args.use_flash_attn_triton) \ self.use_flash_attn = (args.use_flash_attn_cutlass or args.use_flash_attn_triton or args.use_flash_attn_torch) \
and attention_type == AttnType.self_attn \ and attention_type == AttnType.self_attn \
and self.attn_mask_type == AttnMaskType.causal and self.attn_mask_type == AttnMaskType.causal
self.use_flash_attn_triton = args.use_flash_attn_triton self.use_flash_attn_triton = args.use_flash_attn_triton
self.use_flash_attn_torch = args.use_flash_attn_torch
if self.use_flash_attn: if self.use_flash_attn:
if args.use_flash_attn_cutlass: if args.use_flash_attn_cutlass:
...@@ -658,6 +687,8 @@ class ParallelAttention(MegatronModule): ...@@ -658,6 +687,8 @@ class ParallelAttention(MegatronModule):
self.core_attention_flash = FlashSelfAttentionTriton( self.core_attention_flash = FlashSelfAttentionTriton(
causal=True, attention_dropout=args.attention_dropout causal=True, attention_dropout=args.attention_dropout
) )
elif self.use_flash_attn_torch:
self.core_attention_flash = FlashSelfAttentionTorch(causal=True, attention_dropout=config.attention_dropout)
elif self.use_flash_attn: elif self.use_flash_attn:
self.core_attention_flash = FlashSelfAttention( self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=config.attention_dropout causal=True, attention_dropout=config.attention_dropout
...@@ -871,7 +902,7 @@ class ParallelAttention(MegatronModule): ...@@ -871,7 +902,7 @@ class ParallelAttention(MegatronModule):
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask) query_layer, key_layer, value_layer, attention_mask)
else: else:
if not self.use_flash_attn_triton: if not self.use_flash_attn_triton and not self.use_flash_attn_torch:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)] for x in (query_layer, key_layer, value_layer)]
...@@ -881,7 +912,7 @@ class ParallelAttention(MegatronModule): ...@@ -881,7 +912,7 @@ class ParallelAttention(MegatronModule):
else: else:
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer) context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
if not self.use_flash_attn_triton: if not self.use_flash_attn_triton and not self.use_flash_attn_torch:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
# ================= # =================
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment