# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.

import contextlib
from functools import wraps
from typing import Iterator, List, Union

import torch

from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.training import get_args
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
    get_attr_wrapped_model,
    get_model_config,
    get_model_type,
)
from megatron.core.pipeline_parallel.schedules import clear_embedding_activation_buffer, deallocate_output_tensor
from megatron.core import ModelParallelConfig
from megatron.core.pipeline_parallel.p2p_communication import _communicate
from megatron.core.pipeline_parallel.schedules import (
    backward_step,
    set_current_microbatch,
    check_first_val_step,
    finish_embedding_wgrad_compute
)
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore


# Types
Shape = Union[List[int], torch.Size]
LOSS_BACKWARD_SCALE = torch.tensor(1.0)


_DUALPIPE_CHUNK = None


def set_dualpipe_chunk(chunk_id):
    """set_dualpipe_chunk for fp16forward patch"""
    global _DUALPIPE_CHUNK
    _DUALPIPE_CHUNK = chunk_id


def get_dualpipe_chunk():
    global _DUALPIPE_CHUNK
    if _DUALPIPE_CHUNK is not None:
        return _DUALPIPE_CHUNK
    else:
        raise AssertionError("_DUALPIPE_CHUNK is None")


def is_dualpipev_last_stage(model_chunk_id):
    return parallel_state.is_pipeline_first_stage(ignore_virtual=True) and model_chunk_id == 1


def send_forward(output_tensor: torch.Tensor, tensor_shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> None:
    """Send tensor to next rank in pipeline (forward send).

    See _communicate for argument details.
    """
    tensor_send_next, tensor_send_prev = None, None
    if model_chunk_id == 0:
        if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            return None
        tensor_send_next = output_tensor
    else:
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            return None
        tensor_send_prev = output_tensor

    if config.timers is not None:
        config.timers('forward-send', log_level=2).start()

    _, _, fwd_wait_handles = _communicate(
        tensor_send_next=tensor_send_next,
        tensor_send_prev=tensor_send_prev,
        recv_prev=False,
        recv_next=False,
        tensor_shape=tensor_shape,
        config=config,
        wait_on_reqs=(not async_op)
    )
    if config.timers is not None:
        config.timers('forward-send').stop()

    return fwd_wait_handles


def send_backward(input_tensor_grad: torch.Tensor, tensor_shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> None:
    """Send tensor to next rank in pipeline (forward send).

    See _communicate for argument details.
    """

    tensor_send_next, tensor_send_prev = None, None
    if model_chunk_id == 0:
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            return None
        tensor_send_prev = input_tensor_grad
    else:
        if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            return None
        tensor_send_next = input_tensor_grad

    if config.timers is not None:
        config.timers('backward-send', log_level=2).start()
    _, _, reqs = _communicate(
        tensor_send_next=tensor_send_next,
        tensor_send_prev=tensor_send_prev,
        recv_prev=False,
        recv_next=False,
        tensor_shape=tensor_shape,
        config=config,
        wait_on_reqs=(not async_op)
    )
    if config.timers is not None:
        config.timers('backward-send').stop()
    return reqs


def recv_forward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> torch.Tensor:
    """ Receive tensor from previous rank in pipeline (forward receive).

    See _communicate for argument details.
    """
    recv_prev, recv_next = False, False
    if model_chunk_id == 0:
        recv_prev = True
    else:
        recv_next = True

    if (
        (parallel_state.is_pipeline_first_stage(ignore_virtual=True) and recv_prev)
        or (parallel_state.is_pipeline_last_stage(ignore_virtual=True) and recv_next)
    ):
        fwd_wait_handles = None
        return None, fwd_wait_handles
    else:
        if config.timers is not None:
            config.timers('forward-recv', log_level=2).start()
        tensor_recv_prev, tensor_recv_next, fwd_wait_handles = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=recv_prev,
            recv_next=recv_next,
            tensor_shape=tensor_shape,
            config=config,
            wait_on_reqs=(not async_op),
        )
        if config.timers is not None:
            config.timers('forward-recv').stop()

    if recv_prev:
        return tensor_recv_prev, fwd_wait_handles
    else:
        return tensor_recv_next, fwd_wait_handles


def recv_backward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> torch.Tensor:
    """Receive tensor from next rank in pipeline (backward receive).

    See _communicate for argument details.
    """
    recv_prev, recv_next = False, False
    if model_chunk_id == 0:
        recv_next = True
    else:
        recv_prev = True

    if (
        (parallel_state.is_pipeline_first_stage(ignore_virtual=True) and recv_prev)
        or (parallel_state.is_pipeline_last_stage(ignore_virtual=True) and recv_next)
    ):
        output_tensor_grad = None
        bwd_wait_handles = None
        return output_tensor_grad, bwd_wait_handles
    else:

        if config.timers is not None:
            config.timers('backward-recv', log_level=2).start()
        tensor_recv_prev, tensor_recv_next, bwd_wait_handles = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=recv_prev,
            recv_next=recv_next,
            tensor_shape=tensor_shape,
            config=config,
            wait_on_reqs=(not async_op)
        )
        if config.timers is not None:
            config.timers('backward-recv').stop()

    if recv_prev:
        return tensor_recv_prev, bwd_wait_handles
    else:
        return tensor_recv_next, bwd_wait_handles


def send_forward_recv_forward(
    output_tensor: torch.Tensor,
    tensor_shape: Shape,
    config: ModelParallelConfig,
    model_chunk_id,
    async_op=False
) -> torch.Tensor:
    """Batched recv from previous rank and send to next rank in pipeline.

    See _communicate for argument details.
    """
    recv_prev, recv_next = False, False
    tensor_send_next, tensor_send_prev = None, None
    if model_chunk_id == 0:
        if not parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            tensor_send_next = output_tensor
        if not parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            recv_prev = True
    if model_chunk_id == 1:
        if not parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            tensor_send_prev = output_tensor
        if not parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            recv_next = True

    if config.timers is not None:
        config.timers('forward-send-forward-recv', log_level=2).start()
    tensor_recv_prev, tensor_recv_next, fwd_wait_handles = _communicate(
        tensor_send_next=tensor_send_next,
        tensor_send_prev=tensor_send_prev,
        recv_prev=recv_prev,
        recv_next=recv_next,
        tensor_shape=tensor_shape,
        wait_on_reqs=(not async_op),
        config=config
    )
    if config.timers is not None:
        config.timers('forward-send-forward-recv').stop()

    if model_chunk_id == 0:
        if not parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            return tensor_recv_prev, fwd_wait_handles
        else:
            return None, fwd_wait_handles
    else:
        if not parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            return tensor_recv_next, fwd_wait_handles
        else:
            return None, fwd_wait_handles


# TODO (dongcl)
def send_forward_recv_slave_forward(
    output_tensor: torch.Tensor,
    tensor_shape: Shape,
    config: ModelParallelConfig,
    fwd_model_chunk_id,
    async_op=False,
) -> torch.Tensor:
    """Batched recv from previous rank and send to next rank in pipeline.
    See _communicate for argument details.
    """
    recv_prev, recv_next = False, False
    tensor_send_next, tensor_send_prev = None, None
    if fwd_model_chunk_id == 0:
        if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            return None, None
        tensor_send_next = output_tensor
        recv_next = True
    if fwd_model_chunk_id == 1:
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            return None, None
        tensor_send_prev = output_tensor
        recv_prev = True
    if config.timers is not None:
        config.timers('forward-send-slave-forward-recv', log_level=2).start()
    tensor_recv_prev, tensor_recv_next, fwd_wait_handles = _communicate(
        tensor_send_next=tensor_send_next,
        tensor_send_prev=tensor_send_prev,
        recv_prev=recv_prev,
        recv_next=recv_next,
        tensor_shape=tensor_shape,
        wait_on_reqs=(not async_op),
        config=config,
    )
    if config.timers is not None:
        config.timers('forward-send-slave-forward-recv').stop()

    if fwd_model_chunk_id == 0:
        return tensor_recv_next, fwd_wait_handles
    else:
        return tensor_recv_prev, fwd_wait_handles


def send_backward_recv_slave_backward(
    input_tensor_grad: torch.Tensor,
    tensor_shape: Shape,
    config: ModelParallelConfig,
    fwd_model_chunk_id,
    async_op=False,
) -> torch.Tensor:
    """Batched recv from previous rank and send to next rank in pipeline.
    See _communicate for argument details.
    """
    recv_prev, recv_next = False, False
    tensor_send_next, tensor_send_prev = None, None
    if fwd_model_chunk_id == 0:
        if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            return None, None
        tensor_send_next = input_tensor_grad
        recv_next = True
    if fwd_model_chunk_id == 1:
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            return None, None
        tensor_send_prev = input_tensor_grad
        recv_prev = True
    if config.timers is not None:
        config.timers('forward-send-slave-forward-recv', log_level=2).start()
    tensor_recv_prev, tensor_recv_next, fwd_wait_handles = _communicate(
        tensor_send_next=tensor_send_next,
        tensor_send_prev=tensor_send_prev,
        recv_prev=recv_prev,
        recv_next=recv_next,
        tensor_shape=tensor_shape,
        wait_on_reqs=(not async_op),
        config=config,
    )
    if config.timers is not None:
        config.timers('forward-send-slave-forward-recv').stop()

    if fwd_model_chunk_id == 0:
        return tensor_recv_next, fwd_wait_handles
    else:
        return tensor_recv_prev, fwd_wait_handles


def generate_dualpipev_schedule(pp_size, num_microbatches):
    num_microbatches = num_microbatches * 2
    num_warmup_stages = [0] * pp_size
    num_interleaved_forward_stages = [0] * pp_size
    num_1b1w1f_stages = [0] * pp_size
    num_overlap_stages = [0] * pp_size
    num_1b1overlap_stages = [0] * pp_size
    num_interleaved_backward_stages = [0] * pp_size
    num_cooldown_stages = [0] * pp_size

    pp_size *= 2
    for i in range(pp_size // 2):
        num_warmup_stages[i] = pp_size - 2 - i * 2

        num_interleaved_forward_stages[i] = i + 1  # 每个单位是一组1f1f

        num_1b1w1f_stages[i] = pp_size // 2 - i - 1

        num_overlap_stages[i] = num_microbatches - pp_size * 2 + i * 2 + 2

        num_1b1overlap_stages[i] = (pp_size // 2 - i - 1) * 2

        num_interleaved_backward_stages[i] = i + 1

        num_cooldown_stages[i] = [i + 1, pp_size - 2 * i - 2, i + 1]

    schedule_all_stages = {
        'warmup': num_warmup_stages,
        'interleaved_forward': num_interleaved_forward_stages,
        '1b1w1f': num_1b1w1f_stages,
        'overlap': num_overlap_stages,
        '1b1overlap': num_1b1overlap_stages,
        'interleaved_backward': num_interleaved_backward_stages,
        'cooldown': num_cooldown_stages
    }

    return schedule_all_stages


def forward_step_no_model_graph(
    forward_step_func,
    model_chunk_id,
    data_iterator,
    model,
    num_microbatches,
    input_tensor,
    forward_data_store,
    config,
    collect_non_loss_data=False,
    checkpoint_activations_microbatch=None,
    is_first_microbatch=False,
    current_microbatch=None,
):
    if config.timers is not None:
        config.timers('forward-compute', log_level=2).start()

    if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
        model.set_is_first_microbatch()
    if current_microbatch is not None:
        set_current_microbatch(model, current_microbatch)

    unwrap_output_tensor = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_output_tensor = True

    set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
    set_input_tensor(input_tensor)

    if config.enable_autocast:
        context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
    else:
        context_manager = contextlib.nullcontext()
    with context_manager:
        if checkpoint_activations_microbatch is None:
            output_tensor, loss_func = forward_step_func(data_iterator, model)
        else:
            output_tensor, loss_func = forward_step_func(
                data_iterator, model, checkpoint_activations_microbatch
            )

    num_tokens = torch.tensor(0, dtype=torch.int)
    if is_dualpipev_last_stage(model_chunk_id):
        if not collect_non_loss_data:
            outputs = loss_func(output_tensor)
            if len(outputs) == 3:
                output_tensor, num_tokens, loss_reduced = outputs
                if not config.calculate_per_token_loss:
                    output_tensor /= num_tokens
                    output_tensor *= parallel_state.get_context_parallel_world_size()
                    output_tensor /= num_microbatches
            else:
                # preserve legacy loss averaging behavior (ie, over the number of microbatches)
                assert len(outputs) == 2
                output_tensor, loss_reduced = outputs
                output_tensor *= parallel_state.get_context_parallel_world_size()
                output_tensor /= num_microbatches
            forward_data_store.append(loss_reduced)
        else:
            data = loss_func(output_tensor, non_loss_data=True)
            forward_data_store.append(data)

    if config.timers is not None:
        config.timers('forward-compute').stop()

    # Set the loss scale for the auxiliary loss of the MoE layer.
    # Since we use a trick to do backward on the auxiliary loss, we need to set the scale
    # explicitly.
    if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
        # Calculate the loss scale based on the grad_scale_func if available, else default to 1.
        loss_scale = (
            config.grad_scale_func(torch.ones(1, device=output_tensor.device))
            if config.grad_scale_func is not None
            else torch.ones(1, device=output_tensor.device)
        )
        # Set the loss scale
        if config.calculate_per_token_loss:
            MoEAuxLossAutoScaler.set_loss_scale(loss_scale)
        else:
            MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)

    # Set the loss scale for Multi-Token Prediction (MTP) loss.
    if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
        # Calculate the loss scale based on the grad_scale_func if available, else default to 1.
        loss_scale = (
            config.grad_scale_func(torch.ones(1, device=output_tensor.device))
            if config.grad_scale_func is not None
            else torch.ones(1, device=output_tensor.device)
        )
        # Set the loss scale
        if config.calculate_per_token_loss:
            MTPLossAutoScaler.set_loss_scale(loss_scale)
        else:
            MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)

    # If T5 model and in decoder stack, then send encoder_hidden_state
    # downstream as well.
    model_type = get_model_type(model)
    if (
        parallel_state.is_pipeline_stage_after_split()
        and model_type == ModelType.encoder_and_decoder
    ):
        return [output_tensor, input_tensor[-1]], num_tokens

    if unwrap_output_tensor:
        return output_tensor, num_tokens
    return [output_tensor], num_tokens


shared_embedding = None


def get_shared_embedding_from_dual_chunk():
    assert shared_embedding is not None
    return shared_embedding


def set_shared_embedding_from_dual_chunk(model1, model2):
    global shared_embedding
    if shared_embedding is not None:
        return
    if model1.module.module.pre_process:
        shared_embedding = model1.module.module.embedding.word_embeddings.weight
    elif model2.module.module.pre_process:
        shared_embedding = model2.module.module.embedding.word_embeddings.weight


def forward_backward_pipelining_with_cutinhalf(
    *,
    forward_step_func,
    data_iterator: Union[Iterator, List[Iterator]],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    num_microbatches: int,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int = None,
    forward_only: bool = False,
    collect_non_loss_data: bool = False,
    first_val_step: bool = None,
):
    args = get_args()
    args.moe_fb_overlap = True
    args.dualpipe_no_dw_detach = True

    set_shared_embedding_from_dual_chunk(model[0], model[1])
    assert (
        isinstance(model, list) and len(model) == 2
    ), 'Dualpipe Schedule expects two model chunks'

    assert (
        isinstance(data_iterator, list) and len(data_iterator) == 2
    ), 'Dualpipe Schedule expects two data_iterators'

    config = get_model_config(model[0])
    config.batch_p2p_comm = False

    # Needed only when gradients are finalized in M-Core
    if config.finalize_model_grads_func is not None and not forward_only:
        embedding_module = clear_embedding_activation_buffer(config, model)

    if config.timers is not None:
        config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

    # Disable async grad reductions
    no_sync_func = config.no_sync_func
    if no_sync_func is None:
        no_sync_func = contextlib.nullcontext
    no_sync_context = None

    def disable_grad_sync():
        """Disable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is None:
            no_sync_context = no_sync_func()
            no_sync_context.__enter__()

    def enable_grad_sync():
        """Enable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is not None:
            no_sync_context.__exit__(None, None, None)
            no_sync_context = None

    disable_grad_sync()

    def combined_forward_backward_helper(
        fwd_model_chunk_id,
        bwd_model_chunk_id,
        fwd_input_tensor=None,
        bwd_output_tensor_grad=None
    ):
        """Helper method to run combined forward and backward step"""
        # forward prepare
        fwd_microbatch_id = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
        f_context = contextlib.nullcontext()
        set_dualpipe_chunk(fwd_model_chunk_id)

        # backward prepare
        b_context = contextlib.nullcontext()
        bwd_input_tensor = input_tensors[bwd_model_chunk_id].pop(0)[1]
        bwd_output_tensor = output_tensors[bwd_model_chunk_id].pop(0)

        output_tensor, num_tokens, input_tensor_grad = forward_backward_step(
            forward_step_func,
            data_iterator[fwd_model_chunk_id] if fwd_model_chunk_id is not None else None,
            model[fwd_model_chunk_id] if fwd_model_chunk_id is not None else None,
            num_microbatches,
            fwd_input_tensor,
            forward_data_store,
            model[bwd_model_chunk_id] if bwd_model_chunk_id is not None else None,
            bwd_input_tensor,
            bwd_output_tensor,
            bwd_output_tensor_grad,
            config,
            f_context=f_context,
            b_context=b_context,
            collect_non_loss_data=collect_non_loss_data,
            checkpoint_activations_microbatch=None,
            is_first_microbatch=False,
            current_microbatch=fwd_microbatch_id,
        )

        # forward post process
        if fwd_model_chunk_id is not None:
            with f_context:
                nonlocal total_num_tokens
                total_num_tokens += num_tokens.item()

                if not forward_only:
                    input_tensors[fwd_model_chunk_id].append((fwd_microbatch_id, fwd_input_tensor))
                    output_tensors[fwd_model_chunk_id].append(output_tensor)

        # backward post process
        if b_model_chunk_id:
            with b_context:
                # launch grad synchronization (custom grad sync)
                # Note: Asynchronous communication tends to slow down compute.
                # To reduce idling from mismatched microbatch times, we launch
                # asynchronous communication at the same time across the
                # pipeline-parallel group.
                if config.grad_sync_func is not None:
                    grad_sync_virtual_microbatch_id = (
                        b_virtual_microbatch_id - pipeline_parallel_rank
                    )
                    if grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
                        grad_sync_virtual_microbatch_id
                    ):
                        grad_sync_chunk_id = get_model_chunk_id(
                            grad_sync_virtual_microbatch_id, forward=False
                        )
                        enable_grad_sync()
                        config.grad_sync_func[grad_sync_chunk_id](
                            model[grad_sync_chunk_id].parameters()
                        )
                        synchronized_model_chunks.add(grad_sync_chunk_id)
                disable_grad_sync()
                if input_tensor is not None:
                    assert input_tensor_grad is not None

        return output_tensor, input_tensor_grad

    # Compute number of steps for each stage
    pp_size = parallel_state.get_pipeline_model_parallel_world_size()
    rank = parallel_state.get_pipeline_model_parallel_rank()
    schedule = generate_dualpipev_schedule(pp_size, num_microbatches)

    model_type = get_model_type(model[0])

    tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
    tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
    if config.sequence_parallel:
        tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()

    total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
    forward_data_store = []
    if not forward_only:
        input_tensors = [[], []]
        output_tensors = [[], []]

    master_chunk_id = 0
    slave_chunk_id = 1

    master_cur_microbatch = 0
    slave_cur_microbatch = num_microbatches
    master_microbatch_max = num_microbatches
    slave_microbatch_max = num_microbatches * 2

    checkpoint_activations_microbatch = None
    fwd_wait_handles_warmup = None

    def forward_step_helper(input_tensor, model_chunk_id, cur_microbatch, is_first_microbatch=False):
        set_dualpipe_chunk(model_chunk_id)
        output_tensor, num_tokens = forward_step_no_model_graph(
            forward_step_func,
            model_chunk_id,
            data_iterator[model_chunk_id],
            model[model_chunk_id],
            num_microbatches,
            input_tensor,
            forward_data_store,
            config,
            collect_non_loss_data,
            checkpoint_activations_microbatch,
            is_first_microbatch=is_first_microbatch,
            current_microbatch=cur_microbatch
        )

        nonlocal total_num_tokens
        total_num_tokens += num_tokens.item()
        if not forward_only:
            input_tensors[model_chunk_id].append(
                (cur_microbatch, input_tensor))
            output_tensors[model_chunk_id].append(output_tensor)

        return output_tensor

    def backward_step_helper(input_tensor, output_tensor, output_tensor_grad, is_last_microbatch=False):
        # # launch grad synchronization (default)
        # if config.grad_sync_func is None and is_last_microbatch:
        #         enable_grad_sync()

        input_tensor_grad = backward_step(
            input_tensor, output_tensor, output_tensor_grad, model_type, config
        )

        # # launch grad synchronization (custom grad sync)
        # # Note: Asynchronous communication tends to slow down compute.
        # # To reduce idling from mismatched microbatch times, we launch
        # # asynchronous communication at the same time across the
        # # pipeline-parallel group.
        # if config.grad_sync_func is not None:
        #     grad_sync_virtual_microbatch_id = virtual_microbatch_id - pipeline_parallel_rank
        #     if grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
        #         grad_sync_virtual_microbatch_id
        #     ):
        #         grad_sync_chunk_id = get_model_chunk_id(
        #             grad_sync_virtual_microbatch_id, forward=False
        #         )
        #         enable_grad_sync()
        #         config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
        #         synchronized_model_chunks.add(grad_sync_chunk_id)
        # disable_grad_sync()

        return input_tensor_grad

    # Run warmup forward passes
    input_tensor, _ = recv_forward(tensor_shape, config, master_chunk_id)
    for i in range(schedule['warmup'][rank]):
        is_first_microbatch = check_first_val_step(first_val_step, forward_only, i == 0)
        output_tensor_warmup = forward_step_helper(
            input_tensor,
            master_chunk_id,
            master_cur_microbatch,
            is_first_microbatch=is_first_microbatch
        )

        master_cur_microbatch += 1

        if i != schedule['warmup'][rank] - 1:
            input_tensor, _ = send_forward_recv_forward(
                output_tensor_warmup, tensor_shape, config, master_chunk_id)
            if not forward_only:
                deallocate_output_tensor(
                    output_tensor_warmup, config.deallocate_pipeline_outputs)
        else:
            input_tensor, _ = recv_forward(
                tensor_shape, config, master_chunk_id)
            fwd_wait_handles_warmup = send_forward(
                output_tensor_warmup, tensor_shape, config, master_chunk_id, async_op=True)

    # Run interleaved forward passes for two model chunk
    fwd_wait_handles = None
    fwd_wait_handles_slave_chunk = None
    fwd_wait_handles_send = None
    for i in range(schedule['interleaved_forward'][rank]):
        if fwd_wait_handles is not None:
            for req, req_handle in fwd_wait_handles.items():
                if req_handle is not None:
                    req_handle.wait()
            fwd_wait_handles = None

        is_first_microbatch = parallel_state.is_pipeline_last_stage(ignore_virtual=True) and (i == 0)
        is_first_microbatch = check_first_val_step(first_val_step, forward_only, is_first_microbatch)
        output_tensor = forward_step_helper(
            input_tensor,
            master_chunk_id,
            master_cur_microbatch,
            is_first_microbatch=is_first_microbatch
        )
        master_cur_microbatch += 1

        if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and fwd_wait_handles_send is not None:
            for req, req_handle in fwd_wait_handles_send.items():
                if req_handle is not None:
                    req_handle.wait()
            fwd_wait_handles_send = None

            if not forward_only:
                deallocate_output_tensor(
                    output_tensor_send, config.deallocate_pipeline_outputs)

        if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            if not forward_only:
                input_tensor_slave = output_tensor.detach()
                input_tensor_slave.requires_grad = True
            else:
                input_tensor_slave = output_tensor
        else:
            input_tensor_slave, _ = recv_forward(
                tensor_shape, config, slave_chunk_id)

        input_tensor, fwd_wait_handles = recv_forward(
            tensor_shape, config, master_chunk_id, async_op=True)

        if fwd_wait_handles_warmup is not None:
            for req, req_handle in fwd_wait_handles_warmup.items():
                if req_handle is not None:
                    req_handle.wait()
            fwd_wait_handles_warmup = None

            if not forward_only:
                deallocate_output_tensor(
                    output_tensor_warmup, config.deallocate_pipeline_outputs)

        if fwd_wait_handles_slave_chunk is not None:
            for req, req_handle in fwd_wait_handles_slave_chunk.items():
                if req_handle is not None:
                    req_handle.wait()
            fwd_wait_handles_slave_chunk = None

            if not forward_only:
                deallocate_output_tensor(
                    output_tensor_slave_chunk, config.deallocate_pipeline_outputs)

        is_first_microbatch = check_first_val_step(first_val_step, forward_only, i == 0)
        output_tensor_slave_chunk = forward_step_helper(
            input_tensor_slave,
            slave_chunk_id,
            slave_cur_microbatch,
            is_first_microbatch=is_first_microbatch
        )
        slave_cur_microbatch += 1

        if not forward_only:
            if i == schedule['interleaved_forward'][rank] - 1:
                firstFB_no_overlp_handle = None
                # last rank not overlap first F&B
                if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                    output_tensor_grad_bwd, firstFB_no_overlp_handle = recv_backward(
                        tensor_shape, config, slave_chunk_id, async_op=True)
                else:
                    output_tensor_grad_bwd, _ = recv_backward(
                        tensor_shape, config, slave_chunk_id)

        fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk,
                                                    tensor_shape, config, slave_chunk_id, async_op=True)

        if not parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            output_tensor_send = output_tensor
            fwd_wait_handles_send = send_forward(
                output_tensor_send, tensor_shape, config, master_chunk_id, async_op=True)
        else:
            # custom_backward requires output_tensor.numel() == 1
            deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)

    if fwd_wait_handles is not None:
        for req, req_handle in fwd_wait_handles.items():
            if req_handle is not None:
                req_handle.wait()
        fwd_wait_handles = None

    # Run 1b1w1f stages for slave chunk
    bwd_wait_handles = None
    for _ in range(schedule['1b1w1f'][rank]):
        if not forward_only:
            input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
            output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)

            input_tensor_grad = backward_step(
                input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
            )

        if fwd_wait_handles_slave_chunk is not None:
            for req in fwd_wait_handles_slave_chunk:
                req.wait()
            fwd_wait_handles_slave_chunk = None
            if not forward_only:
                deallocate_output_tensor(
                    output_tensor_slave_chunk, config.deallocate_pipeline_outputs)

        if fwd_wait_handles_send is not None:
            for req, req_handle in fwd_wait_handles_send.items():
                if req_handle is not None:
                    req_handle.wait()
            fwd_wait_handles_send = None
            if not forward_only:
                deallocate_output_tensor(
                    output_tensor, config.deallocate_pipeline_outputs)

        if not forward_only:
            # If asynchronous, the memory will rise.
            bwd_wait_handles = send_backward(input_tensor_grad,
                                             tensor_shape, config, slave_chunk_id)

        # If asynchronous, the memory will rise.
        input_tensor_slave, recv_forward_handle = recv_forward(
            tensor_shape, config, slave_chunk_id)

        if recv_forward_handle is not None:
            for req, handle in recv_forward_handle.items():
                if handle is not None:
                    handle.wait()
            recv_forward_handle = None

        # 1F: Forward pass
        output_tensor = forward_step_helper(
            input_tensor_slave,
            slave_chunk_id,
            slave_cur_microbatch,
            is_first_microbatch=False
        )
        slave_cur_microbatch += 1

        if not forward_only:
            output_tensor_grad_bwd, _ = recv_backward(
                tensor_shape, config, slave_chunk_id)

        fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk,
                                                    tensor_shape, config, slave_chunk_id, async_op=True)

    # Run overlaping f&bw stages
    fwd_wait_handles = None
    bwd_wait_handles = None
    fwd_wait_handles_recv = None
    fwd_model_chunk_id = master_chunk_id
    bwd_model_chunk_id = slave_chunk_id
    num_overlap_steps = schedule['overlap'][rank] + schedule['1b1overlap'][rank]
    if not forward_only:
        num_overlap_steps += schedule['interleaved_backward'][rank]
    for step_id in range(num_overlap_steps):
        only_bwd = False
        if fwd_model_chunk_id == master_chunk_id and master_cur_microbatch == master_microbatch_max:
            only_bwd = True
        if fwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch == slave_microbatch_max:
            only_bwd = True

        if not only_bwd:
            def pp_pre_forward():
                nonlocal fwd_wait_handles_recv

                if fwd_wait_handles_recv is not None:
                    for req, req_handle in fwd_wait_handles_recv.items():
                        req_handle.wait()
                    fwd_wait_handles_recv = None

            def pp_post_forward(output_tensor):
                nonlocal master_cur_microbatch
                nonlocal slave_cur_microbatch
                nonlocal fwd_wait_handles
                nonlocal fwd_wait_handles_slave_chunk
                nonlocal firstFB_no_overlp_handle

                if fwd_model_chunk_id == master_chunk_id:
                    master_cur_microbatch += 1
                    fwd_send_only = False
                else:
                    slave_cur_microbatch += 1
                    fwd_send_only = (master_cur_microbatch == master_microbatch_max)

                # 同步上个阶段最后一个slave前向send
                if fwd_wait_handles_slave_chunk is not None:
                    for req, req_handle in fwd_wait_handles_slave_chunk.items():
                        if req_handle is not None:
                            req_handle.wait()
                    fwd_wait_handles_slave_chunk = None
                    if not forward_only:
                        deallocate_output_tensor(
                            output_tensor_slave_chunk, config.deallocate_pipeline_outputs)

                if fwd_send_only:
                    input_tensor = None
                    fwd_wait_handles = send_forward(
                        output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
                else:
                    if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
                        if not forward_only:
                            input_tensor = output_tensor.detach()
                            input_tensor.requires_grad = True
                            deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
                        else:
                            input_tensor = output_tensor
                    else:
                        input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
                            output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)

                if not forward_only and firstFB_no_overlp_handle is not None:
                    for req, req_handle in firstFB_no_overlp_handle.items():
                        if req_handle is not None:
                            req_handle.wait()
                    firstFB_no_overlp_handle = None

                return input_tensor

            def pp_pre_backward():
                nonlocal bwd_wait_handles

                if not forward_only:
                    if bwd_wait_handles is not None:
                        for _, req_handle in bwd_wait_handles.items():
                            if req_handle is not None:
                                req_handle.wait()
                        bwd_wait_handles = None

            def pp_post_backward(input_tensor_grad):
                nonlocal fwd_wait_handles
                nonlocal bwd_wait_handles

                if fwd_wait_handles is not None:
                    for _, req_handle in fwd_wait_handles.items():
                        if req_handle is not None:
                            req_handle.wait()
                    fwd_wait_handles = None
                    if not forward_only:
                        deallocate_output_tensor(
                            output_tensor, config.deallocate_pipeline_outputs)

                if not forward_only:
                    if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
                        output_tensor_grad = input_tensor_grad
                    else:
                        output_tensor_grad, bwd_wait_handles = send_backward_recv_slave_backward(
                            input_tensor_grad,
                            tensor_shape,
                            config,
                            fwd_model_chunk_id,
                            async_op=True
                        )
                else:
                    output_tensor_grad = None

                return output_tensor_grad

            # forward
            pp_pre_forward()
            fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
            output_tensor = forward_step_helper(
                input_tensor,
                fwd_model_chunk_id,
                fwd_microbatch,
                is_first_microbatch=False
            )
            input_tensor = pp_post_forward(output_tensor)

            # backward
            pp_pre_backward()
            if not forward_only:
                input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
                output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
                input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)
            else:
                input_tensor_grad = None
            output_tensor_grad_bwd = pp_post_backward(input_tensor_grad)

        # only run backward
        else:
            if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
                input_tensor, fwd_wait_handles_recv = recv_forward(
                    tensor_shape, config, slave_chunk_id, async_op=True)
            if not forward_only:
                if bwd_wait_handles is not None:
                    for req, req_handle in bwd_wait_handles.items():
                        if req_handle is not None:
                            req_handle.wait()
                    bwd_wait_handles = None

                input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
                output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
                input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)

                if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
                    output_tensor_grad_bwd = input_tensor_grad
                else:
                    #  send_backward_recv_slave_backward
                    output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
                                                                                               tensor_shape, config, fwd_model_chunk_id)

        # swap fwd & bwd chunks
        fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id

    if not forward_only:
        # Run cooldown phases
        merged_input_tensors = []
        merged_output_tensors = []
        while len(input_tensors[0]) > 0 or len(input_tensors[1]) > 0:
            if len(input_tensors[bwd_model_chunk_id]) > 0:
                merged_input_tensors.append(
                    input_tensors[bwd_model_chunk_id].pop(0))
                merged_output_tensors.append(
                    (output_tensors[bwd_model_chunk_id].pop(0), bwd_model_chunk_id))

            if len(input_tensors[1 - bwd_model_chunk_id]) > 0:
                merged_input_tensors.append(
                    input_tensors[1 - bwd_model_chunk_id].pop(0))
                merged_output_tensors.append(
                    (output_tensors[1 - bwd_model_chunk_id].pop(0), 1 - bwd_model_chunk_id))

        bwd_wait_handles_recv = None
        for i in range(pp_size):

            if bwd_wait_handles is not None:
                for req, req_handle in bwd_wait_handles.items():
                    if req_handle is not None:
                        req_handle.wait()
                bwd_wait_handles = None
            if bwd_wait_handles_recv is not None:
                for req, req_handle in bwd_wait_handles_recv.items():
                    if req_handle is not None:
                        req_handle.wait()
                bwd_wait_handles_recv = None

            input_tensor_bwd = merged_input_tensors.pop(0)[1]
            output_tensor_bwd, bwd_model_chunk_id = merged_output_tensors.pop(0)

            input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)

            if i == pp_size - 1:
                bwd_wait_handles = send_backward(input_tensor_grad,
                                                 tensor_shape, config, bwd_model_chunk_id, async_op=True)
            elif i >= schedule['cooldown'][rank][0] - 1:
                bwd_wait_handles = send_backward(input_tensor_grad,
                                                 tensor_shape, config, bwd_model_chunk_id, async_op=True)
                output_tensor_grad_bwd, bwd_wait_handles_recv = recv_backward(
                    tensor_shape, config, bwd_model_chunk_id, async_op=True)
            else:
                if parallel_state.is_pipeline_last_stage() and (1 - bwd_model_chunk_id) == master_chunk_id:
                    output_tensor_grad_bwd = input_tensor_grad
                else:
                    #  send_backward_recv_slave_backward
                    output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
                                                                                               tensor_shape, config, 1 - bwd_model_chunk_id)

        if bwd_wait_handles is not None:
            for req, req_handle in bwd_wait_handles.items():
                if req_handle is not None:
                    req_handle.wait()
            bwd_wait_handles = None

    if config.finalize_model_grads_func is not None and not forward_only:

        # If defer_embedding_wgrad_compute is enabled we need to do the
        # weight gradient GEMM's here.
        finish_embedding_wgrad_compute(config, embedding_module)

        # Finalize model grads (perform full grad all-reduce / reduce-scatter for
        # data parallelism, layernorm all-reduce for sequence parallelism, and
        # embedding all-reduce for pipeline parallelism).
        config.finalize_model_grads_func(
            model, total_num_tokens if config.calculate_per_token_loss else None
        )

    return forward_data_store
