Commit 69add73b authored by dongcl's avatar dongcl
Browse files

dualpipev support moe a2a overlap

parent 62f16817
__pycache__ __pycache__
*.bak *.bak
*.log
...@@ -69,16 +69,12 @@ class PipelineFeature(AbstractFeature): ...@@ -69,16 +69,12 @@ class PipelineFeature(AbstractFeature):
patch_manager.register_patch( patch_manager.register_patch(
'megatron.training.training.evaluate', evaluate) 'megatron.training.training.evaluate', evaluate)
if ( if args.combined_1f1b:
args.schedule_method == "interleaved_1f1b"
and args.combined_1f1b
):
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear
from dcu_megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher from dcu_megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from dcu_megatron.core.transformer.transformer_layer import TransformerLayer from dcu_megatron.core.transformer.transformer_layer import TransformerLayer
from dcu_megatron.core.models.gpt.gpt_model import GPTModel from dcu_megatron.core.models.gpt.gpt_model import GPTModel
from dcu_megatron.core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
from dcu_megatron.core.extensions.transformer_engine import ( from dcu_megatron.core.extensions.transformer_engine import (
_get_extra_te_kwargs_wrapper, _get_extra_te_kwargs_wrapper,
TELinear, TELinear,
...@@ -89,53 +85,55 @@ class PipelineFeature(AbstractFeature): ...@@ -89,53 +85,55 @@ class PipelineFeature(AbstractFeature):
from dcu_megatron.core.transformer.moe.experts import TEGroupedMLP from dcu_megatron.core.transformer.moe.experts import TEGroupedMLP
from dcu_megatron.core.transformer.moe.moe_layer import MoELayer from dcu_megatron.core.transformer.moe.moe_layer import MoELayer
# num_warmup_microbatches + 1 patch_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
patches_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
MoEAlltoAllTokenDispatcher) MoEAlltoAllTokenDispatcher)
patches_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer', patch_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer',
TransformerLayer) TransformerLayer)
patches_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan', patch_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan',
GPTModel.build_schedule_plan, GPTModel.build_schedule_plan,
create_dummy=True) create_dummy=True)
# backward_dw # backward_dw
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs', patch_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper, _get_extra_te_kwargs_wrapper,
apply_wrapper=True) apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear', patch_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear) TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear', patch_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
TELayerNormColumnParallelLinear) TELayerNormColumnParallelLinear)
TEColumnParallelLinear.__bases__ = (TELinear,) TEColumnParallelLinear.__bases__ = (TELinear,)
TERowParallelLinear.__bases__ = (TELinear,) TERowParallelLinear.__bases__ = (TELinear,)
if is_te_min_version("1.9.0.dev0"): if is_te_min_version("1.9.0.dev0"):
from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear, TERowParallelGroupedLinear from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear, TERowParallelGroupedLinear
from ..core.extensions.transformer_engine import TEGroupedLinear from dcu_megatron.core.extensions.transformer_engine import TEGroupedLinear
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear', patch_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear',
TEGroupedLinear) TEGroupedLinear)
TEColumnParallelGroupedLinear.__bases__ = (TEGroupedLinear,) TEColumnParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
TERowParallelGroupedLinear.__bases__ = (TEGroupedLinear,) TERowParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw', patch_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw',
MLASelfAttention.backward_dw, MLASelfAttention.backward_dw,
create_dummy=True) create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw', patch_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw',
MLP.backward_dw, MLP.backward_dw,
create_dummy=True) create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw', patch_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw',
TEGroupedMLP.backward_dw, TEGroupedMLP.backward_dw,
create_dummy=True) create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw', patch_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw',
MoELayer.backward_dw, MoELayer.backward_dw,
create_dummy=True) create_dummy=True)
if args.schedule_method == "interleaved_1f1b":
from dcu_megatron.core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
...@@ -6,7 +6,6 @@ from typing import Optional ...@@ -6,7 +6,6 @@ from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
from megatron.core import parallel_state
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.inference.contexts import BaseInferenceContext
...@@ -720,7 +719,7 @@ def schedule_chunk_1f1b( ...@@ -720,7 +719,7 @@ def schedule_chunk_1f1b(
if f_schedule_plan is not None and post_forward is not None: if f_schedule_plan is not None and post_forward is not None:
with f_context: with f_context:
f_schedule_plan.wait_current_stream() f_schedule_plan.wait_current_stream()
post_forward(None if parallel_state.is_pipeline_last_stage(ignore_virtual=False) else f_input) post_forward(f_input)
# pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch # pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch
if b_schedule_plan is not None and post_backward is not None: if b_schedule_plan is not None and post_backward is not None:
......
_DUALPIPE_CHUNK = None
def set_dualpipe_chunk(chunk_id):
"""set_dualpipe_chunk for fp16forward patch"""
global _DUALPIPE_CHUNK
_DUALPIPE_CHUNK = chunk_id
def get_dualpipe_chunk():
global _DUALPIPE_CHUNK
if _DUALPIPE_CHUNK is not None:
return _DUALPIPE_CHUNK
else:
raise AssertionError("_DUALPIPE_CHUNK is None")
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch.autograd.variable import Variable from torch.autograd.variable import Variable
from megatron.training import get_args
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel from megatron.core.distributed import DistributedDataParallel
...@@ -15,6 +16,8 @@ from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler ...@@ -15,6 +16,8 @@ from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor
from dcu_megatron.core.parallel_state import get_dualpipe_chunk
def make_viewless(e): def make_viewless(e):
"""make_viewless util func""" """make_viewless util func"""
...@@ -432,7 +435,13 @@ def forward_backward_step( ...@@ -432,7 +435,13 @@ def forward_backward_step(
if f_model: if f_model:
with f_context: with f_context:
num_tokens = torch.tensor(0, dtype=torch.int) num_tokens = torch.tensor(0, dtype=torch.int)
if parallel_state.is_pipeline_last_stage(ignore_virtual=False): args = get_args()
is_last_stage = False
if args.schedule_method == "dualpipev":
is_last_stage = parallel_state.is_pipeline_first_stage() and get_dualpipe_chunk() == 1
else:
is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=False)
if is_last_stage:
if not collect_non_loss_data: if not collect_non_loss_data:
loss_node = ScheduleNode( loss_node = ScheduleNode(
loss_func, loss_func,
......
...@@ -20,7 +20,7 @@ from megatron.training.utils import ( ...@@ -20,7 +20,7 @@ from megatron.training.utils import (
reduce_max_stat_across_model_parallel_group reduce_max_stat_across_model_parallel_group
) )
from dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_schedules import get_dualpipe_chunk from dcu_megatron.core.parallel_state import get_dualpipe_chunk
def dualpipev_fp16forward(self, *inputs, **kwargs): def dualpipev_fp16forward(self, *inputs, **kwargs):
......
...@@ -26,7 +26,8 @@ from megatron.core.pipeline_parallel.schedules import ( ...@@ -26,7 +26,8 @@ from megatron.core.pipeline_parallel.schedules import (
finish_embedding_wgrad_compute finish_embedding_wgrad_compute
) )
from dcu_megatron.training.utils import print_rank_message from dcu_megatron.core.pipeline_parallel.combined_1f1b import forward_backward_step, set_streams, wrap_forward_func
from dcu_megatron.core.parallel_state import set_dualpipe_chunk
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore # from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
...@@ -35,23 +36,6 @@ Shape = Union[List[int], torch.Size] ...@@ -35,23 +36,6 @@ Shape = Union[List[int], torch.Size]
LOSS_BACKWARD_SCALE = torch.tensor(1.0) LOSS_BACKWARD_SCALE = torch.tensor(1.0)
_DUALPIPE_CHUNK = None
def set_dualpipe_chunk(chunk_id):
"""set_dualpipe_chunk for fp16forward patch"""
global _DUALPIPE_CHUNK
_DUALPIPE_CHUNK = chunk_id
def get_dualpipe_chunk():
global _DUALPIPE_CHUNK
if _DUALPIPE_CHUNK is not None:
return _DUALPIPE_CHUNK
else:
raise AssertionError("_DUALPIPE_CHUNK is None")
def is_dualpipev_last_stage(model_chunk_id): def is_dualpipev_last_stage(model_chunk_id):
return parallel_state.is_pipeline_first_stage(ignore_virtual=True) and model_chunk_id == 1 return parallel_state.is_pipeline_first_stage(ignore_virtual=True) and model_chunk_id == 1
...@@ -530,6 +514,13 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -530,6 +514,13 @@ def forward_backward_pipelining_with_cutinhalf(
config = get_model_config(model[0]) config = get_model_config(model[0])
config.batch_p2p_comm = False config.batch_p2p_comm = False
if (
not forward_only
and config.combined_1f1b
):
set_streams()
forward_step_func = wrap_forward_func(config, forward_step_func)
# Needed only when gradients are finalized in M-Core # Needed only when gradients are finalized in M-Core
if config.finalize_model_grads_func is not None and not forward_only: if config.finalize_model_grads_func is not None and not forward_only:
embedding_module = clear_embedding_activation_buffer(config, model) embedding_module = clear_embedding_activation_buffer(config, model)
...@@ -582,86 +573,6 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -582,86 +573,6 @@ def forward_backward_pipelining_with_cutinhalf(
disable_grad_sync() disable_grad_sync()
def combined_forward_backward_helper(
fwd_model_chunk_id,
bwd_model_chunk_id,
fwd_input_tensor=None,
bwd_output_tensor_grad=None,
pre_forward=None,
pre_backward=None,
post_forward=None,
post_backward=None,
):
"""Helper method to run combined forward and backward step"""
# forward prepare
fwd_microbatch_id = cur_fwd_chunk_microbatch[fwd_model_chunk_id]
f_context = contextlib.nullcontext()
set_dualpipe_chunk(fwd_model_chunk_id)
# backward prepare
b_context = contextlib.nullcontext()
bwd_input_tensor = input_tensors[bwd_model_chunk_id].pop(0)[1]
bwd_output_tensor = output_tensors[bwd_model_chunk_id].pop(0)
output_tensor, num_tokens, input_tensor_grad = forward_backward_step(
forward_step_func,
data_iterator[fwd_model_chunk_id] if fwd_model_chunk_id is not None else None,
model[fwd_model_chunk_id] if fwd_model_chunk_id is not None else None,
num_microbatches,
fwd_input_tensor,
forward_data_store,
model[bwd_model_chunk_id] if bwd_model_chunk_id is not None else None,
bwd_input_tensor,
bwd_output_tensor,
bwd_output_tensor_grad,
config,
f_context=f_context,
b_context=b_context,
collect_non_loss_data=collect_non_loss_data,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=fwd_microbatch_id,
)
# forward post process
if fwd_model_chunk_id is not None:
with f_context:
nonlocal total_num_tokens
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[fwd_model_chunk_id].append((fwd_microbatch_id, fwd_input_tensor))
output_tensors[fwd_model_chunk_id].append(output_tensor)
# backward post process
if b_model_chunk_id:
with b_context:
# launch grad synchronization (custom grad sync)
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if config.grad_sync_func is not None:
grad_sync_virtual_microbatch_id = (
b_virtual_microbatch_id - pipeline_parallel_rank
)
if grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
grad_sync_virtual_microbatch_id
):
grad_sync_chunk_id = get_model_chunk_id(
grad_sync_virtual_microbatch_id, forward=False
)
enable_grad_sync()
config.grad_sync_func[grad_sync_chunk_id](
model[grad_sync_chunk_id].parameters()
)
synchronized_model_chunks.add(grad_sync_chunk_id)
disable_grad_sync()
if input_tensor is not None:
assert input_tensor_grad is not None
return output_tensor, input_tensor_grad
# Compute number of steps for each stage # Compute number of steps for each stage
pp_size = parallel_state.get_pipeline_model_parallel_world_size() pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank = parallel_state.get_pipeline_model_parallel_rank() rank = parallel_state.get_pipeline_model_parallel_rank()
...@@ -686,8 +597,6 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -686,8 +597,6 @@ def forward_backward_pipelining_with_cutinhalf(
cur_bwd_chunk_microbatch = [0, num_microbatches] cur_bwd_chunk_microbatch = [0, num_microbatches]
num_chunk_max_microbatch = [num_microbatches, num_microbatches * 2] num_chunk_max_microbatch = [num_microbatches, num_microbatches * 2]
checkpoint_activations_microbatch = None
def wait_comm_handles(comm_handles): def wait_comm_handles(comm_handles):
if comm_handles is None: if comm_handles is None:
return return
...@@ -697,13 +606,19 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -697,13 +606,19 @@ def forward_backward_pipelining_with_cutinhalf(
req_handle.wait() req_handle.wait()
comm_handles = None comm_handles = None
def forward_step_helper(model_chunk_id, cur_microbatch, is_first_microbatch=False): def forward_step_helper(model_chunk_id, cur_microbatch, checkpoint_activations_microbatch=False):
set_dualpipe_chunk(model_chunk_id) set_dualpipe_chunk(model_chunk_id)
if not forward_only: if not forward_only:
offset = cur_bwd_chunk_microbatch[model_chunk_id] offset = cur_bwd_chunk_microbatch[model_chunk_id]
input_tensor = input_tensors[model_chunk_id][cur_microbatch - offset] input_tensor = input_tensors[model_chunk_id][cur_microbatch - offset]
else: else:
input_tensor = input_tensors[model_chunk_id][0] input_tensor = input_tensors[model_chunk_id][0]
is_first_microbatch = check_first_val_step(
first_val_step,
forward_only,
cur_fwd_chunk_microbatch[model_chunk_id],
),
output_tensor, num_tokens = forward_step_no_model_graph( output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func, forward_step_func,
model_chunk_id, model_chunk_id,
...@@ -759,12 +674,154 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -759,12 +674,154 @@ def forward_backward_pipelining_with_cutinhalf(
return input_tensor_grad return input_tensor_grad
def combined_forward_backward_helper(
fwd_model_chunk_id=None,
bwd_model_chunk_id=None,
pre_forward=None,
pre_backward=None,
post_forward=None,
post_backward=None,
):
"""Helper method to run combined forward and backward step"""
# forward prepare
f_context = contextlib.nullcontext()
fwd_input_tensor = None
fwd_microbatch_id = None
if fwd_model_chunk_id is not None:
fwd_microbatch_id = cur_fwd_chunk_microbatch[fwd_model_chunk_id]
set_dualpipe_chunk(fwd_model_chunk_id)
offset = cur_bwd_chunk_microbatch[fwd_model_chunk_id]
fwd_input_tensor = input_tensors[fwd_model_chunk_id][fwd_microbatch_id - offset]
# backward prepare
b_context = contextlib.nullcontext()
bwd_input_tensor = None
bwd_output_tensor = None
bwd_output_tensor_grad = None
if bwd_model_chunk_id is not None:
# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
bwd_microbatch_id = cur_bwd_chunk_microbatch[bwd_model_chunk_id]
if (
bwd_microbatch_id is not None
and bwd_microbatch_id == num_chunk_max_microbatch[bwd_model_chunk_id] - 1
):
if (
config.grad_sync_func is None
or (bwd_model_chunk_id == slave_chunk_id and parallel_state.is_pipeline_last_stage())
or (bwd_model_chunk_id == master_chunk_id and parallel_state.is_pipeline_first_stage())
):
enable_grad_sync()
bwd_input_tensor = input_tensors[bwd_model_chunk_id].pop(0)
bwd_output_tensor = output_tensors[bwd_model_chunk_id].pop(0)
bwd_output_tensor_grad = output_tensor_grads[bwd_model_chunk_id].pop(0)
output_tensor, num_tokens, input_tensor_grad = forward_backward_step(
forward_step_func,
data_iterator[fwd_model_chunk_id] if fwd_model_chunk_id is not None else None,
model[fwd_model_chunk_id] if fwd_model_chunk_id is not None else None,
num_microbatches,
fwd_input_tensor,
forward_data_store,
model[bwd_model_chunk_id] if bwd_model_chunk_id is not None else None,
bwd_input_tensor,
bwd_output_tensor,
bwd_output_tensor_grad,
config,
f_context=f_context,
b_context=b_context,
pre_forward=pre_forward,
pre_backward=pre_backward,
post_forward=post_forward,
post_backward=post_backward,
collect_non_loss_data=collect_non_loss_data,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=fwd_microbatch_id,
)
# forward post process
if fwd_model_chunk_id is not None:
cur_fwd_chunk_microbatch[fwd_model_chunk_id] += 1
output_tensors[fwd_model_chunk_id].append(output_tensor)
nonlocal total_num_tokens
total_num_tokens += num_tokens.item()
if forward_only:
input_tensors[fwd_model_chunk_id].pop(0)
output_tensors[fwd_model_chunk_id].pop()
# backward post process
if bwd_model_chunk_id is not None:
cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1
return output_tensor, input_tensor_grad
def forward_backward_helper_wrapper(
fwd_model_chunk_id=None,
bwd_model_chunk_id=None,
pre_forward=None,
pre_backward=None,
post_forward=None,
post_backward=None,
checkpoint_activations_microbatch=None,
):
"""
wrap forward_helper、backward_helper、combined_forward_backward_helper in a unified way
"""
if config.combined_1f1b and config.combined_1f1b_recipe == "ep_a2a" and not forward_only:
assert (
checkpoint_activations_microbatch is None
), "checkpoint_activations_microbatch not supported when combined_1f1b is true"
return combined_forward_backward_helper(
fwd_model_chunk_id=fwd_model_chunk_id,
bwd_model_chunk_id=bwd_model_chunk_id,
pre_forward=pre_forward,
pre_backward=pre_backward,
post_forward=post_forward,
post_backward=post_backward,
)
else:
output_tensor = None
input_tensor_grad = None
if fwd_model_chunk_id is not None:
# forward pass
if pre_forward is not None:
pre_forward()
output_tensor = forward_step_helper(
fwd_model_chunk_id,
cur_fwd_chunk_microbatch[fwd_model_chunk_id],
checkpoint_activations_microbatch
)
cur_fwd_chunk_microbatch[fwd_model_chunk_id] += 1
if post_forward is not None:
output_tensor = post_forward(output_tensor)
if bwd_model_chunk_id is not None:
# Backward pass.
if pre_backward is not None:
pre_backward()
input_tensor_grad = backward_step_helper(bwd_model_chunk_id, cur_bwd_chunk_microbatch[bwd_model_chunk_id])
cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1
if post_backward is not None:
input_tensor_grad = post_backward(input_tensor_grad)
return output_tensor, input_tensor_grad
output_tensor = None
output_tensor_master_send = None output_tensor_master_send = None
output_tensor_slave_send = None output_tensor_slave_send = None
fwd_wait_recv_handles = [None, None] fwd_wait_recv_handles = [None, None]
fwd_wait_send_handles = [None, None] fwd_wait_send_handles = [None, None]
bwd_wait_recv_handles = [None, None] bwd_wait_recv_handles = [None, None]
bwd_wait_send_handles = [None, None] bwd_wait_send_handles = [None, None]
checkpoint_activations_microbatch = None
# Run warmup forward passes # Run warmup forward passes
input_tensor, _ = recv_forward(tensor_shape, config, master_chunk_id) input_tensor, _ = recv_forward(tensor_shape, config, master_chunk_id)
...@@ -776,13 +833,10 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -776,13 +833,10 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor, fwd_wait_recv_handles[master_chunk_id] = recv_forward(tensor_shape, config, master_chunk_id, async_op=True) input_tensor, fwd_wait_recv_handles[master_chunk_id] = recv_forward(tensor_shape, config, master_chunk_id, async_op=True)
input_tensors[master_chunk_id].append(input_tensor) input_tensors[master_chunk_id].append(input_tensor)
is_first_microbatch = check_first_val_step(first_val_step, forward_only, i == 0) output_tensor, _ = forward_backward_helper_wrapper(
output_tensor = forward_step_helper( fwd_model_chunk_id=master_chunk_id,
master_chunk_id, checkpoint_activations_microbatch=checkpoint_activations_microbatch,
cur_fwd_chunk_microbatch[master_chunk_id],
is_first_microbatch=is_first_microbatch
) )
cur_fwd_chunk_microbatch[master_chunk_id] += 1
if fwd_wait_send_handles[master_chunk_id] is not None: if fwd_wait_send_handles[master_chunk_id] is not None:
for req, req_handle in fwd_wait_send_handles[master_chunk_id].items(): for req, req_handle in fwd_wait_send_handles[master_chunk_id].items():
...@@ -804,14 +858,10 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -804,14 +858,10 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_slave, fwd_wait_recv_handles[slave_chunk_id] = recv_forward(tensor_shape, config, slave_chunk_id, async_op=True) input_tensor_slave, fwd_wait_recv_handles[slave_chunk_id] = recv_forward(tensor_shape, config, slave_chunk_id, async_op=True)
input_tensors[slave_chunk_id].append(input_tensor_slave) input_tensors[slave_chunk_id].append(input_tensor_slave)
is_first_microbatch = parallel_state.is_pipeline_last_stage(ignore_virtual=True) and (i == 0) output_tensor, _ = forward_backward_helper_wrapper(
is_first_microbatch = check_first_val_step(first_val_step, forward_only, is_first_microbatch) fwd_model_chunk_id=master_chunk_id,
output_tensor_master = forward_step_helper( checkpoint_activations_microbatch=checkpoint_activations_microbatch,
master_chunk_id,
cur_fwd_chunk_microbatch[master_chunk_id],
is_first_microbatch=is_first_microbatch
) )
cur_fwd_chunk_microbatch[master_chunk_id] += 1
if not parallel_state.is_pipeline_last_stage(): if not parallel_state.is_pipeline_last_stage():
wait_comm_handles(fwd_wait_send_handles[master_chunk_id]) wait_comm_handles(fwd_wait_send_handles[master_chunk_id])
...@@ -819,20 +869,20 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -819,20 +869,20 @@ def forward_backward_pipelining_with_cutinhalf(
if not forward_only: if not forward_only:
deallocate_output_tensor(output_tensor_master_send, config.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor_master_send, config.deallocate_pipeline_outputs)
output_tensor_master_send = output_tensor_master output_tensor_master_send = output_tensor
fwd_wait_send_handles[master_chunk_id] = send_forward( fwd_wait_send_handles[master_chunk_id] = send_forward(
output_tensor_master_send, tensor_shape, config, master_chunk_id, async_op=True) output_tensor_master_send, tensor_shape, config, master_chunk_id, async_op=True)
# prepare input for slave chunk # prepare input for slave chunk
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
if not forward_only: if not forward_only:
input_tensor_slave = output_tensor_master.detach() input_tensor_slave = output_tensor.detach()
input_tensor_slave.requires_grad = True input_tensor_slave.requires_grad = True
else: else:
input_tensor_slave = output_tensor_master input_tensor_slave = output_tensor
input_tensors[slave_chunk_id].append(input_tensor_slave) input_tensors[slave_chunk_id].append(input_tensor_slave)
if not forward_only: if not forward_only:
deallocate_output_tensor(output_tensor_master, config.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
else: else:
wait_comm_handles(fwd_wait_recv_handles[slave_chunk_id]) wait_comm_handles(fwd_wait_recv_handles[slave_chunk_id])
...@@ -841,19 +891,16 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -841,19 +891,16 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensors[master_chunk_id].append(input_tensor) input_tensors[master_chunk_id].append(input_tensor)
# slave forward # slave forward
is_first_microbatch = check_first_val_step(first_val_step, forward_only, i == 0) output_tensor, _ = forward_backward_helper_wrapper(
output_tensor_slave = forward_step_helper( fwd_model_chunk_id=slave_chunk_id,
slave_chunk_id, checkpoint_activations_microbatch=checkpoint_activations_microbatch,
cur_fwd_chunk_microbatch[slave_chunk_id],
is_first_microbatch=is_first_microbatch
) )
cur_fwd_chunk_microbatch[slave_chunk_id] += 1
wait_comm_handles(fwd_wait_send_handles[slave_chunk_id]) wait_comm_handles(fwd_wait_send_handles[slave_chunk_id])
if not forward_only: if not forward_only:
deallocate_output_tensor(output_tensor_slave_send, config.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor_slave_send, config.deallocate_pipeline_outputs)
output_tensor_slave_send = output_tensor_slave output_tensor_slave_send = output_tensor
fwd_wait_send_handles[slave_chunk_id] = send_forward(output_tensor_slave_send, tensor_shape, config, slave_chunk_id, async_op=True) fwd_wait_send_handles[slave_chunk_id] = send_forward(output_tensor_slave_send, tensor_shape, config, slave_chunk_id, async_op=True)
# check whether data transmission is completed. # check whether data transmission is completed.
...@@ -884,8 +931,7 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -884,8 +931,7 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensors[slave_chunk_id].append(input_tensor_slave) input_tensors[slave_chunk_id].append(input_tensor_slave)
if not forward_only: if not forward_only:
input_tensor_grad = backward_step_helper(slave_chunk_id) _, input_tensor_grad = forward_backward_helper_wrapper(bwd_model_chunk_id=slave_chunk_id)
cur_bwd_chunk_microbatch[slave_chunk_id] += 1
# If asynchronous, the memory will rise. # If asynchronous, the memory will rise.
bwd_wait_send_handles[slave_chunk_id] = send_backward(input_tensor_grad, tensor_shape, config, slave_chunk_id) bwd_wait_send_handles[slave_chunk_id] = send_backward(input_tensor_grad, tensor_shape, config, slave_chunk_id)
...@@ -905,19 +951,22 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -905,19 +951,22 @@ def forward_backward_pipelining_with_cutinhalf(
handle.wait() handle.wait()
fwd_wait_recv_handles[slave_chunk_id] = None fwd_wait_recv_handles[slave_chunk_id] = None
output_tensor_slave = forward_step_helper( output_tensor, _ = forward_backward_helper_wrapper(
slave_chunk_id, fwd_model_chunk_id=slave_chunk_id,
cur_fwd_chunk_microbatch[slave_chunk_id], checkpoint_activations_microbatch=checkpoint_activations_microbatch,
is_first_microbatch=False
) )
cur_fwd_chunk_microbatch[slave_chunk_id] += 1
# check whether backward data transmission is completed. # check whether backward data transmission is completed.
wait_comm_handles(bwd_wait_send_handles[slave_chunk_id]) wait_comm_handles(bwd_wait_send_handles[slave_chunk_id])
output_tensor_slave_send = output_tensor_slave output_tensor_slave_send = output_tensor
fwd_wait_send_handles[slave_chunk_id] = send_forward(output_tensor_slave_send, tensor_shape, config, slave_chunk_id, async_op=True) fwd_wait_send_handles[slave_chunk_id] = send_forward(output_tensor_slave_send, tensor_shape, config, slave_chunk_id, async_op=True)
# check whether forward data transmission is completed.
wait_comm_handles(fwd_wait_send_handles[slave_chunk_id])
if not forward_only:
deallocate_output_tensor(output_tensor_slave_send, config.deallocate_pipeline_outputs)
# Run overlaping f&bw stages # Run overlaping f&bw stages
fwd_wait_send_recv_handles = None fwd_wait_send_recv_handles = None
bwd_wait_send_recv_handles = None bwd_wait_send_recv_handles = None
...@@ -938,6 +987,9 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -938,6 +987,9 @@ def forward_backward_pipelining_with_cutinhalf(
# wait input for current step # wait input for current step
wait_comm_handles(fwd_wait_recv_handles[fwd_model_chunk_id]) wait_comm_handles(fwd_wait_recv_handles[fwd_model_chunk_id])
if not forward_only:
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
def pp_post_forward(output_tensor): def pp_post_forward(output_tensor):
nonlocal cur_fwd_chunk_microbatch nonlocal cur_fwd_chunk_microbatch
nonlocal num_chunk_max_microbatch nonlocal num_chunk_max_microbatch
...@@ -973,6 +1025,7 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -973,6 +1025,7 @@ def forward_backward_pipelining_with_cutinhalf(
wait_comm_handles(bwd_wait_send_recv_handles) wait_comm_handles(bwd_wait_send_recv_handles)
def pp_post_backward(input_tensor_grad): def pp_post_backward(input_tensor_grad):
nonlocal output_tensor_grads
nonlocal fwd_wait_send_handles nonlocal fwd_wait_send_handles
nonlocal fwd_wait_send_recv_handles nonlocal fwd_wait_send_recv_handles
nonlocal bwd_wait_send_recv_handles nonlocal bwd_wait_send_recv_handles
...@@ -981,9 +1034,6 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -981,9 +1034,6 @@ def forward_backward_pipelining_with_cutinhalf(
wait_comm_handles(fwd_wait_send_handles[fwd_model_chunk_id]) wait_comm_handles(fwd_wait_send_handles[fwd_model_chunk_id])
wait_comm_handles(fwd_wait_send_recv_handles) wait_comm_handles(fwd_wait_send_recv_handles)
if not forward_only:
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
if not forward_only: if not forward_only:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id: if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad = input_tensor_grad output_tensor_grad = input_tensor_grad
...@@ -999,31 +1049,21 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -999,31 +1049,21 @@ def forward_backward_pipelining_with_cutinhalf(
return input_tensor_grad return input_tensor_grad
# forward output_tensor, input_tensor_grad = forward_backward_helper_wrapper(
pp_pre_forward() fwd_model_chunk_id=fwd_model_chunk_id,
output_tensor = forward_step_helper( bwd_model_chunk_id=None if forward_only else bwd_model_chunk_id,
fwd_model_chunk_id, pre_forward=pp_pre_forward,
cur_fwd_chunk_microbatch[fwd_model_chunk_id], pre_backward=pp_pre_backward,
is_first_microbatch=False post_forward=pp_post_forward,
post_backward=pp_post_backward,
checkpoint_activations_microbatch=checkpoint_activations_microbatch,
) )
cur_fwd_chunk_microbatch[fwd_model_chunk_id] += 1
output_tensor = pp_post_forward(output_tensor)
# backward
pp_pre_backward()
if not forward_only:
try:
input_tensor_grad = backward_step_helper(bwd_model_chunk_id)
except Exception as e:
print(f"step_id: {step_id}, rank: {torch.distributed.get_rank()}, bwd_model_chunk_id: {bwd_model_chunk_id}", flush=True)
raise Exception(f"{e}")
cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1
else:
input_tensor_grad = None
_ = pp_post_backward(input_tensor_grad)
# only run backward # only run backward
else: else:
if not forward_only:
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
if bwd_model_chunk_id == slave_chunk_id and cur_fwd_chunk_microbatch[slave_chunk_id] < num_chunk_max_microbatch[slave_chunk_id]: if bwd_model_chunk_id == slave_chunk_id and cur_fwd_chunk_microbatch[slave_chunk_id] < num_chunk_max_microbatch[slave_chunk_id]:
input_tensor, fwd_wait_recv_handles[slave_chunk_id] = recv_forward(tensor_shape, config, slave_chunk_id, async_op=True) input_tensor, fwd_wait_recv_handles[slave_chunk_id] = recv_forward(tensor_shape, config, slave_chunk_id, async_op=True)
input_tensors[slave_chunk_id].append(input_tensor) input_tensors[slave_chunk_id].append(input_tensor)
...@@ -1031,11 +1071,9 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -1031,11 +1071,9 @@ def forward_backward_pipelining_with_cutinhalf(
wait_comm_handles(bwd_wait_send_handles[1 - bwd_model_chunk_id]) wait_comm_handles(bwd_wait_send_handles[1 - bwd_model_chunk_id])
wait_comm_handles(bwd_wait_send_recv_handles) wait_comm_handles(bwd_wait_send_recv_handles)
input_tensor_grad = backward_step_helper( _, input_tensor_grad = forward_backward_helper_wrapper(
bwd_model_chunk_id, bwd_model_chunk_id=bwd_model_chunk_id,
bwd_cur_microbatch=cur_bwd_chunk_microbatch[bwd_model_chunk_id]
) )
cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id: if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad = input_tensor_grad output_tensor_grad = input_tensor_grad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment