Commit 1dc8bc8a authored by dongcl's avatar dongcl
Browse files

fix the bug related to parameter sharing

parent 5b1e05ab
......@@ -100,7 +100,11 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..core.models.gpt.gpt_model import (
gpt_model_forward,
gpt_model_init_wrapper,
shared_embedding_or_mtp_embedding_weight
shared_embedding_or_output_weight,
)
from ..core.models.common.language_module.language_module import (
setup_embeddings_and_output_layer,
tie_embeddings_and_output_weights_state_dict
)
from ..training.utils import get_batch_on_this_tp_rank
......@@ -115,14 +119,21 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model
MegatronAdaptation.register(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer',
setup_embeddings_and_output_layer)
MegatronAdaptation.register(
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict',
tie_embeddings_and_output_weights_state_dict)
MegatronAdaptation.register(
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight',
shared_embedding_or_output_weight)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
from megatron.core.models.gpt.gpt_model import GPTModel
setattr(GPTModel, 'shared_embedding_or_mtp_embedding_weight', shared_embedding_or_mtp_embedding_weight)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
......
......@@ -28,22 +28,10 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
if model_module.share_embeddings_and_output_weights or getattr(config, 'num_nextn_predict_layers', 0):
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
if (
hasattr(model_module, "share_mtp_embedding_and_output_weight")
and model_module.share_mtp_embedding_and_output_weight
and config.num_nextn_predict_layers > 0
):
weight = model_module.shared_embedding_or_mtp_embedding_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
......@@ -30,6 +30,15 @@ def gpt_model_init_wrapper(fn):
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if self.post_process and getattr(self.config, 'num_nextn_predict_layers', 0):
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=kwargs.get("position_embedding_type"),
scatter_to_sequence_parallel=kwargs.get("scatter_embedding_sequence_parallel"),
)
if (
self.post_process
and int(os.getenv("USE_FLUX_OVERLAP", "0"))
......@@ -55,7 +64,6 @@ def gpt_model_init_wrapper(fn):
if self.num_nextn_predict_layers:
assert hasattr(self.config, "mtp_spec")
self.mtp_spec: ModuleSpec = self.config.mtp_spec
self.share_mtp_embedding_and_output_weight = self.config.share_mtp_embedding_and_output_weight
self.recompute_mtp_norm = self.config.recompute_mtp_norm
self.recompute_mtp_layer = self.config.recompute_mtp_layer
self.mtp_loss_scale = self.config.mtp_loss_scale
......@@ -74,7 +82,6 @@ def gpt_model_init_wrapper(fn):
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=kwargs.get("seq_len_interpolation_factor", None),
share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight,
recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False
......@@ -83,95 +90,22 @@ def gpt_model_init_wrapper(fn):
]
)
if self.pre_process or self.post_process:
setup_mtp_embeddings(self)
return wrapper
def shared_embedding_or_mtp_embedding_weight(self) -> Tensor:
"""Gets the embedding weight when share embedding and mtp embedding weights set to True.
def shared_embedding_or_output_weight(self) -> Tensor:
"""Gets the emedding weight or output logit weights when share embedding and output weights set to True.
Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns
mtp embedding layers weight
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight
"""
assert self.num_nextn_predict_layers > 0
if self.pre_process:
if self.pre_process or (self.post_process and getattr(self.config, 'num_nextn_predict_layers', 0)):
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.mtp_layers[0].embedding.word_embeddings.weight
return self.output_layer.weight
return None
def setup_mtp_embeddings(self):
"""
Share embedding layer in mtp layer.
"""
if self.pre_process:
self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
# Set `is_embedding_or_output_parameter` attribute.
for i in range(self.num_nextn_predict_layers):
if self.post_process and self.mtp_layers[i].embedding.word_embeddings.weight is not None:
self.mtp_layers[i].embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
if not self.share_mtp_embedding_and_output_weight:
return
if self.pre_process and self.post_process:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self.shared_embedding_or_mtp_embedding_weight().zero_out_wgrad = True
return
if self.pre_process and not self.post_process:
assert parallel_state.is_pipeline_first_stage()
self.shared_embedding_or_mtp_embedding_weight().shared_embedding = True
if self.post_process and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
for i in range(self.num_nextn_predict_layers):
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.mtp_layers[i].embedding.word_embeddings.weight.data.fill_(0)
self.mtp_layers[i].embedding.word_embeddings.weight.shared = True
self.mtp_layers[i].embedding.word_embeddings.weight.shared_embedding = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_mtp_embedding_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule.embedding_warning_printed = True
def slice_inputs(self, input_ids, labels, position_ids, attention_mask):
if self.num_nextn_predict_layers == 0:
return (
......@@ -317,11 +251,6 @@ def gpt_model_forward(
loss = 0
# Multi token prediction module
if self.num_nextn_predict_layers and self.training:
if not self.share_embeddings_and_output_weights and self.share_mtp_embedding_and_output_weight:
output_weight = self.output_layer.weight
output_weight.zero_out_wgrad = True
embedding_weight = self.shared_embedding_or_mtp_embedding_weight() if self.share_mtp_embedding_and_output_weight else None
mtp_hidden_states = hidden_states
for i in range(self.num_nextn_predict_layers):
mtp_hidden_states, mtp_loss = self.mtp_layers[i](
......@@ -333,7 +262,8 @@ def gpt_model_forward(
inference_params,
packed_seq_params,
extra_block_kwargs,
embeding_weight=embedding_weight,
embedding_layer=self.embedding,
output_layer=self.output_layer,
output_weight=output_weight,
)
......
......@@ -46,7 +46,6 @@ class MultiTokenPredictor(MegatronModule):
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
share_mtp_embedding_and_output_weight=True,
recompute_mtp_norm=False,
recompute_mtp_layer=False,
add_output_layer_bias=False
......@@ -65,20 +64,10 @@ class MultiTokenPredictor(MegatronModule):
self.parallel_output = parallel_output
self.position_embedding_type = position_embedding_type
# share with main model
self.share_mtp_embedding_and_output_weight = share_mtp_embedding_and_output_weight
self.recompute_layer_norm = recompute_mtp_norm
self.recompute_mtp_layer = recompute_mtp_layer
self.add_output_layer_bias = add_output_layer_bias
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=self.position_embedding_type,
skip_weight_param_allocation=self.pre_process and self.share_mtp_embedding_and_output_weight
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
......@@ -138,23 +127,6 @@ class MultiTokenPredictor(MegatronModule):
self.embedding_activation_buffer = None
self.grad_output_buffer = None
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
column_parallel_linear_impl = FluxColumnParallelLinear
else:
column_parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = column_parallel_linear_impl(
self.config.hidden_size,
self.vocab_size,
config=self.config,
init_method=self.config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
def forward(
self,
hidden_input_ids: Tensor,
......@@ -165,16 +137,16 @@ class MultiTokenPredictor(MegatronModule):
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
embeding_weight: Optional[torch.Tensor] = None,
output_weight: Optional[torch.Tensor] = None,
embedding_layer=None,
output_layer=None,
output_weight=None
):
"""Forward function of the MTP module"""
# Decoder embedding.
decoder_input = self.embedding(
decoder_input = embedding(
input_ids=embed_input_ids,
position_ids=position_ids,
weight=embeding_weight,
)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
......@@ -251,7 +223,7 @@ class MultiTokenPredictor(MegatronModule):
else:
finalnorm_output = hidden_states
logits, _ = self.output_layer(finalnorm_output, weight=output_weight)
logits, _ = output_layer(finalnorm_output, weight=output_weight)
if self.recompute_layer_norm:
self.finalnorm_ckpt.discard_output()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment