Commit c140c914 authored by dongcl's avatar dongcl
Browse files

bug fix

parent 26940c4c
......@@ -165,13 +165,14 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init_wrapper
# VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
vocab_parallel_embedding_forward)
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init)
vocab_parallel_embedding_init_wrapper,
apply_wrapper=True)
# VocabParallelCrossEntropy
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits',
......
import os
import copy
import socket
import warnings
from functools import wraps
from typing import Callable, List, Optional
try:
import flux
except ImportError:
raise ImportError("flux is NOT installed")
if int(os.getenv("USE_FLUX_OVERLAP", "0"))
try:
import flux
from dcu_megatron.core.utils import is_flux_min_version
except ImportError:
raise ImportError("flux is NOT installed")
import torch
import torch.nn.functional as F
......@@ -50,7 +53,6 @@ from megatron.core.tensor_parallel.layers import (
linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce
)
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True
......@@ -60,64 +62,29 @@ except ImportError:
_grad_accum_fusion_available = False
def vocab_parallel_embedding_init(
self,
num_embeddings: int,
embedding_dim: int,
*,
init_method: Callable,
reduce_scatter_embeddings: bool = False,
config: ModelParallelConfig,
skip_weight_param_allocation: bool = False
):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.reduce_scatter_embeddings = reduce_scatter_embeddings
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
(self.vocab_start_index, self.vocab_end_index) = (
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings,
get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size,
)
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
self.deterministic_mode = config.deterministic_mode
# Allocate weights and initialize.
if not skip_weight_param_allocation:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=config.params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
else:
self.weight = None
def vocab_parallel_embedding_init_wrapper(fn):
@wraps(fn)
def wrapper(self,
*args,
skip_weight_param_allocation: bool = False,
**kwargs
):
if (
skip_weight_param_allocation
and "config" in kwargs
and hasattr(kwargs["config"], "perform_initialization")
):
config = copy.deepcopy(kwargs["config"])
config.perform_initialization = False
kwargs["config"] = config
fn(self, *args, **kwargs)
if skip_weight_param_allocation:
self.weight = None
return wrapper
@torch.compile(mode='max-autotune-no-cudagraphs')
......
......@@ -16,6 +16,7 @@ from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross
from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module
from ...tensor_parallel.random import CheckpointWithoutOutput
from ...tensor_parallel import FluxColumnParallelLinear
@dataclass
......
import torch
import torch.nn.functional as F
from functools import wraps
from megatron.training import get_args
from megatron.core import tensor_parallel
from megatron.legacy.model.enums import AttnType
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment