Commit 899c20e8 authored by wangxj's avatar wangxj
Browse files

优化legacy的定长fa接口

parent 9dabea91
Pipeline #2564 passed with stage
......@@ -56,7 +56,7 @@ export cache_size_limit=64
# CHECKPOINT_PATH=./Llama-2-7b-hf-to-meg-tp1-pp2 #CHECKPOINT_PATH=./tmp_7b #
SAVE_PATH=./tmp_7b
TENSORBOARD_LOGS_PATH=./tmp_7b #$2 #<Specify path>
DATA_PATH="/public/home/wangxj/Downloads/datasets/oscar-1GB-head/oscar-1GB_head-llama2_text_document" #<Specify path and file prefix>_text_document
DATA_PATH="/public/home/wangxj/Downloads/datasets/oscar-1GB/oscar-1GB-llama2_text_document" #<Specify path and file prefix>_text_document
# DATA_PATH="/data/datasets/oscar-1GB-head/oscar-1GB_head-llama2_text_document" #<Specify path and file prefix>_text_document
GPT_MODEL_ARGS=(
......@@ -83,6 +83,8 @@ TRAINING_ARGS=(
--micro-batch-size 1
--global-batch-size 256 #256 #240 #60 #512 #64
--train-iters 50
--eval-interval 10
--eval-iters 3
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.95
......@@ -125,6 +127,7 @@ MODEL_PARALLEL_ARGS=(
--sequence-parallel
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 2
# --context-parallel-size 2
# --num-layers-per-virtual-pipeline-stage 4
# --microbatch-group-size-per-virtual-pipeline-stage 1
# --no-overlap-p2p-communication # 开启后
......@@ -143,10 +146,8 @@ EVAL_AND_LOGGING_ARGS=(
--log-interval 1
--log-throughput
--save-interval 1000
--eval-interval 1000
--save $SAVE_PATH
--load $SAVE_PATH
--eval-iters 10
--tensorboard-dir $TENSORBOARD_LOGS_PATH
)
......
......@@ -47,6 +47,11 @@ try:
except ImportError:
rearrange = None
try: # 使用定长fa
from flash_attn import flash_attn_func
except ImportError:
flash_attn_func = None
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
......@@ -510,6 +515,41 @@ class FlashSelfAttention(torch.nn.Module):
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
class FlashFixedSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_func is not None, ('Please install FlashAttention first, '
'e.g., with pip install flash-attn')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.flash_attn_func = flash_attn_func
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((i.is_cuda for i in (q,k,v)))
output = self.flash_attn_func(q, k, v, dropout_p=self.dropout_p, softmax_scale=self.softmax_scale, causal=self.causal)
# [b,s,a,dim]
return output
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
......@@ -605,7 +645,10 @@ class ParallelAttention(MegatronModule):
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
if self.use_flash_attn:
self.core_attention_flash = FlashSelfAttention(
# self.core_attention_flash = FlashSelfAttention(
# causal=True, attention_dropout=config.attention_dropout
# )
self.core_attention_flash = FlashFixedSelfAttention(
causal=True, attention_dropout=config.attention_dropout
)
......
......@@ -137,6 +137,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
rope_scaling=args.use_rope_scaling
)
print_rank_0(model)
# model = torch.compile(model, mode="max-autotune-no-cudagraphs")
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment