"utils/tests_fetcher.py" did not exist on "ee519cfef5d274f7d0c67270674523833083640d"
Commit 763941b5 authored by dongcl's avatar dongcl
Browse files

modify ParallelAttention

parent bf323343
...@@ -141,9 +141,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -141,9 +141,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}), torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True) apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}), torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
# apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True) apply_wrapper=True)
...@@ -233,13 +233,21 @@ class LegacyAdaptation(MegatronAdaptationABC): ...@@ -233,13 +233,21 @@ class LegacyAdaptation(MegatronAdaptationABC):
self.patch_legacy_models() self.patch_legacy_models()
def patch_legacy_models(self): def patch_legacy_models(self):
from ..legacy.model.transformer import ParallelMLPPatch, ParallelAttentionPatch from ..legacy.model.transformer import (
parallel_mlp_init_wrapper,
ParallelAttentionPatch,
parallel_attention_init_wrapper
)
from ..legacy.model.utils import get_norm from ..legacy.model.utils import get_norm
# ParallecMLP # ParallecMLP
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelMLP.__init__', MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelMLP.__init__',
ParallelMLPPatch.__init__) parallel_mlp_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.__init__',
parallel_attention_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.forward', MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.forward',
ParallelAttentionPatch.forward) ParallelAttentionPatch.forward)
......
...@@ -6,80 +6,83 @@ from megatron.core import tensor_parallel ...@@ -6,80 +6,83 @@ from megatron.core import tensor_parallel
from megatron.legacy.model.enums import AttnType from megatron.legacy.model.enums import AttnType
from megatron.core.models.common.embeddings import apply_rotary_pos_emb from megatron.core.models.common.embeddings import apply_rotary_pos_emb
from megatron.legacy.model.module import MegatronModule from megatron.legacy.model.module import MegatronModule
from megatron.legacy.model.transformer import ParallelMLP
from megatron.legacy.model.utils import (
erf_gelu,
openai_gelu,
)
try: try:
from einops import rearrange from einops import rearrange
except ImportError: except ImportError:
rearrange = None rearrange = None
class ParallelMLPPatch(MegatronModule): try: # 使用定长fa
"""MLP. from flash_attn import flash_attn_func
except ImportError:
MLP will take the input with h hidden state, project it to 4*h flash_attn_func = None
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config, is_expert=False):
super(ParallelMLP, self).__init__()
args = get_args()
self.add_bias = config.add_bias_linear
ffn_hidden_size = config.ffn_hidden_size
if config.gated_linear_unit:
ffn_hidden_size *= 2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
ffn_hidden_size,
config=config,
init_method=config.init_method,
bias=self.add_bias,
gather_output=False,
skip_bias_add=True,
is_expert=is_expert,
)
self.bias_gelu_fusion = False def parallel_mlp_init_wrapper(fn):
self.activation_func = None @wraps(fn)
self.swiglu = args.swiglu def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if args.openai_gelu: args = get_args()
self.activation_func = openai_gelu if args.swiglu:
elif args.onnx_safe:
self.activation_func = erf_gelu
elif args.swiglu:
@torch.compile(mode="max-autotune-no-cudagraphs") @torch.compile(mode="max-autotune-no-cudagraphs")
def swiglu(x): def swiglu(x):
x = torch.chunk(x, 2, dim=-1) x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1] return F.silu(x[0]) * x[1]
self.activation_func = swiglu self.activation_func = swiglu
elif args.squared_relu:
def squared_relu(x): return wrapper
return torch.pow(F.relu(x), 2)
self.activation_func = squared_relu
else: class FlashFixedSelfAttention(torch.nn.Module):
self.bias_gelu_fusion = args.bias_gelu_fusion """Implement the scaled dot product attention with softmax.
self.activation_func = F.gelu Arguments
---------
# Project back to h. softmax_scale: The temperature to use for the softmax attention.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear( (default: 1/sqrt(d_keys) where d_keys is computed at
config.ffn_hidden_size, runtime)
config.hidden_size, attention_dropout: The dropout rate to apply to the attention
config=config, (default: 0.0)
init_method=config.output_layer_init_method, """
bias=self.add_bias, def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
skip_bias_add=True, device=None, dtype=None):
input_is_parallel=True, super().__init__()
is_expert=is_expert, assert flash_attn_func is not None, ('Please install FlashAttention first, '
'e.g., with pip install flash-attn')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.flash_attn_func = flash_attn_func
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((i.is_cuda for i in (q,k,v)))
output = self.flash_attn_func(q, k, v, dropout_p=self.dropout_p, softmax_scale=self.softmax_scale, causal=self.causal)
# [b,s,a,dim]
return output
def parallel_attention_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if self.use_flash_attn:
self.core_attention_flash = FlashFixedSelfAttention(
causal=True, attention_dropout=self.config.attention_dropout
) )
return wrapper
class ParallelAttentionPatch(MegatronModule): class ParallelAttentionPatch(MegatronModule):
"""Parallel self-attention layer abstract class. """Parallel self-attention layer abstract class.
...@@ -87,6 +90,7 @@ class ParallelAttentionPatch(MegatronModule): ...@@ -87,6 +90,7 @@ class ParallelAttentionPatch(MegatronModule):
Self-attention layer takes input with size [s, b, h] Self-attention layer takes input with size [s, b, h]
and returns output of the same size. and returns output of the same size.
""" """
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None, encoder_output=None, inference_params=None,
rotary_pos_emb=None): rotary_pos_emb=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment