"vscode:/vscode.git/clone" did not exist on "3d415d0ea395a08a5a2230881cee4d22d21f7c76"
Commit c140c914 authored by dongcl's avatar dongcl
Browse files

bug fix

parent 26940c4c
...@@ -165,13 +165,14 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -165,13 +165,14 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init_wrapper
# VocabParallelEmbedding # VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
vocab_parallel_embedding_forward) vocab_parallel_embedding_forward)
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init) vocab_parallel_embedding_init_wrapper,
apply_wrapper=True)
# VocabParallelCrossEntropy # VocabParallelCrossEntropy
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits', MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits',
......
import os import os
import copy
import socket import socket
import warnings import warnings
from functools import wraps from functools import wraps
from typing import Callable, List, Optional from typing import Callable, List, Optional
try: if int(os.getenv("USE_FLUX_OVERLAP", "0"))
import flux try:
except ImportError: import flux
raise ImportError("flux is NOT installed") from dcu_megatron.core.utils import is_flux_min_version
except ImportError:
raise ImportError("flux is NOT installed")
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -50,7 +53,6 @@ from megatron.core.tensor_parallel.layers import ( ...@@ -50,7 +53,6 @@ from megatron.core.tensor_parallel.layers import (
linear_with_frozen_weight, linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce linear_with_grad_accumulation_and_async_allreduce
) )
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True _grad_accum_fusion_available = True
...@@ -60,64 +62,29 @@ except ImportError: ...@@ -60,64 +62,29 @@ except ImportError:
_grad_accum_fusion_available = False _grad_accum_fusion_available = False
def vocab_parallel_embedding_init( def vocab_parallel_embedding_init_wrapper(fn):
self, @wraps(fn)
num_embeddings: int, def wrapper(self,
embedding_dim: int, *args,
*, skip_weight_param_allocation: bool = False,
init_method: Callable, **kwargs
reduce_scatter_embeddings: bool = False, ):
config: ModelParallelConfig,
skip_weight_param_allocation: bool = False if (
): skip_weight_param_allocation
super(VocabParallelEmbedding, self).__init__() and "config" in kwargs
# Keep the input dimensions. and hasattr(kwargs["config"], "perform_initialization")
self.num_embeddings = num_embeddings ):
self.embedding_dim = embedding_dim config = copy.deepcopy(kwargs["config"])
self.reduce_scatter_embeddings = reduce_scatter_embeddings config.perform_initialization = False
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() kwargs["config"] = config
# Divide the weight matrix along the vocaburaly dimension.
(self.vocab_start_index, self.vocab_end_index) = ( fn(self, *args, **kwargs)
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, if skip_weight_param_allocation:
get_tensor_model_parallel_rank(), self.weight = None
self.tensor_model_parallel_size,
) return wrapper
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
self.deterministic_mode = config.deterministic_mode
# Allocate weights and initialize.
if not skip_weight_param_allocation:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=config.params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
else:
self.weight = None
@torch.compile(mode='max-autotune-no-cudagraphs') @torch.compile(mode='max-autotune-no-cudagraphs')
......
...@@ -16,6 +16,7 @@ from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross ...@@ -16,6 +16,7 @@ from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross
from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module
from ...tensor_parallel.random import CheckpointWithoutOutput from ...tensor_parallel.random import CheckpointWithoutOutput
from ...tensor_parallel import FluxColumnParallelLinear
@dataclass @dataclass
......
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import wraps
from megatron.training import get_args from megatron.training import get_args
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.legacy.model.enums import AttnType from megatron.legacy.model.enums import AttnType
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment