Commit 2ed12105 authored by wuxk1's avatar wuxk1
Browse files

add torch fa support

parent 031a4157
...@@ -456,6 +456,34 @@ class CoreAttention(MegatronModule): ...@@ -456,6 +456,34 @@ class CoreAttention(MegatronModule):
return context_layer return context_layer
class FlashSelfAttentionTorch(torch.nn.Module):
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
assert flash_attn_func is not None, ('Triton version of FlashAttention is not installed.')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.attention_dropout = attention_dropout
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda
if os.environ.get('USE_BSHD',None):
q, k, v = [rearrange(x, 's b h d -> b s h d').contiguous()
for x in (q, k, v)]
else:
q, k, v = [rearrange(x, 's b h d -> b h s d').contiguous()
for x in (q, k, v)]
output = SDPA(q, k, v, is_causal=self.causal, dropout_p=self.attention_dropout, scale=self.softmax_scale)
if os.environ.get('USE_BSHD',None):
output = rearrange(output, 'b s h d -> s b (h d)').contiguous()
else:
output = rearrange(output, 'b h s d -> s b (h d)').contiguous()
return output
class FlashSelfAttention(torch.nn.Module): class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax. """Implement the scaled dot product attention with softmax.
...@@ -582,10 +610,11 @@ class ParallelAttention(MegatronModule): ...@@ -582,10 +610,11 @@ class ParallelAttention(MegatronModule):
else: else:
kv_projection_size = args.kv_channels * args.num_attention_heads kv_projection_size = args.kv_channels * args.num_attention_heads
self.use_flash_attn = (args.use_flash_attn_cutlass or args.use_flash_attn_triton) \ self.use_flash_attn = (args.use_flash_attn_cutlass or args.use_flash_attn_triton or args.use_flash_attn_torch) \
and attention_type == AttnType.self_attn \ and attention_type == AttnType.self_attn \
and self.attn_mask_type == AttnMaskType.causal and self.attn_mask_type == AttnMaskType.causal
self.use_flash_attn_triton = args.use_flash_attn_triton self.use_flash_attn_triton = args.use_flash_attn_triton
self.use_flash_attn_torch = args.use_flash_attn_torch
if self.use_flash_attn: if self.use_flash_attn:
if args.use_flash_attn_cutlass: if args.use_flash_attn_cutlass:
...@@ -658,6 +687,8 @@ class ParallelAttention(MegatronModule): ...@@ -658,6 +687,8 @@ class ParallelAttention(MegatronModule):
self.core_attention_flash = FlashSelfAttentionTriton( self.core_attention_flash = FlashSelfAttentionTriton(
causal=True, attention_dropout=args.attention_dropout causal=True, attention_dropout=args.attention_dropout
) )
elif self.use_flash_attn_torch:
self.core_attention_flash = FlashSelfAttentionTorch(causal=True, attention_dropout=config.attention_dropout)
elif self.use_flash_attn: elif self.use_flash_attn:
self.core_attention_flash = FlashSelfAttention( self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=config.attention_dropout causal=True, attention_dropout=config.attention_dropout
...@@ -871,7 +902,7 @@ class ParallelAttention(MegatronModule): ...@@ -871,7 +902,7 @@ class ParallelAttention(MegatronModule):
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask) query_layer, key_layer, value_layer, attention_mask)
else: else:
if not self.use_flash_attn_triton: if not self.use_flash_attn_triton and not self.use_flash_attn_torch:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)] for x in (query_layer, key_layer, value_layer)]
...@@ -881,7 +912,7 @@ class ParallelAttention(MegatronModule): ...@@ -881,7 +912,7 @@ class ParallelAttention(MegatronModule):
else: else:
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer) context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
if not self.use_flash_attn_triton: if not self.use_flash_attn_triton and not self.use_flash_attn_torch:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
# ================= # =================
......
...@@ -654,7 +654,7 @@ def validate_args(args, defaults={}): ...@@ -654,7 +654,7 @@ def validate_args(args, defaults={}):
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.' '--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
# FlashAttention # FlashAttention
args.use_flash_attn = args.use_flash_attn_cutlass or args.use_flash_attn_triton args.use_flash_attn = args.use_flash_attn_cutlass or args.use_flash_attn_triton or args.use_flash_attn_torch
# Legacy RoPE arguments # Legacy RoPE arguments
if args.use_rotary_position_embeddings: if args.use_rotary_position_embeddings:
...@@ -1377,6 +1377,8 @@ def _add_training_args(parser): ...@@ -1377,6 +1377,8 @@ def _add_training_args(parser):
group.add_argument('--use-flash-attn-cutlass', action='store_true', group.add_argument('--use-flash-attn-cutlass', action='store_true',
help='use FlashAttention implementation of attention. ' help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135') 'https://arxiv.org/abs/2205.14135')
group.add_argument('--use-flash-attn-torch', action='store_true',
help='use FlashAttention implementation of attention using torch.')
group.add_argument('--use-flash-attn-triton', action='store_true', group.add_argument('--use-flash-attn-triton', action='store_true',
help='use FlashAttention implementation of attention using Triton.') help='use FlashAttention implementation of attention using Triton.')
group.add_argument('--disable-bias-linear', action='store_false', group.add_argument('--disable-bias-linear', action='store_false',
......
export TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1
export TORCHINDUCTOR_BENCHMARK_FUSION=1
export TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES=1
# export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export TORCHINDUCTOR_MAX_AUTOTUNE=1
#export FLASH_ATTENTION_PRINT_PARAM=1
export TORCHINDUCTOR_CACHE_DIR=./cache
# export USE_AOTRITON_FA=1
# export USE_BSHD=1 # use fa bsdh layout
#for uniq kernel name
#export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
mpirun --allow-run-as-root -np 8 ./Llama_pretraining.sh localhost
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment