Commit 1f376414 authored by dongcl's avatar dongcl
Browse files

merge megatron_v0.11.0

parents eefee67c c140c914
...@@ -165,13 +165,14 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -165,13 +165,14 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init_wrapper
# VocabParallelEmbedding # VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
vocab_parallel_embedding_forward) vocab_parallel_embedding_forward)
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init) vocab_parallel_embedding_init_wrapper,
apply_wrapper=True)
# VocabParallelCrossEntropy # VocabParallelCrossEntropy
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits', MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits',
...@@ -233,13 +234,20 @@ class LegacyAdaptation(MegatronAdaptationABC): ...@@ -233,13 +234,20 @@ class LegacyAdaptation(MegatronAdaptationABC):
self.patch_legacy_models() self.patch_legacy_models()
def patch_legacy_models(self): def patch_legacy_models(self):
from ..legacy.model.transformer import ParallelMLPPatch, ParallelAttentionPatch from ..legacy.model.transformer import (
parallel_mlp_init_wrapper,
ParallelAttentionPatch,
parallel_attention_init_wrapper
)
from ..legacy.model.utils import get_norm from ..legacy.model.utils import get_norm
# ParallecMLP # ParallecMLP
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelMLP.__init__', MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelMLP.__init__',
ParallelMLPPatch.__init__) parallel_mlp_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.__init__',
parallel_attention_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.forward', MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.forward',
ParallelAttentionPatch.forward) ParallelAttentionPatch.forward)
......
import os import os
import copy
import socket import socket
import warnings import warnings
from functools import wraps from functools import wraps
from typing import Callable, List, Optional from typing import Callable, List, Optional
try: if int(os.getenv("USE_FLUX_OVERLAP", "0"))
import flux try:
except ImportError: import flux
raise ImportError("flux is NOT installed") from dcu_megatron.core.utils import is_flux_min_version
except ImportError:
raise ImportError("flux is NOT installed")
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -53,6 +56,7 @@ from megatron.core.tensor_parallel.layers import ( ...@@ -53,6 +56,7 @@ from megatron.core.tensor_parallel.layers import (
from dcu_megatron.core.utils import is_flux_min_version from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True _grad_accum_fusion_available = True
try: try:
import fused_weight_gradient_mlp_cuda import fused_weight_gradient_mlp_cuda
...@@ -60,64 +64,29 @@ except ImportError: ...@@ -60,64 +64,29 @@ except ImportError:
_grad_accum_fusion_available = False _grad_accum_fusion_available = False
def vocab_parallel_embedding_init( def vocab_parallel_embedding_init_wrapper(fn):
self, @wraps(fn)
num_embeddings: int, def wrapper(self,
embedding_dim: int, *args,
*, skip_weight_param_allocation: bool = False,
init_method: Callable, **kwargs
reduce_scatter_embeddings: bool = False, ):
config: ModelParallelConfig,
skip_weight_param_allocation: bool = False if (
): skip_weight_param_allocation
super(VocabParallelEmbedding, self).__init__() and "config" in kwargs
# Keep the input dimensions. and hasattr(kwargs["config"], "perform_initialization")
self.num_embeddings = num_embeddings ):
self.embedding_dim = embedding_dim config = copy.deepcopy(kwargs["config"])
self.reduce_scatter_embeddings = reduce_scatter_embeddings config.perform_initialization = False
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() kwargs["config"] = config
# Divide the weight matrix along the vocaburaly dimension.
(self.vocab_start_index, self.vocab_end_index) = ( fn(self, *args, **kwargs)
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, if skip_weight_param_allocation:
get_tensor_model_parallel_rank(), self.weight = None
self.tensor_model_parallel_size,
) return wrapper
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
self.deterministic_mode = config.deterministic_mode
# Allocate weights and initialize.
if not skip_weight_param_allocation:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=config.params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
else:
self.weight = None
@torch.compile(mode='max-autotune-no-cudagraphs') @torch.compile(mode='max-autotune-no-cudagraphs')
......
...@@ -16,6 +16,7 @@ from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross ...@@ -16,6 +16,7 @@ from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross
from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module
from ...tensor_parallel.random import CheckpointWithoutOutput from ...tensor_parallel.random import CheckpointWithoutOutput
from ...tensor_parallel import FluxColumnParallelLinear
@dataclass @dataclass
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch import torch
from torch import nn from typing import Optional
import lightop
class RMSNorm(torch.nn.Module):
def __init__(self,
dim: int,
eps: float = 1e-6,
sequence_parallel: bool = False,
config: dict = None):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
sequence_parallel (bool): Set to true if sequence parallelism is being used,
this marks the weights as needing to be allreduced.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
setattr(self.weight, 'sequence_parallel', sequence_parallel)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
import torch
from typing import Any, Callable, Dict, Optional, Tuple, Union
import lightop # rmsnorm_forward,rmsnorm_backward
from functools import partial from functools import partial
from megatron.core.utils import is_torch_min_version from megatron.core.utils import is_torch_min_version
......
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import wraps
from megatron.training import get_args from megatron.training import get_args
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.legacy.model.enums import AttnType from megatron.legacy.model.enums import AttnType
from megatron.core.models.common.embeddings import apply_rotary_pos_emb from megatron.core.models.common.embeddings import apply_rotary_pos_emb
from megatron.legacy.model.module import MegatronModule from megatron.legacy.model.module import MegatronModule
from megatron.legacy.model.transformer import ParallelMLP
from megatron.legacy.model.utils import (
erf_gelu,
openai_gelu,
)
try: try:
from einops import rearrange from einops import rearrange
except ImportError: except ImportError:
rearrange = None rearrange = None
class ParallelMLPPatch(MegatronModule): try: # 使用定长fa
"""MLP. from flash_attn import flash_attn_func
except ImportError:
flash_attn_func = None
MLP will take the input with h hidden state, project it to 4*h try:
hidden dimension, perform nonlinear transformation, and project the from einops import rearrange
state back into h hidden dimension. except ImportError:
""" rearrange = None
def __init__(self, config, is_expert=False):
super(ParallelMLP, self).__init__()
args = get_args() def parallel_mlp_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
self.add_bias = config.add_bias_linear args = get_args()
if args.swiglu:
ffn_hidden_size = config.ffn_hidden_size
if config.gated_linear_unit:
ffn_hidden_size *= 2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
ffn_hidden_size,
config=config,
init_method=config.init_method,
bias=self.add_bias,
gather_output=False,
skip_bias_add=True,
is_expert=is_expert,
)
self.bias_gelu_fusion = False
self.activation_func = None
self.swiglu = args.swiglu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
elif args.swiglu:
@torch.compile(mode="max-autotune-no-cudagraphs") @torch.compile(mode="max-autotune-no-cudagraphs")
def swiglu(x): def swiglu(x):
x = torch.chunk(x, 2, dim=-1) x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1] return F.silu(x[0]) * x[1]
self.activation_func = swiglu self.activation_func = swiglu
elif args.squared_relu:
def squared_relu(x): return wrapper
return torch.pow(F.relu(x), 2)
self.activation_func = squared_relu
else: class FlashFixedSelfAttention(torch.nn.Module):
self.bias_gelu_fusion = args.bias_gelu_fusion """Implement the scaled dot product attention with softmax.
self.activation_func = F.gelu Arguments
---------
# Project back to h. softmax_scale: The temperature to use for the softmax attention.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear( (default: 1/sqrt(d_keys) where d_keys is computed at
config.ffn_hidden_size, runtime)
config.hidden_size, attention_dropout: The dropout rate to apply to the attention
config=config, (default: 0.0)
init_method=config.output_layer_init_method, """
bias=self.add_bias, def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
skip_bias_add=True, device=None, dtype=None):
input_is_parallel=True, super().__init__()
is_expert=is_expert, assert flash_attn_func is not None, ('Please install FlashAttention first, '
) 'e.g., with pip install flash-attn')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.flash_attn_func = flash_attn_func
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((i.is_cuda for i in (q,k,v)))
output = self.flash_attn_func(q, k, v, dropout_p=self.dropout_p, softmax_scale=self.softmax_scale, causal=self.causal)
# [b,s,a,dim]
return output
def parallel_attention_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if self.use_flash_attn:
self.core_attention_flash = FlashFixedSelfAttention(
causal=True, attention_dropout=self.config.attention_dropout
)
return wrapper
class ParallelAttentionPatch(MegatronModule): class ParallelAttentionPatch(MegatronModule):
...@@ -87,6 +97,7 @@ class ParallelAttentionPatch(MegatronModule): ...@@ -87,6 +97,7 @@ class ParallelAttentionPatch(MegatronModule):
Self-attention layer takes input with size [s, b, h] Self-attention layer takes input with size [s, b, h]
and returns output of the same size. and returns output of the same size.
""" """
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None, encoder_output=None, inference_params=None,
rotary_pos_emb=None): rotary_pos_emb=None):
......
from megatron.training import get_args from megatron.training import get_args
from megatron.legacy.model import LayerNorm from megatron.legacy.model import LayerNorm, RMSNorm
from .rms_norm import RMSNorm, LightopRMSNorm from .rms_norm import LightopRMSNorm
def get_norm(config): def get_norm(config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment