Commit 763941b5 authored by dongcl's avatar dongcl
Browse files

modify ParallelAttention

parent bf323343
......@@ -141,9 +141,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
# apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
......@@ -233,13 +233,21 @@ class LegacyAdaptation(MegatronAdaptationABC):
self.patch_legacy_models()
def patch_legacy_models(self):
from ..legacy.model.transformer import ParallelMLPPatch, ParallelAttentionPatch
from ..legacy.model.transformer import (
parallel_mlp_init_wrapper,
ParallelAttentionPatch,
parallel_attention_init_wrapper
)
from ..legacy.model.utils import get_norm
# ParallecMLP
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelMLP.__init__',
ParallelMLPPatch.__init__)
parallel_mlp_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.__init__',
parallel_attention_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.forward',
ParallelAttentionPatch.forward)
......
......@@ -6,79 +6,82 @@ from megatron.core import tensor_parallel
from megatron.legacy.model.enums import AttnType
from megatron.core.models.common.embeddings import apply_rotary_pos_emb
from megatron.legacy.model.module import MegatronModule
from megatron.legacy.model.transformer import ParallelMLP
from megatron.legacy.model.utils import (
erf_gelu,
openai_gelu,
)
try:
from einops import rearrange
except ImportError:
rearrange = None
class ParallelMLPPatch(MegatronModule):
"""MLP.
try: # 使用定长fa
from flash_attn import flash_attn_func
except ImportError:
flash_attn_func = None
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config, is_expert=False):
super(ParallelMLP, self).__init__()
args = get_args()
self.add_bias = config.add_bias_linear
ffn_hidden_size = config.ffn_hidden_size
if config.gated_linear_unit:
ffn_hidden_size *= 2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
ffn_hidden_size,
config=config,
init_method=config.init_method,
bias=self.add_bias,
gather_output=False,
skip_bias_add=True,
is_expert=is_expert,
)
self.bias_gelu_fusion = False
self.activation_func = None
self.swiglu = args.swiglu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
elif args.swiglu:
def parallel_mlp_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
args = get_args()
if args.swiglu:
@torch.compile(mode="max-autotune-no-cudagraphs")
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = swiglu
elif args.squared_relu:
def squared_relu(x):
return torch.pow(F.relu(x), 2)
self.activation_func = squared_relu
else:
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=self.add_bias,
skip_bias_add=True,
input_is_parallel=True,
is_expert=is_expert,
)
return wrapper
class FlashFixedSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_func is not None, ('Please install FlashAttention first, '
'e.g., with pip install flash-attn')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.flash_attn_func = flash_attn_func
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((i.is_cuda for i in (q,k,v)))
output = self.flash_attn_func(q, k, v, dropout_p=self.dropout_p, softmax_scale=self.softmax_scale, causal=self.causal)
# [b,s,a,dim]
return output
def parallel_attention_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if self.use_flash_attn:
self.core_attention_flash = FlashFixedSelfAttention(
causal=True, attention_dropout=self.config.attention_dropout
)
return wrapper
class ParallelAttentionPatch(MegatronModule):
......@@ -87,6 +90,7 @@ class ParallelAttentionPatch(MegatronModule):
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment