Commit 1dc8bc8a authored by dongcl's avatar dongcl
Browse files

fix the bug related to parameter sharing

parent 5b1e05ab
...@@ -100,7 +100,11 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -100,7 +100,11 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..core.models.gpt.gpt_model import ( from ..core.models.gpt.gpt_model import (
gpt_model_forward, gpt_model_forward,
gpt_model_init_wrapper, gpt_model_init_wrapper,
shared_embedding_or_mtp_embedding_weight shared_embedding_or_output_weight,
)
from ..core.models.common.language_module.language_module import (
setup_embeddings_and_output_layer,
tie_embeddings_and_output_weights_state_dict
) )
from ..training.utils import get_batch_on_this_tp_rank from ..training.utils import get_batch_on_this_tp_rank
...@@ -115,14 +119,21 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -115,14 +119,21 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank) MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model # GPT Model
MegatronAdaptation.register(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer',
setup_embeddings_and_output_layer)
MegatronAdaptation.register(
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict',
tie_embeddings_and_output_weights_state_dict)
MegatronAdaptation.register(
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight',
shared_embedding_or_output_weight)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward) MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper, gpt_model_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
from megatron.core.models.gpt.gpt_model import GPTModel
setattr(GPTModel, 'shared_embedding_or_mtp_embedding_weight', shared_embedding_or_mtp_embedding_weight)
def patch_core_transformers(self): def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
......
...@@ -28,22 +28,10 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf ...@@ -28,22 +28,10 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module = model[0] model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights: if model_module.share_embeddings_and_output_weights or getattr(config, 'num_nextn_predict_layers', 0):
weight = model_module.shared_embedding_or_output_weight() weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr) orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad) grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
if (
hasattr(model_module, "share_mtp_embedding_and_output_weight")
and model_module.share_mtp_embedding_and_output_weight
and config.num_nextn_predict_layers > 0
):
weight = model_module.shared_embedding_or_mtp_embedding_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
...@@ -30,6 +30,15 @@ def gpt_model_init_wrapper(fn): ...@@ -30,6 +30,15 @@ def gpt_model_init_wrapper(fn):
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs) fn(self, *args, **kwargs)
if self.post_process and getattr(self.config, 'num_nextn_predict_layers', 0):
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=kwargs.get("position_embedding_type"),
scatter_to_sequence_parallel=kwargs.get("scatter_embedding_sequence_parallel"),
)
if ( if (
self.post_process self.post_process
and int(os.getenv("USE_FLUX_OVERLAP", "0")) and int(os.getenv("USE_FLUX_OVERLAP", "0"))
...@@ -55,7 +64,6 @@ def gpt_model_init_wrapper(fn): ...@@ -55,7 +64,6 @@ def gpt_model_init_wrapper(fn):
if self.num_nextn_predict_layers: if self.num_nextn_predict_layers:
assert hasattr(self.config, "mtp_spec") assert hasattr(self.config, "mtp_spec")
self.mtp_spec: ModuleSpec = self.config.mtp_spec self.mtp_spec: ModuleSpec = self.config.mtp_spec
self.share_mtp_embedding_and_output_weight = self.config.share_mtp_embedding_and_output_weight
self.recompute_mtp_norm = self.config.recompute_mtp_norm self.recompute_mtp_norm = self.config.recompute_mtp_norm
self.recompute_mtp_layer = self.config.recompute_mtp_layer self.recompute_mtp_layer = self.config.recompute_mtp_layer
self.mtp_loss_scale = self.config.mtp_loss_scale self.mtp_loss_scale = self.config.mtp_loss_scale
...@@ -74,7 +82,6 @@ def gpt_model_init_wrapper(fn): ...@@ -74,7 +82,6 @@ def gpt_model_init_wrapper(fn):
position_embedding_type=self.position_embedding_type, position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent, rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=kwargs.get("seq_len_interpolation_factor", None), seq_len_interpolation_factor=kwargs.get("seq_len_interpolation_factor", None),
share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight,
recompute_mtp_norm=self.recompute_mtp_norm, recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer, recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False add_output_layer_bias=False
...@@ -83,95 +90,22 @@ def gpt_model_init_wrapper(fn): ...@@ -83,95 +90,22 @@ def gpt_model_init_wrapper(fn):
] ]
) )
if self.pre_process or self.post_process:
setup_mtp_embeddings(self)
return wrapper return wrapper
def shared_embedding_or_mtp_embedding_weight(self) -> Tensor: def shared_embedding_or_output_weight(self) -> Tensor:
"""Gets the embedding weight when share embedding and mtp embedding weights set to True. """Gets the emedding weight or output logit weights when share embedding and output weights set to True.
Returns: Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight
mtp embedding layers weight
""" """
assert self.num_nextn_predict_layers > 0 if self.pre_process or (self.post_process and getattr(self.config, 'num_nextn_predict_layers', 0)):
if self.pre_process:
return self.embedding.word_embeddings.weight return self.embedding.word_embeddings.weight
elif self.post_process: elif self.post_process:
return self.mtp_layers[0].embedding.word_embeddings.weight return self.output_layer.weight
return None return None
def setup_mtp_embeddings(self):
"""
Share embedding layer in mtp layer.
"""
if self.pre_process:
self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
# Set `is_embedding_or_output_parameter` attribute.
for i in range(self.num_nextn_predict_layers):
if self.post_process and self.mtp_layers[i].embedding.word_embeddings.weight is not None:
self.mtp_layers[i].embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
if not self.share_mtp_embedding_and_output_weight:
return
if self.pre_process and self.post_process:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self.shared_embedding_or_mtp_embedding_weight().zero_out_wgrad = True
return
if self.pre_process and not self.post_process:
assert parallel_state.is_pipeline_first_stage()
self.shared_embedding_or_mtp_embedding_weight().shared_embedding = True
if self.post_process and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
for i in range(self.num_nextn_predict_layers):
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.mtp_layers[i].embedding.word_embeddings.weight.data.fill_(0)
self.mtp_layers[i].embedding.word_embeddings.weight.shared = True
self.mtp_layers[i].embedding.word_embeddings.weight.shared_embedding = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_mtp_embedding_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule.embedding_warning_printed = True
def slice_inputs(self, input_ids, labels, position_ids, attention_mask): def slice_inputs(self, input_ids, labels, position_ids, attention_mask):
if self.num_nextn_predict_layers == 0: if self.num_nextn_predict_layers == 0:
return ( return (
...@@ -317,11 +251,6 @@ def gpt_model_forward( ...@@ -317,11 +251,6 @@ def gpt_model_forward(
loss = 0 loss = 0
# Multi token prediction module # Multi token prediction module
if self.num_nextn_predict_layers and self.training: if self.num_nextn_predict_layers and self.training:
if not self.share_embeddings_and_output_weights and self.share_mtp_embedding_and_output_weight:
output_weight = self.output_layer.weight
output_weight.zero_out_wgrad = True
embedding_weight = self.shared_embedding_or_mtp_embedding_weight() if self.share_mtp_embedding_and_output_weight else None
mtp_hidden_states = hidden_states mtp_hidden_states = hidden_states
for i in range(self.num_nextn_predict_layers): for i in range(self.num_nextn_predict_layers):
mtp_hidden_states, mtp_loss = self.mtp_layers[i]( mtp_hidden_states, mtp_loss = self.mtp_layers[i](
...@@ -333,7 +262,8 @@ def gpt_model_forward( ...@@ -333,7 +262,8 @@ def gpt_model_forward(
inference_params, inference_params,
packed_seq_params, packed_seq_params,
extra_block_kwargs, extra_block_kwargs,
embeding_weight=embedding_weight, embedding_layer=self.embedding,
output_layer=self.output_layer,
output_weight=output_weight, output_weight=output_weight,
) )
......
...@@ -46,7 +46,6 @@ class MultiTokenPredictor(MegatronModule): ...@@ -46,7 +46,6 @@ class MultiTokenPredictor(MegatronModule):
rotary_percent: float = 1.0, rotary_percent: float = 1.0,
rotary_base: int = 10000, rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None, seq_len_interpolation_factor: Optional[float] = None,
share_mtp_embedding_and_output_weight=True,
recompute_mtp_norm=False, recompute_mtp_norm=False,
recompute_mtp_layer=False, recompute_mtp_layer=False,
add_output_layer_bias=False add_output_layer_bias=False
...@@ -65,20 +64,10 @@ class MultiTokenPredictor(MegatronModule): ...@@ -65,20 +64,10 @@ class MultiTokenPredictor(MegatronModule):
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.position_embedding_type = position_embedding_type self.position_embedding_type = position_embedding_type
# share with main model
self.share_mtp_embedding_and_output_weight = share_mtp_embedding_and_output_weight
self.recompute_layer_norm = recompute_mtp_norm self.recompute_layer_norm = recompute_mtp_norm
self.recompute_mtp_layer = recompute_mtp_layer self.recompute_mtp_layer = recompute_mtp_layer
self.add_output_layer_bias = add_output_layer_bias self.add_output_layer_bias = add_output_layer_bias
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=self.position_embedding_type,
skip_weight_param_allocation=self.pre_process and self.share_mtp_embedding_and_output_weight
)
if self.position_embedding_type == 'rope': if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding( self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels, kv_channels=self.config.kv_channels,
...@@ -138,23 +127,6 @@ class MultiTokenPredictor(MegatronModule): ...@@ -138,23 +127,6 @@ class MultiTokenPredictor(MegatronModule):
self.embedding_activation_buffer = None self.embedding_activation_buffer = None
self.grad_output_buffer = None self.grad_output_buffer = None
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
column_parallel_linear_impl = FluxColumnParallelLinear
else:
column_parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = column_parallel_linear_impl(
self.config.hidden_size,
self.vocab_size,
config=self.config,
init_method=self.config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
def forward( def forward(
self, self,
hidden_input_ids: Tensor, hidden_input_ids: Tensor,
...@@ -165,16 +137,16 @@ class MultiTokenPredictor(MegatronModule): ...@@ -165,16 +137,16 @@ class MultiTokenPredictor(MegatronModule):
inference_params: InferenceParams = None, inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None, packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None, extra_block_kwargs: dict = None,
embeding_weight: Optional[torch.Tensor] = None, embedding_layer=None,
output_weight: Optional[torch.Tensor] = None, output_layer=None,
output_weight=None
): ):
"""Forward function of the MTP module""" """Forward function of the MTP module"""
# Decoder embedding. # Decoder embedding.
decoder_input = self.embedding( decoder_input = embedding(
input_ids=embed_input_ids, input_ids=embed_input_ids,
position_ids=position_ids, position_ids=position_ids,
weight=embeding_weight,
) )
# Rotary positional embeddings (embedding is None for PP intermediate devices) # Rotary positional embeddings (embedding is None for PP intermediate devices)
...@@ -251,7 +223,7 @@ class MultiTokenPredictor(MegatronModule): ...@@ -251,7 +223,7 @@ class MultiTokenPredictor(MegatronModule):
else: else:
finalnorm_output = hidden_states finalnorm_output = hidden_states
logits, _ = self.output_layer(finalnorm_output, weight=output_weight) logits, _ = output_layer(finalnorm_output, weight=output_weight)
if self.recompute_layer_norm: if self.recompute_layer_norm:
self.finalnorm_ckpt.discard_output() self.finalnorm_ckpt.discard_output()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment