from argparse import ArgumentParser

from ..base_feature import BaseFeature


class MTPFeature(BaseFeature):

    def __init__(self):
        super().__init__('schedules-method')

    def register_args(self, parser: ArgumentParser):
        group = parser.add_argument_group(title=self.feature_name)
        group.add_argument('--schedules-method', type=str,
                           default=None, choices=['dualpipev'])

    def register_patches(self, patch_manager, args):
        from ...core.distributed.finalize_model_grads import _allreduce_word_embedding_grads
        from ...core.models.common.language_module.language_module import (
            setup_embeddings_and_output_layer,
            tie_embeddings_and_output_weights_state_dict,
        )
        from ...core.models.gpt.gpt_model import GPTModel
        from ...training.utils import get_batch_on_this_tp_rank
        from ...core.pipeline_parallel.schedules import forward_step_wrapper
        from ...core import transformer_block_init_wrapper

        MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads',
                                    _allreduce_word_embedding_grads)


        # LanguageModule
        MegatronAdaptation.register(
            'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer',
            setup_embeddings_and_output_layer)
        MegatronAdaptation.register(
            'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict',
            tie_embeddings_and_output_weights_state_dict)

        MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)

        # GPT Model
        MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel', GPTModel)
        
        # Transformer block
        MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
                                    transformer_block_init_wrapper)

        # pipeline_parallel.schedules.forward_step
        MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step',
                                    forward_step_wrapper,
                                    apply_wrapper=True)
