Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -2,6 +2,7 @@
import torch
from torch import nn
import torch._dynamo
torch._dynamo.config.suppress_errors = True
......
File mode changed from 100755 to 100644
......@@ -60,10 +60,6 @@ except ImportError:
except ImportError:
flash_attn_unpadded_func = None
try:
from flash_attn.flash_attn_triton import flash_attn_func
except ImportError:
flash_attn_func = None
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
......@@ -165,7 +161,6 @@ class ParallelMLP(MegatronModule):
is_expert=is_expert,
)
# @torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, hidden_states):
# [s, b, 4hp]
......@@ -477,10 +472,6 @@ class FlashSelfAttention(torch.nn.Module):
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
# Use FlashAttention-2 when args.use_flash_attn_ck is True
args = get_args()
self.flash_attn_func = flash_attn_unpadded_func
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
......@@ -522,38 +513,6 @@ class FlashSelfAttention(torch.nn.Module):
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
class FlashSelfAttentionTriton(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_func is not None, ('Triton version of FlashAttention is not installed.')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda
q, k, v = [rearrange(x, 's b h d -> b h s d').contiguous()
for x in (q, k, v)]
output = flash_attn_func(q, k, v, self.causal)
output = rearrange(output, 'b s h d -> h b (s d)').contiguous()
return output
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
......@@ -582,19 +541,13 @@ class ParallelAttention(MegatronModule):
else:
kv_projection_size = args.kv_channels * args.num_attention_heads
self.use_flash_attn = (args.use_flash_attn_cutlass or args.use_flash_attn_triton) \
self.use_flash_attn = args.use_flash_attn \
and attention_type == AttnType.self_attn \
and self.attn_mask_type == AttnMaskType.causal
self.use_flash_attn_triton = args.use_flash_attn_triton
if self.use_flash_attn:
if args.use_flash_attn_cutlass:
if flash_attn_unpadded_func is None:
raise ImportError('FlashAttention is not installed, please install with '
'pip install flash-attn')
if args.use_flash_attn_triton:
assert flash_attn_func != None, "Cannot import FlashAttention triton "
if flash_attn_unpadded_func is None:
raise ImportError('FlashAttention is not installed, please install with '
'pip install flash-attn')
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
'self-attention for now')
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
......@@ -654,11 +607,7 @@ class ParallelAttention(MegatronModule):
self.attn_mask_type)
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
if self.use_flash_attn_triton:
self.core_attention_flash = FlashSelfAttentionTriton(
causal=True, attention_dropout=args.attention_dropout
)
elif self.use_flash_attn:
if self.use_flash_attn:
self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=config.attention_dropout
)
......@@ -766,7 +715,7 @@ class ParallelAttention(MegatronModule):
dim=3)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
query_layer = query_layer.contiguous().view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
......@@ -806,7 +755,8 @@ class ParallelAttention(MegatronModule):
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
assert sequence_end <= inference_key_memory.size(0), ("Current sequence length is "
"longer than expected maximum sequence length! Increase inference_max_seq_length.")
# Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
......@@ -871,18 +821,14 @@ class ParallelAttention(MegatronModule):
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
else:
if not self.use_flash_attn_triton:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
context_layer = self.core_attention_flash(q, k, v)
else:
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
if not self.use_flash_attn_triton:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
context_layer = self.core_attention_flash(q, k, v)
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
# =================
# Output. [sq, b, h]
......@@ -1213,7 +1159,6 @@ class ParallelTransformerLayer(MegatronModule):
return retriever_output, norm_input, norm_output
# @torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
retriever_input=None,
......@@ -1360,7 +1305,7 @@ class NoopTransformerLayer(MegatronModule):
"""A single 'no-op' transformer layer.
The sole purpose of this layer is for when a standalone embedding layer
is used (i.e., args.standalone_embedding_stage == True). In this case,
is used (i.e., args.account_for_embedding_in_pipeline_split == True). In this case,
zero transformer layers are assigned when pipeline rank == 0. Additionally,
when virtual pipeline rank >= 1, zero total model parameters are created
(virtual rank 0 contains the input embedding). This results in the model's
......@@ -1399,7 +1344,7 @@ def _get_num_layers(args, model_type, is_decoder=False):
# or no layers at all (virtual pp rank >= 1).
num_layers = (
0
if args.standalone_embedding_stage
if args.account_for_embedding_in_pipeline_split
and mpu.get_pipeline_model_parallel_rank() == 0 else
args.num_layers // args.transformer_pipeline_model_parallel_size
)
......@@ -1616,7 +1561,7 @@ class ParallelTransformer(MegatronModule):
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# args.account_for_embedding_in_pipeline_split == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
......
......@@ -9,6 +9,7 @@ import torch
from megatron.training import get_args
from megatron.legacy.model import LayerNorm, RMSNorm
from megatron.core.jit import jit_fuser
import torch._dynamo
torch._dynamo.config.suppress_errors = True
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment