Commit 8a13fabe authored by wxj's avatar wxj
Browse files

Merge branch 'main' into 'main'

添加优化项, 添加qwen和llama3

See merge request !5
parents f5ca0d94 425a2473
Pipeline #2076 failed with stages
in 0 seconds
......@@ -28,12 +28,12 @@ TENSORBOARD_LOGS_PATH=./tmp_7b #$2 #<Specify path>
DATA_PATH="/datasets/oscar-1GB-llama_text_document" #<Specify path and file prefix>_text_document
GPT_MODEL_ARGS=(
--num-layers 6
--hidden-size 1024
--ffn-hidden-size 2048
--num-attention-heads 16
--num-layers 36
--hidden-size 4096
--ffn-hidden-size 11008
--num-attention-heads 32
--seq-length 4096 #4096
--max-position-embeddings 32768
--max-position-embeddings 4096
)
# export NVTE_FLASH_ATTN=1 # 走autlass
......@@ -69,7 +69,10 @@ TRAINING_ARGS=(
--lr-decay-style cosine
--min-lr 3.0e-6
--lr-warmup-iters 1
--use-flash-attn-triton
)
# --use-flash-attn-ck
# --use-flash-attn-triton
MODEL_PARALLEL_ARGS=(
--sequence-parallel
......
......@@ -2,6 +2,8 @@
import torch
from torch import nn
import torch._dynamo
torch._dynamo.config.suppress_errors = True
class RMSNorm(torch.nn.Module):
......@@ -24,9 +26,11 @@ class RMSNorm(torch.nn.Module):
setattr(self.weight, 'sequence_parallel', sequence_parallel)
@torch.compile(mode="max-autotune-no-cudagraphs")
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
......@@ -40,6 +40,9 @@ from megatron.legacy.model.utils import (
)
from megatron.training import get_args, get_timers
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from .module import MegatronModule
try:
......@@ -57,6 +60,10 @@ except ImportError:
except ImportError:
flash_attn_unpadded_func = None
try:
from flash_attn.flash_attn_triton import flash_attn_func
except ImportError:
flash_attn_func = None
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
......@@ -133,6 +140,7 @@ class ParallelMLP(MegatronModule):
elif args.onnx_safe:
self.activation_func = erf_gelu
elif args.swiglu:
@torch.compile(mode="max-autotune-no-cudagraphs")
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
......@@ -157,6 +165,7 @@ class ParallelMLP(MegatronModule):
is_expert=is_expert,
)
@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, hidden_states):
# [s, b, 4hp]
......@@ -468,6 +477,10 @@ class FlashSelfAttention(torch.nn.Module):
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
# Use FlashAttention-2 when args.use_flash_attn_ck is True
args = get_args()
self.flash_attn_func = flash_attn_unpadded_func
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
......@@ -509,6 +522,38 @@ class FlashSelfAttention(torch.nn.Module):
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
class FlashSelfAttentionTriton(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_func is not None, ('Triton version of FlashAttention is not installed.')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda
q, k, v = [rearrange(x, 's b h d -> b h s d').contiguous()
for x in (q, k, v)]
output = flash_attn_func(q, k, v, self.causal)
output = rearrange(output, 'b s h d -> h b (s d)').contiguous()
return output
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
......@@ -537,13 +582,19 @@ class ParallelAttention(MegatronModule):
else:
kv_projection_size = args.kv_channels * args.num_attention_heads
self.use_flash_attn = args.use_flash_attn \
self.use_flash_attn = (args.use_flash_attn_ck or args.use_flash_attn_triton) \
and attention_type == AttnType.self_attn \
and self.attn_mask_type == AttnMaskType.causal
self.use_flash_attn_triton = args.use_flash_attn_triton
if self.use_flash_attn:
if flash_attn_unpadded_func is None:
raise ImportError('FlashAttention is not installed, please install with '
'pip install flash-attn')
if args.use_flash_attn_ck:
if flash_attn_unpadded_func is None:
raise ImportError('FlashAttention is not installed, please install with '
'pip install flash-attn')
if args.use_flash_attn_triton:
assert flash_attn_func != None, "Cannot import FlashAttention triton "
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
'self-attention for now')
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
......@@ -603,7 +654,11 @@ class ParallelAttention(MegatronModule):
self.attn_mask_type)
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
if self.use_flash_attn:
if self.use_flash_attn_triton:
self.core_attention_flash = FlashSelfAttentionTriton(
causal=True, attention_dropout=args.attention_dropout
)
elif self.use_flash_attn:
self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=config.attention_dropout
)
......@@ -711,7 +766,7 @@ class ParallelAttention(MegatronModule):
dim=3)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
query_layer = query_layer.contiguous().view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
......@@ -816,14 +871,18 @@ class ParallelAttention(MegatronModule):
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
else:
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
if not self.use_flash_attn_triton:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(q, k, v)
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
else:
context_layer = self.core_attention_flash(q, k, v)
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
if not self.use_flash_attn_triton:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
# =================
# Output. [sq, b, h]
......
......@@ -9,6 +9,8 @@ import torch
from megatron.training import get_args
from megatron.legacy.model import LayerNorm, RMSNorm
from megatron.core.jit import jit_fuser
import torch._dynamo
torch._dynamo.config.suppress_errors = True
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
......@@ -58,7 +60,7 @@ def openai_gelu(x):
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
@torch.compile(mode="max-autotune-no-cudagraphs")
def get_norm(config):
args = get_args()
if args.normalization == "LayerNorm":
......
......@@ -642,6 +642,9 @@ def validate_args(args, defaults={}):
assert not args.use_legacy_models, \
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
# FlashAttention
args.use_flash_attn = args.use_flash_attn_ck or args.use_flash_attn_triton
# Legacy RoPE arguments
if args.use_rotary_position_embeddings:
args.position_embedding_type = 'rope'
......@@ -1355,9 +1358,11 @@ def _add_training_args(parser):
group.add_argument('--cross-entropy-loss-fusion', action='store_true',
help='Enabled fusion of cross entropy loss calculation.',
dest='cross_entropy_loss_fusion')
group.add_argument('--use-flash-attn', action='store_true',
group.add_argument('--use-flash-attn-ck', action='store_true',
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--use-flash-attn-triton', action='store_true',
help='use FlashAttention implementation of attention using Triton.')
group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers',
dest='add_bias_linear')
......@@ -1824,6 +1829,8 @@ def _add_tokenizer_args(parser):
'GPTSentencePieceTokenizer',
'HuggingFaceTokenizer',
'Llama2Tokenizer',
'Llama3Tokenizer',
'QwenTokenizer',
'TikTokenizer',
'MultimodalTokenizer',
'NullTokenizer'],
......
......@@ -15,6 +15,7 @@ from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer
from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer
from transformers import Qwen2Tokenizer
def build_tokenizer(args, **kwargs):
......@@ -50,6 +51,11 @@ def build_tokenizer(args, **kwargs):
elif args.tokenizer_type == 'Llama2Tokenizer':
assert args.tokenizer_model is not None
tokenizer = _Llama2Tokenizer(args.tokenizer_model)
elif args.tokenizer_type == 'Llama3Tokenizer':
assert args.tokenizer_model is not None
tokenizer = _Llama3Tokenizer(args.tokenizer_model)
elif args.tokenizer_type == 'QwenTokenizer':
tokenizer = _Qwen2Tokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == 'TikTokenizer':
assert args.tokenizer_model is not None
assert args.tiktoken_pattern is not None
......@@ -606,6 +612,96 @@ class _Llama2Tokenizer(_SentencePieceTokenizer):
return None
class _Llama3Tokenizer(MegatronTokenizer):
"""tiktokenTokenizer-Megatron llama3 改写"""
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py
def __init__(self, model_file):
super().__init__(model_file)
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
tokenizer_path=model_file
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [f"<|reserved_special_token_{i}|>" for i in range (5, 256 - 5)]
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
self.tokenizer = tiktoken.Encoding(tokenizer_path,
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
mergeable_ranks=mergeable_ranks,
special_tokens={token: len (mergeable_ranks) + i for i, token in enumerate (special_tokens)},
)
self.eod_id = self.tokenizer.encode("<|end_of_text|>", allowed_special="all")[0]
@property
def vocab_size(self):
return self.tokenizer.n_vocab
@property
def vocab(self):
return self.tokenizer.encode
@property
def inv_vocab(self):
return self.tokenizer.encode
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.encode(token_ids)
@property
def eod(self):
return self.eod_id
class _Qwen2Tokenizer(MegatronTokenizer):
def __init__(self, vocab_file, merge_file,extra_vocab_size=0):
super().__init__(vocab_file, merge_file)
self.tokenizer = Qwen2Tokenizer(vocab_file, merge_file)
self.extra_vocab_size = extra_vocab_size
self.tokenizer.add_special_tokens(special_tokens_dict=dict(pad_token="<|extra_0|>"))
@property
def vocab_size(self):
return len(self.tokenizer.encoder) + self.extra_vocab_size
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.tokenizer.eos_token_id
@property
def eos_token(self):
return self.tokenizer.eos_token
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
def reload_mergeable_ranks(path: str, max_vocab: Optional[int] = None) -> Dict[bytes, int]:
"""
Reload our tokenizer JSON file and convert it to Tiktoken format.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment