Commit d77d95c5 authored by dongcl's avatar dongcl
Browse files

bug fix

parent 8da353a4
......@@ -84,6 +84,7 @@ class CoreAdaptation(MegatronAdaptationABC):
self.patch_core_distributed()
self.patch_core_models()
self.patch_core_transformers()
self.patch_core_extentions()
self.patch_tensor_parallel()
self.patch_training()
self.patch_miscellaneous()
......@@ -137,6 +138,12 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.transformer_config.MLATransformerConfig',
MLATransformerConfig)
def patch_core_extentions(self):
from ..core.extensions.transformer_engine import te_dot_product_attention_init
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
te_dot_product_attention_init)
def patch_tensor_parallel(self):
from ..core import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
......@@ -147,12 +154,15 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_training(self):
from ..training.tokenizer import build_tokenizer
from ..training.initialize import initialize_megatron
from ..training.initialize import _initialize_distributed
from ..training.initialize import _compile_dependencies
MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer',
build_tokenizer)
MegatronAdaptation.register('megatron.training.initialize.initialize_megatron',
initialize_megatron)
MegatronAdaptation.register('megatron.training.initialize._initialize_distributed',
_initialize_distributed)
MegatronAdaptation.register('megatron.training.initialize._compile_dependencies',
_compile_dependencies)
def patch_miscellaneous(self):
from ..training.arguments import parse_args
......
import os
import dataclasses
from typing import Any, Optional
from packaging.version import Version as PkgVersion
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.utils import get_te_version, is_te_min_version
from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_hierarchical_context_parallel_groups,
get_tensor_model_parallel_group,
)
def te_dot_product_attention_init(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: Optional[float] = None,
softmax_scale: Optional[float] = None,
k_channels: Optional[int] = None,
v_channels: Optional[int] = None,
cp_comm_type: str = "p2p",
):
self.config = config
self.te_forward_mask_type = False
self.qkv_format: str = 'sbhd'
if self.config.apply_query_key_layer_scaling != bool(
int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
):
raise ValueError(
f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
f"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs: dict[str, Any] = {}
if is_te_min_version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
if is_te_min_version("0.12.0", check_equality=False):
self.te_forward_mask_type = True
# This check is important as CP config can be disabled while having a valid CP group
# Example - Disabling CP for encoder while a valid CP group exists for decoder
if self.config.context_parallel_size > 1:
assert is_te_min_version(
"1.0.0"
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
if is_te_min_version("1.10.0"):
if cp_comm_type is None:
extra_kwargs["cp_comm_type"] = "p2p"
elif cp_comm_type == "a2a+p2p":
assert is_te_min_version("1.12.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.12.0 to support"
"hierarchical cp commucation."
)
extra_kwargs["cp_comm_type"] = "a2a+p2p"
extra_kwargs["cp_group"] = get_hierarchical_context_parallel_groups(
check_initialized=False
)
else:
extra_kwargs["cp_comm_type"] = cp_comm_type
if self.config.deterministic_mode:
if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
raise RuntimeError(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
)
if config.window_size is not None:
# Check version
assert is_te_min_version("1.2.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support"
"sliding window attention."
)
extra_kwargs['window_size'] = config.window_size
if is_te_min_version("1.9.0"):
# TE 1.10.0 introduces the ability to set the different k and v channels
kv_channels = (
(k_channels, v_channels)
if k_channels is not None and v_channels is not None
else self.config.kv_channels
)
extra_kwargs['softmax_scale'] = softmax_scale
else:
kv_channels = self.config.kv_channels
self.kept_packed_seq_params = set(
field.name for field in dataclasses.fields(PackedSeqParams)
)
if get_te_version() < PkgVersion("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
self.kept_packed_seq_params.discard("max_seqlen_q")
self.kept_packed_seq_params.discard("max_seqlen_kv")
if get_te_version() < PkgVersion("1.10.0"):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012)
self.kept_packed_seq_params.discard("cu_seqlens_q_padded")
self.kept_packed_seq_params.discard("cu_seqlens_kv_padded")
super(TEDotProductAttention, self).__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=kv_channels,
attention_dropout=(
self.config.attention_dropout if attention_dropout is None else attention_dropout
),
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
**extra_kwargs,
)
import os
import argparse
from megatron.training.arguments import (
......
"""Megatron initialization."""
import time
import torch
from datetime import timedelta
......@@ -76,4 +77,77 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
print(
f"> initialized pipeline model parallel with size "
f"{mpu.get_pipeline_model_parallel_world_size()}"
)
\ No newline at end of file
)
def _compile_dependencies():
args = get_args()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if torch.distributed.get_rank() == 0:
start_time = time.time()
print("> compiling dataset index builder ...")
from megatron.core.datasets.utils import compile_helpers
compile_helpers()
print(
">>> done with dataset index builder. Compilation time: {:.3f} "
"seconds".format(time.time() - start_time),
flush=True,
)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len = args.seq_length
attn_batch_size = (
args.num_attention_heads / args.tensor_model_parallel_size
) * args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = (
seq_len > 16
and seq_len <= 16384
and seq_len % 4 == 0
and attn_batch_size % 4 == 0
)
# Print a warning.
if not (
(args.fp16 or args.bf16)
and custom_kernel_constraint
and args.masked_softmax_fusion
):
if args.rank == 0:
print(
"WARNING: constraints for invoking optimized"
" fused softmax kernel are not met. We default"
" back to unfused kernel invocations.",
flush=True,
)
# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print("> compiling and loading fused kernels ...", flush=True)
#fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier()
#fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(
">>> done with compiling and loading fused kernels. "
"Compilation time: {:.3f} seconds".format(time.time() - start_time),
flush=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment