Commit bf212e29 authored by dongcl's avatar dongcl
Browse files

bug fix

parent 595e428a
...@@ -45,7 +45,7 @@ def get_gpt_layer_with_flux_spec( ...@@ -45,7 +45,7 @@ def get_gpt_layer_with_flux_spec(
fp8: Optional[str] = None, # pylint: disable=unused-arguments fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False, moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec: ) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training). """Use this spec to use flux modules (required for fp8 training).
Args: Args:
...@@ -57,7 +57,7 @@ def get_gpt_layer_with_flux_spec( ...@@ -57,7 +57,7 @@ def get_gpt_layer_with_flux_spec(
Defaults to False. Defaults to False.
Returns: Returns:
ModuleSpec: Module specification with TE modules ModuleSpec: Module specification with flux modules
""" """
if fp8 is not None: if fp8 is not None:
warnings.warn( warnings.warn(
......
import os
import logging import logging
from typing import Literal, Optional from typing import Literal, Optional
from functools import wraps from functools import wraps
...@@ -21,6 +22,7 @@ from megatron.core.extensions.transformer_engine import TEColumnParallelLinear ...@@ -21,6 +22,7 @@ from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from dcu_megatron.core.utils import tensor_slide from dcu_megatron.core.utils import tensor_slide
from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor
from dcu_megatron.core.transformer.transformer_config import TransformerConfig from dcu_megatron.core.transformer.transformer_config import TransformerConfig
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
def gpt_model_init_wrapper(fn): def gpt_model_init_wrapper(fn):
...@@ -28,6 +30,26 @@ def gpt_model_init_wrapper(fn): ...@@ -28,6 +30,26 @@ def gpt_model_init_wrapper(fn):
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs) fn(self, *args, **kwargs)
if (
self.post_process
and int(os.getenv("USE_FLUX_OVERLAP", "0"))
):
self.output_layer = FluxColumnParallelLinear(
self.config.hidden_size,
self.vocab_size,
config=self.config,
init_method=self.config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
self.setup_embeddings_and_output_layer()
# add mtp # add mtp
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
if self.num_nextn_predict_layers: if self.num_nextn_predict_layers:
......
...@@ -230,7 +230,7 @@ class AGLinear(torch.autograd.Function): ...@@ -230,7 +230,7 @@ class AGLinear(torch.autograd.Function):
output_scale=None, output_scale=None,
fast_accum=False fast_accum=False
) )
torch.distributed.barrier()
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
output = output.view(sequence_len * world_size, batch_size, -1) output = output.view(sequence_len * world_size, batch_size, -1)
else: else:
...@@ -1172,7 +1172,7 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -1172,7 +1172,7 @@ class FluxRowParallelLinear(RowParallelLinear):
output_parallel = self._forward_impl( output_parallel = self._forward_impl(
input=input_parallel, input=input_parallel,
weight=self.weight, weight=self.weight,
bias=self.bias if not self.skip_bias_add and self.sequence_parallel else None, bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion, gradient_accumulation_fusion=self.gradient_accumulation_fusion,
allreduce_dgrad=False, allreduce_dgrad=False,
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel, sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
...@@ -1192,7 +1192,6 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -1192,7 +1192,6 @@ class FluxRowParallelLinear(RowParallelLinear):
if not self.skip_bias_add: if not self.skip_bias_add:
output_bias = None output_bias = None
if not self.sequence_parallel:
output = (output_ + self.bias) if self.bias is not None else output_ output = (output_ + self.bias) if self.bias is not None else output_
else: else:
output = output_ output = output_
......
...@@ -61,7 +61,11 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -61,7 +61,11 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
""" """
args = get_args() args = get_args()
use_te = args.transformer_impl == "transformer_engine" or bool(int(os.getenv("USE_FLUX_OVERLAP", "0")))
if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))):
assert args.transformer_impl == "transformer_engine"
use_te = args.transformer_impl == "transformer_engine"
if args.record_memory_history: if args.record_memory_history:
torch.cuda.memory._record_memory_history(True, torch.cuda.memory._record_memory_history(True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment