Commit b7c994f5 authored by dongcl's avatar dongcl
Browse files

patch for megatron v0.12.0.rc3. When tp > 1 and combined_1f1b = true, megatron...

patch for megatron v0.12.0.rc3. When tp > 1 and combined_1f1b = true, megatron (main) cannot execute properly
parent 08efd4ec
...@@ -66,8 +66,7 @@ class MegatronAdaptation: ...@@ -66,8 +66,7 @@ class MegatronAdaptation:
""" """
Execute after other adaptations. Execute after other adaptations.
""" """
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear pass
from megatron.core.transformer.transformer_block import TransformerBlock
class MegatronAdaptationABC: class MegatronAdaptationABC:
...@@ -101,11 +100,11 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -101,11 +100,11 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward
# GPT Model # GPT Model
# MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
# gpt_model_init_wrapper, gpt_model_init_wrapper,
# apply_wrapper=True) apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
# gpt_model_forward) gpt_model_forward)
def patch_core_transformers(self): def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper from ..core import transformer_block_init_wrapper
......
...@@ -13,12 +13,12 @@ from megatron.core.utils import get_te_version, is_te_min_version ...@@ -13,12 +13,12 @@ from megatron.core.utils import get_te_version, is_te_min_version
from megatron.core.extensions.transformer_engine import TEDotProductAttention from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.process_groups_config import ModelCommProcessGroups
from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.extensions.transformer_engine import TELinear as MegatronCoreTELinear from megatron.core.extensions.transformer_engine import TELinear as MegatronCoreTELinear
from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear as MegatronCoreTELayerNormColumnParallelLinear from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear as MegatronCoreTELayerNormColumnParallelLinear
from megatron.core.parallel_state import ( from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group, get_context_parallel_group,
get_hierarchical_context_parallel_groups, get_hierarchical_context_parallel_groups,
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
...@@ -65,7 +65,6 @@ class TELinear(MegatronCoreTELinear): ...@@ -65,7 +65,6 @@ class TELinear(MegatronCoreTELinear):
skip_weight_param_allocation: bool, skip_weight_param_allocation: bool,
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False, is_expert: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported" assert not self.split_bw, "split_bw is currently not supported"
...@@ -81,7 +80,6 @@ class TELinear(MegatronCoreTELinear): ...@@ -81,7 +80,6 @@ class TELinear(MegatronCoreTELinear):
skip_weight_param_allocation=skip_weight_param_allocation, skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name, tp_comm_buffer_name=tp_comm_buffer_name,
is_expert=is_expert, is_expert=is_expert,
tp_group=tp_group,
) )
def backward_dw(self): def backward_dw(self):
...@@ -108,7 +106,6 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea ...@@ -108,7 +106,6 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
is_expert: bool, is_expert: bool,
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported" assert not self.split_bw, "split_bw is currently not supported"
...@@ -124,7 +121,6 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea ...@@ -124,7 +121,6 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
is_expert=is_expert, is_expert=is_expert,
skip_weight_param_allocation=skip_weight_param_allocation, skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name, tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
) )
def backward_dw(self): def backward_dw(self):
...@@ -144,7 +140,6 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention): ...@@ -144,7 +140,6 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
k_channels: Optional[int] = None, k_channels: Optional[int] = None,
v_channels: Optional[int] = None, v_channels: Optional[int] = None,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
model_comm_pgs: ModelCommProcessGroups = None,
): ):
self.config = config self.config = config
self.te_forward_mask_type = False self.te_forward_mask_type = False
...@@ -171,26 +166,6 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention): ...@@ -171,26 +166,6 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
f"num_attention_heads ({self.config.num_attention_heads}))" f"num_attention_heads ({self.config.num_attention_heads}))"
) )
if model_comm_pgs is None:
# For backward compatibility, remove in v0.14 and raise error
# raise ValueError("TEDotProductAttention was called without ModelCommProcessGroups")
model_comm_pgs = ModelCommProcessGroups(
tp=get_tensor_model_parallel_group(check_initialized=False),
cp=get_context_parallel_group(check_initialized=False),
hcp=get_hierarchical_context_parallel_groups(check_initialized=False),
)
else:
assert hasattr(
model_comm_pgs, 'tp'
), "TEDotProductAttention model_comm_pgs must have tp pg"
assert hasattr(
model_comm_pgs, 'cp'
), "TEDotProductAttention model_comm_pgs must have cp pg"
if cp_comm_type == "a2a+p2p":
assert hasattr(
model_comm_pgs, 'hcp'
), "TEDotProductAttention model_comm_pgs must have hierarchical cp pg"
if is_te_min_version("0.10.0"): if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type # older version don't need attention_type
...@@ -206,9 +181,9 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention): ...@@ -206,9 +181,9 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!" ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if getattr(TEDotProductAttention, "cp_stream") is None: if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream() TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = model_comm_pgs.cp extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = torch.distributed.get_process_group_ranks( extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
model_comm_pgs.cp check_initialized=False
) )
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
if is_te_min_version("1.10.0"): if is_te_min_version("1.10.0"):
...@@ -282,7 +257,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention): ...@@ -282,7 +257,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
get_rng_state_tracker=( get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
), ),
tp_group=model_comm_pgs.tp, tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number, layer_number=layer_number,
**extra_kwargs, **extra_kwargs,
) )
...@@ -313,7 +288,6 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -313,7 +288,6 @@ if is_te_min_version("1.9.0.dev0"):
skip_bias_add: bool, skip_bias_add: bool,
is_expert: bool = False, is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported" assert not self.split_bw, "split_bw is currently not supported"
...@@ -329,7 +303,6 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -329,7 +303,6 @@ if is_te_min_version("1.9.0.dev0"):
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
is_expert=is_expert, is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name, tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
) )
def backward_dw(self): def backward_dw(self):
......
...@@ -289,7 +289,7 @@ class MoeAttnNode(TransformerLayerNode): ...@@ -289,7 +289,7 @@ class MoeAttnNode(TransformerLayerNode):
pre_mlp_layernorm_output, pre_mlp_layernorm_output,
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
permuted_probs, probs,
) = self.layer._submodule_attention_router_compound_forward( ) = self.layer._submodule_attention_router_compound_forward(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -305,10 +305,11 @@ class MoeAttnNode(TransformerLayerNode): ...@@ -305,10 +305,11 @@ class MoeAttnNode(TransformerLayerNode):
self.common_state.tokens_per_expert = tokens_per_expert self.common_state.tokens_per_expert = tokens_per_expert
# detached here # detached here
self.common_state.probs = self.detach(probs)
self.common_state.residual = self.detach(hidden_states) self.common_state.residual = self.detach(hidden_states)
self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output) self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output)
return permutated_local_input_tokens, permuted_probs return permutated_local_input_tokens
def dw(self): def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"): with torch.cuda.nvtx.range(f"{self.name} wgrad"):
...@@ -317,27 +318,26 @@ class MoeAttnNode(TransformerLayerNode): ...@@ -317,27 +318,26 @@ class MoeAttnNode(TransformerLayerNode):
class MoeDispatchNode(TransformerLayerNode): class MoeDispatchNode(TransformerLayerNode):
def forward_impl(self, permutated_local_input_tokens, permuted_probs): def forward_impl(self, permutated_local_input_tokens):
token_dispatcher = self.layer.mlp.token_dispatcher token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state): with token_dispatcher.per_batch_state_context(self.common_state):
inputs = permutated_local_input_tokens tokens_per_expert, global_input_tokens = token_dispatcher.dispatch_all_to_all(
tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all( self.common_state.tokens_per_expert, permutated_local_input_tokens
self.common_state.tokens_per_expert, permutated_local_input_tokens, permuted_probs
) )
# release tensor not used by backward # release tensor not used by backward
# inputs.untyped_storage().resize_(0) # inputs.untyped_storage().resize_(0)
self.common_state.tokens_per_expert = tokens_per_expert self.common_state.tokens_per_expert = tokens_per_expert
return global_input_tokens, global_probs return global_input_tokens
class MoeMlPNode(TransformerLayerNode): class MoeMlPNode(TransformerLayerNode):
def forward_impl(self, global_input_tokens, global_probs): def forward_impl(self, global_input_tokens):
pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output
token_dispatcher = self.layer.mlp.token_dispatcher token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state): with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward( expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward(
self.common_state.tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output self.common_state.tokens_per_expert, global_input_tokens, pre_mlp_layernorm_output
) )
assert mlp_bias is None assert mlp_bias is None
...@@ -364,7 +364,9 @@ class MoeCombineNode(TransformerLayerNode): ...@@ -364,7 +364,9 @@ class MoeCombineNode(TransformerLayerNode):
) )
cur_stream = torch.cuda.current_stream() cur_stream = torch.cuda.current_stream()
self.common_state.residual.record_stream(cur_stream) self.common_state.residual.record_stream(cur_stream)
self.common_state.probs.record_stream(cur_stream)
self.common_state.residual = None self.common_state.residual = None
self.common_state.probs = None
return output return output
......
...@@ -13,8 +13,6 @@ from megatron.core.utils import ( ...@@ -13,8 +13,6 @@ from megatron.core.utils import (
get_model_config, get_model_config,
get_model_type, get_model_type,
get_model_xattn, get_model_xattn,
nvtx_range_pop,
nvtx_range_push,
) )
from megatron.core.pipeline_parallel.schedules import ( from megatron.core.pipeline_parallel.schedules import (
forward_step, forward_step,
...@@ -90,16 +88,6 @@ def get_pp_rank_microbatches( ...@@ -90,16 +88,6 @@ def get_pp_rank_microbatches(
) )
def print_rank_4(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 4:
print(message, flush=True)
else:
print(message, flush=True)
from megatron.training import print_rank_0, print_rank_last
def forward_backward_pipelining_with_interleaving( def forward_backward_pipelining_with_interleaving(
*, *,
forward_step_func, forward_step_func,
...@@ -108,11 +96,10 @@ def forward_backward_pipelining_with_interleaving( ...@@ -108,11 +96,10 @@ def forward_backward_pipelining_with_interleaving(
num_microbatches: int, num_microbatches: int,
seq_length: int, seq_length: int,
micro_batch_size: int, micro_batch_size: int,
decoder_seq_length: Optional[int] = None, decoder_seq_length: int = None,
forward_only: bool = False, forward_only: bool = False,
collect_non_loss_data: bool = False, collect_non_loss_data: bool = False,
first_val_step: Optional[bool] = None, first_val_step: bool = None,
adjust_tensor_shapes_fn: Optional[Callable] = None, # unused
): ):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
...@@ -133,9 +120,6 @@ def forward_backward_pipelining_with_interleaving( ...@@ -133,9 +120,6 @@ def forward_backward_pipelining_with_interleaving(
assert isinstance( assert isinstance(
data_iterator, list data_iterator, list
), "interleaved pipeline parallelism expected each model chunk to have a data iterator" ), "interleaved pipeline parallelism expected each model chunk to have a data iterator"
assert (
adjust_tensor_shapes_fn is None
), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism"
config = get_model_config(model[0]) config = get_model_config(model[0])
...@@ -310,9 +294,6 @@ def forward_backward_pipelining_with_interleaving( ...@@ -310,9 +294,6 @@ def forward_backward_pipelining_with_interleaving(
# Both tables are indexed with virtual_microbatch_id. # Both tables are indexed with virtual_microbatch_id.
microbatch_id_table, model_chunk_id_table = zip(*schedule_table) microbatch_id_table, model_chunk_id_table = zip(*schedule_table)
print_rank_4(f"rank last. microbatch_id_table: {microbatch_id_table}. model_chunk_id_table: {model_chunk_id_table}")
print_rank_0(f"rank first. microbatch_id_table: {microbatch_id_table}. model_chunk_id_table: {model_chunk_id_table}")
def get_model_chunk_id(virtual_microbatch_id, forward): def get_model_chunk_id(virtual_microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number.""" """Helper method to get the model chunk ID given the iteration number."""
model_chunk_id = model_chunk_id_table[virtual_microbatch_id % total_num_microbatches] model_chunk_id = model_chunk_id_table[virtual_microbatch_id % total_num_microbatches]
...@@ -432,7 +413,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -432,7 +413,7 @@ def forward_backward_pipelining_with_interleaving(
) )
# forward step # forward step
if parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=model_chunk_id): if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]): if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None) input_tensors[model_chunk_id].append(None)
...@@ -460,7 +441,6 @@ def forward_backward_pipelining_with_interleaving( ...@@ -460,7 +441,6 @@ def forward_backward_pipelining_with_interleaving(
is_first_microbatch_for_model_chunk(virtual_microbatch_id), is_first_microbatch_for_model_chunk(virtual_microbatch_id),
), ),
current_microbatch=microbatch_id, current_microbatch=microbatch_id,
vp_stage=model_chunk_id,
) )
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
...@@ -491,7 +471,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -491,7 +471,7 @@ def forward_backward_pipelining_with_interleaving(
synchronized_model_chunks.add(model_chunk_id) synchronized_model_chunks.add(model_chunk_id)
# pylint: disable=E0606 # pylint: disable=E0606
if parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=model_chunk_id): if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if len(output_tensor_grads[model_chunk_id]) == 0: if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None) output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0) input_tensor = input_tensors[model_chunk_id].pop(0)
...@@ -730,14 +710,9 @@ def forward_backward_pipelining_with_interleaving( ...@@ -730,14 +710,9 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = post_backward(input_tensor_grad) input_tensor_grad = post_backward(input_tensor_grad)
return output_tensor, input_tensor_grad return output_tensor, input_tensor_grad
is_vp_first_stage = partial(parallel_state.is_pipeline_first_stage, ignore_virtual=False)
is_vp_last_stage = partial(parallel_state.is_pipeline_last_stage, ignore_virtual=False)
# Run warmup forward passes. # Run warmup forward passes.
nvtx_range_push(suffix="warmup")
parallel_state.set_virtual_pipeline_model_parallel_rank(0) parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append( input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
p2p_communication.recv_forward(tensor_shape, config, is_vp_first_stage())
)
fwd_wait_handles = None fwd_wait_handles = None
fwd_wait_recv_handles = None fwd_wait_recv_handles = None
...@@ -767,7 +742,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -767,7 +742,7 @@ def forward_backward_pipelining_with_interleaving(
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
if not is_vp_first_stage(vp_stage=cur_model_chunk_id) and k != 0: if not parallel_state.is_pipeline_first_stage(ignore_virtual=False) and k != 0:
assert recv_prev_wait_handles, ( assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, iteration {k},' f'pp rank {pipeline_parallel_rank}, iteration {k},'
'should have registered recv handle' 'should have registered recv handle'
...@@ -814,7 +789,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -814,7 +789,7 @@ def forward_backward_pipelining_with_interleaving(
) )
# Don't send tensor downstream if on last stage. # Don't send tensor downstream if on last stage.
if is_vp_last_stage(vp_stage=cur_model_chunk_id): if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
output_tensor = None output_tensor = None
# Send and receive tensors as appropriate (send tensors computed # Send and receive tensors as appropriate (send tensors computed
...@@ -917,17 +892,12 @@ def forward_backward_pipelining_with_interleaving( ...@@ -917,17 +892,12 @@ def forward_backward_pipelining_with_interleaving(
if recv_next: if recv_next:
output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1]) output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1])
nvtx_range_pop(suffix="warmup")
# Run 1F1B in steady state. # Run 1F1B in steady state.
nvtx_range_push(suffix="steady")
for k in range(num_microbatches_remaining): for k in range(num_microbatches_remaining):
# Forward pass. # Forward pass.
forward_k = k + num_warmup_microbatches forward_k = k + num_warmup_microbatches
print_rank_0(f"rank first. 1F1B in steady state: {k}/{num_microbatches_remaining}")
print_rank_4(f"rank last. 1F1B in steady state: {k}/{num_microbatches_remaining}")
# Decide to checkpoint all layers' activations of the current micro-batch. # Decide to checkpoint all layers' activations of the current micro-batch.
if max_outstanding_backprops is not None: if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = ( checkpoint_activations_microbatch = (
...@@ -944,7 +914,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -944,7 +914,7 @@ def forward_backward_pipelining_with_interleaving(
cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True) cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if not is_vp_first_stage(vp_stage=cur_model_chunk_id): if not parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
assert recv_prev_wait_handles, ( assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, fwd iteration {forward_k}, ' f'pp rank {pipeline_parallel_rank}, fwd iteration {forward_k}, '
...@@ -972,7 +942,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -972,7 +942,7 @@ def forward_backward_pipelining_with_interleaving(
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
# Last virtual stage no activation tensor to send. # Last virtual stage no activation tensor to send.
if is_vp_last_stage(vp_stage=forward_model_chunk_id): if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
output_tensor = None output_tensor = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage( recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
...@@ -999,9 +969,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -999,9 +969,7 @@ def forward_backward_pipelining_with_interleaving(
send_next_wait_handle.wait() send_next_wait_handle.wait()
if fwd_wait_handles is not None: if fwd_wait_handles is not None:
send_next_wait_handle = ( send_next_wait_handle = (
fwd_wait_handles.pop("send_next") fwd_wait_handles.pop("send_next") if "send_next" in fwd_wait_handles else None
if "send_next" in fwd_wait_handles
else None
) )
if "recv_prev" in fwd_wait_handles: if "recv_prev" in fwd_wait_handles:
recv_prev_wait_handles.append(fwd_wait_handles.pop("recv_prev")) recv_prev_wait_handles.append(fwd_wait_handles.pop("recv_prev"))
...@@ -1024,7 +992,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1024,7 +992,7 @@ def forward_backward_pipelining_with_interleaving(
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if not is_vp_last_stage(vp_stage=backward_model_chunk_id): if not parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, ( assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, bwd iteration {backward_k}, ' f'pp rank {pipeline_parallel_rank}, bwd iteration {backward_k}, '
...@@ -1048,7 +1016,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1048,7 +1016,7 @@ def forward_backward_pipelining_with_interleaving(
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
# First virtual stage no activation gradient tensor to send. # First virtual stage no activation gradient tensor to send.
if is_vp_first_stage(vp_stage=backward_model_chunk_id): if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
input_tensor_grad = None input_tensor_grad = None
recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage( recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage(
...@@ -1091,9 +1059,6 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1091,9 +1059,6 @@ def forward_backward_pipelining_with_interleaving(
post_backward=pp_post_backward, post_backward=pp_post_backward,
checkpoint_activations_microbatch=checkpoint_activations_microbatch, checkpoint_activations_microbatch=checkpoint_activations_microbatch,
) )
print_rank_0(f"rank first. 1F1B in steady state: {k}/{num_microbatches_remaining} end")
print_rank_4(f"rank last. 1F1B in steady state: {k}/{num_microbatches_remaining} end")
else: # No p2p overlap. else: # No p2p overlap.
backward_k = k backward_k = k
output_tensor, input_tensor_grad = forward_backward_helper_wrapper( output_tensor, input_tensor_grad = forward_backward_helper_wrapper(
...@@ -1109,12 +1074,12 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1109,12 +1074,12 @@ def forward_backward_pipelining_with_interleaving(
# otherwise set tensor to None. # otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if is_vp_last_stage(vp_stage=forward_model_chunk_id): if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
output_tensor = None output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if is_vp_first_stage(vp_stage=backward_model_chunk_id): if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
input_tensor_grad = None input_tensor_grad = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage( recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
...@@ -1150,16 +1115,9 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1150,16 +1115,9 @@ def forward_backward_pipelining_with_interleaving(
if recv_next: if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
if k == 0:
print_rank_0(f"input_tensor_grad: {input_tensor_grad}")
print_rank_0(f"rank first. 1F1B in steady state end")
print_rank_4(f"rank last. 1F1B in steady state end")
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
nvtx_range_pop(suffix="steady")
# Run cooldown backward passes (flush out pipeline). # Run cooldown backward passes (flush out pipeline).
nvtx_range_push(suffix="cooldown")
if not forward_only: if not forward_only:
if bwd_wait_handles is not None: if bwd_wait_handles is not None:
for bwd_wait_handle in bwd_wait_handles.values(): for bwd_wait_handle in bwd_wait_handles.values():
...@@ -1167,14 +1125,12 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1167,14 +1125,12 @@ def forward_backward_pipelining_with_interleaving(
if are_all_microbatches_in_warmup: if are_all_microbatches_in_warmup:
output_tensor_grads[num_model_chunks - 1].append( output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward( p2p_communication.recv_backward(tensor_shape, config=config)
tensor_shape, config=config, is_last_stage=is_vp_last_stage()
)
) )
for k in range(num_microbatches_remaining, total_num_microbatches): for k in range(num_microbatches_remaining, total_num_microbatches):
cur_model_chunk_id = get_model_chunk_id(k, forward=False) cur_model_chunk_id = get_model_chunk_id(k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if not is_vp_last_stage(vp_stage=cur_model_chunk_id) and k != 0: if not parallel_state.is_pipeline_last_stage(ignore_virtual=False) and k != 0:
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, ( assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, backward iteration {k}, ' f'pp rank {pipeline_parallel_rank}, backward iteration {k}, '
...@@ -1214,7 +1170,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1214,7 +1170,7 @@ def forward_backward_pipelining_with_interleaving(
_, input_tensor_grad = forward_backward_helper_wrapper(b_virtual_microbatch_id=k) _, input_tensor_grad = forward_backward_helper_wrapper(b_virtual_microbatch_id=k)
# First virtual stage no activation gradient tensor to send. # First virtual stage no activation gradient tensor to send.
if is_vp_first_stage(vp_stage=cur_model_chunk_id): if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
input_tensor_grad = None input_tensor_grad = None
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
...@@ -1271,9 +1227,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1271,9 +1227,7 @@ def forward_backward_pipelining_with_interleaving(
if model_chunk_id not in synchronized_model_chunks: if model_chunk_id not in synchronized_model_chunks:
config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters()) config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id) synchronized_model_chunks.add(model_chunk_id)
nvtx_range_pop(suffix="cooldown")
nvtx_range_push(suffix="misc")
assert ( assert (
not recv_prev_wait_handles not recv_prev_wait_handles
), 'recv_prev_wait_handles should be cleared at the end of a step' ), 'recv_prev_wait_handles should be cleared at the end of a step'
...@@ -1303,7 +1257,5 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1303,7 +1257,5 @@ def forward_backward_pipelining_with_interleaving(
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph: if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs() create_cudagraphs()
nvtx_range_pop(suffix="misc")
return forward_data_store return forward_data_store
...@@ -108,20 +108,18 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -108,20 +108,18 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self.hidden_shape_before_permute = hidden_states.shape self.hidden_shape_before_permute = hidden_states.shape
( (
permutated_local_input_tokens, permutated_local_input_tokens,
permuted_probs,
self.reversed_local_input_permutation_mapping, self.reversed_local_input_permutation_mapping,
) = permute( ) = permute(
hidden_states, hidden_states,
routing_map, routing_map,
probs=self.probs,
num_out_tokens=self.num_out_tokens, num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion, fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad, drop_and_pad=self.drop_and_pad,
) )
return tokens_per_expert, permutated_local_input_tokens, permuted_probs return tokens_per_expert, permutated_local_input_tokens
def dispatch_all_to_all(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs): def dispatch_all_to_all(self, tokens_per_expert, permutated_local_input_tokens):
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize( tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert "before_ep_alltoall", tokens_per_expert
...@@ -129,13 +127,10 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -129,13 +127,10 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
global_input_tokens = all_to_all( global_input_tokens = all_to_all(
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits
) )
global_probs = all_to_all(
self.ep_group, permuted_probs, self.output_splits, self.input_splits
)
return tokens_per_expert, global_input_tokens, global_probs return tokens_per_expert, global_input_tokens
def dispatch_postprocess(self, tokens_per_expert, global_input_tokens, global_probs): def dispatch_postprocess(self, tokens_per_expert, global_input_tokens):
if self.shared_experts is not None: if self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens) self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
...@@ -148,9 +143,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -148,9 +143,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
global_input_tokens = gather_from_sequence_parallel_region( global_input_tokens = gather_from_sequence_parallel_region(
global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes
) )
global_probs = gather_from_sequence_parallel_region(
global_probs, group=self.tp_group, output_split_sizes=output_split_sizes
)
# Permutation 2: Sort tokens by local expert. # Permutation 2: Sort tokens by local expert.
tokens_per_expert = self._maybe_dtoh_and_synchronize( tokens_per_expert = self._maybe_dtoh_and_synchronize(
...@@ -169,28 +161,16 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -169,28 +161,16 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
.contiguous() .contiguous()
.flatten(start_dim=0, end_dim=2) .flatten(start_dim=0, end_dim=2)
) )
global_probs = (
global_probs.view(
self.tp_size * self.ep_size,
self.num_local_experts,
self.capacity,
*global_probs.size()[1:],
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
else: else:
global_input_tokens, global_probs = sort_chunks_by_idxs( global_input_tokens = sort_chunks_by_idxs(
global_input_tokens, global_input_tokens,
self.num_global_tokens_per_local_expert.ravel(), self.num_global_tokens_per_local_expert.ravel(),
self.sort_input_by_local_experts, self.sort_input_by_local_experts,
probs=global_probs,
fused=self.config.moe_permute_fusion, fused=self.config.moe_permute_fusion,
) )
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert) tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert, global_probs return global_input_tokens, tokens_per_expert
def token_permutation( def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
...@@ -218,15 +198,15 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -218,15 +198,15 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Preprocess: Get the metadata for communication, permutation and computation operations. # Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input # Permutation 1: input to AlltoAll input
tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map) tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert) tokens_per_expert, permutated_local_input_tokens = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
tokens_per_expert, global_input_tokens, global_probs = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs) tokens_per_expert, global_input_tokens = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens)
# Permutation 2: Sort tokens by local expert. # Permutation 2: Sort tokens by local expert.
global_input_tokens, tokens_per_expert, global_probs = self.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs) global_input_tokens, tokens_per_expert = self.dispatch_postprocess(tokens_per_expert, global_input_tokens)
return global_input_tokens, tokens_per_expert, global_probs return global_input_tokens, tokens_per_expert
def combine_preprocess(self, hidden_states): def combine_preprocess(self, hidden_states):
# Unpermutation 2: Unsort tokens by local expert. # Unpermutation 2: Unsort tokens by local expert.
...@@ -283,6 +263,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -283,6 +263,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
permutated_local_input_tokens, permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping, self.reversed_local_input_permutation_mapping,
restore_shape=self.hidden_shape_before_permute, restore_shape=self.hidden_shape_before_permute,
probs=self.probs,
routing_map=self.routing_map, routing_map=self.routing_map,
fused=self.config.moe_permute_fusion, fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad, drop_and_pad=self.drop_and_pad,
......
...@@ -8,7 +8,11 @@ def transformer_block_init_wrapper(fn): ...@@ -8,7 +8,11 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block # mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config'] config = args[0] if len(args) > 1 else kwargs['config']
if hasattr(config, "mtp_num_layers") and config.mtp_num_layers is not None: if (
hasattr(config, "mtp_num_layers")
and config.mtp_num_layers is not None
and config.mtp_num_layers > 0
):
self.main_final_layernorm = self.final_layernorm self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None self.final_layernorm = None
......
...@@ -66,7 +66,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -66,7 +66,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output, pre_mlp_layernorm_output,
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
permuted_probs, probs,
) = self._submodule_attention_router_compound_forward( ) = self._submodule_attention_router_compound_forward(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -80,16 +80,14 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -80,16 +80,14 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params=inference_params, inference_params=inference_params,
) )
(tokens_per_expert, global_input_tokens, global_probs) = self._submodule_dispatch_forward( (tokens_per_expert, global_input_tokens) = self._submodule_dispatch_forward(
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
permuted_probs
) )
(expert_output, shared_expert_output, mlp_bias) = self._submodule_moe_forward( (expert_output, shared_expert_output, mlp_bias) = self._submodule_moe_forward(
tokens_per_expert, tokens_per_expert,
global_input_tokens, global_input_tokens,
global_probs,
pre_mlp_layernorm_output pre_mlp_layernorm_output
) )
...@@ -125,13 +123,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -125,13 +123,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual = hidden_states residual = hidden_states
# Optional Input Layer norm # Optional Input Layer norm
if self.recompute_input_layernorm: input_layernorm_output = self.input_layernorm(hidden_states)
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
self.input_layernorm, hidden_states
)
else:
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output_with_bias = self.self_attention( attention_output_with_bias = self.self_attention(
...@@ -146,13 +138,6 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -146,13 +138,6 @@ class TransformerLayer(MegatronCoreTransformerLayer):
sequence_len_offset=sequence_len_offset, sequence_len_offset=sequence_len_offset,
) )
if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self.input_layernorm_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself # TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module? # inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
...@@ -193,19 +178,13 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -193,19 +178,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
) )
# Optional Layer norm post the cross-attention. # Optional Layer norm post the cross-attention.
if self.recompute_pre_mlp_layernorm: pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
self.pre_mlp_layernorm, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output) probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare( tokens_per_expert = self.mlp.token_dispatcher.meta_prepare(
pre_mlp_layernorm_output, probs, routing_map pre_mlp_layernorm_output, probs, routing_map
) )
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.mlp.token_dispatcher.dispatch_preprocess( tokens_per_expert, permutated_local_input_tokens = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, routing_map, tokens_per_expert pre_mlp_layernorm_output, routing_map, tokens_per_expert
) )
...@@ -214,18 +193,18 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -214,18 +193,18 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output, pre_mlp_layernorm_output,
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
permuted_probs, probs,
] ]
return tuple(outputs) return tuple(outputs)
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs): def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens):
""" """
Dispatches tokens to the appropriate experts based on the router output. Dispatches tokens to the appropriate experts based on the router output.
""" """
tokens_per_expert, global_input_tokens, global_probs = self.mlp.token_dispatcher.dispatch_all_to_all( tokens_per_expert, global_input_tokens = self.mlp.token_dispatcher.dispatch_all_to_all(
tokens_per_expert, permutated_local_input_tokens, permuted_probs tokens_per_expert, permutated_local_input_tokens
) )
return [tokens_per_expert, global_input_tokens, global_probs] return [tokens_per_expert, global_input_tokens]
def _submodule_dense_forward(self, hidden_states): def _submodule_dense_forward(self, hidden_states):
residual = hidden_states residual = hidden_states
...@@ -241,16 +220,16 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -241,16 +220,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output return output
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output): def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, pre_mlp_layernorm_output):
""" """
Performs a forward pass for the MLP submodule, including both expert-based Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations. and optional shared-expert computations.
""" """
shared_expert_output = None shared_expert_output = None
(dispatched_input, tokens_per_expert, permuted_probs) = ( (dispatched_input, tokens_per_expert) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs) self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens)
) )
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs) expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output) expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap: if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output) shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
......
...@@ -3,7 +3,6 @@ import argparse ...@@ -3,7 +3,6 @@ import argparse
from typing import Union from typing import Union
from megatron.training.arguments import add_megatron_arguments from megatron.training.arguments import add_megatron_arguments
from megatron.core.msc_utils import MultiStorageClientFeature
def remove_original_params(parser, param_names: Union[list, str]): def remove_original_params(parser, param_names: Union[list, str]):
...@@ -60,12 +59,6 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): ...@@ -60,12 +59,6 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
# args.rank = int(os.getenv('RANK', '0')) # args.rank = int(os.getenv('RANK', '0'))
# args.world_size = int(os.getenv("WORLD_SIZE", '1')) # args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Args to disable MSC
if not args.enable_msc:
MultiStorageClientFeature.disable()
assert MultiStorageClientFeature.is_enabled() is False
print('WARNING: The MSC feature is disabled.')
return args return args
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT.""" """Pretrain GPT."""
import os
import os, sys
current_dir = os.path.dirname(os.path.abspath(__file__))
megatron_path = os.path.join(current_dir, "Megatron-LM")
sys.path.append(megatron_path)
from functools import partial
from typing import List, Optional, Tuple, Union
import torch import torch
from functools import partial
from contextlib import nullcontext
import inspect
from megatron.core import parallel_state from typing import List, Optional, Tuple, Union
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.training import get_args
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import mpu
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core.models.gpt import GPTModel from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.models.gpt.gpt_layer_specs import ( from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
get_gpt_decoder_block_spec, from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
get_gpt_heterogeneous_layer_spec,
)
from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.core.transformer.spec_utils import import_module import megatron.legacy.model
from megatron.core.models.gpt import GPTModel
from megatron.training import pretrain
from megatron.core.utils import StragglerDetector from megatron.core.utils import StragglerDetector
from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0 from megatron.core.transformer.spec_utils import import_module
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.utils import ( from megatron.training.utils import (
get_batch_on_this_cp_rank, get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank, get_batch_on_this_tp_rank,
get_blend_and_blend_per_split, get_blend_and_blend_per_split,
) )
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from megatron.core.models.gpt.gpt_layer_specs import (
import megatron.legacy.model # isort: skip get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
try: )
from megatron.post_training.arguments import add_modelopt_args, modelopt_args_enabled
from megatron.post_training.loss_func import loss_func as loss_func_modelopt
from megatron.post_training.model_provider import model_provider as model_provider_modelopt
has_nvidia_modelopt = True
except ImportError:
has_nvidia_modelopt = False
from dcu_megatron import megatron_adaptor from dcu_megatron import megatron_adaptor
stimer = StragglerDetector() stimer = StragglerDetector()
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
def model_provider(
pre_process=True, post_process=True
) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
"""Builds the model. """Builds the model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
...@@ -72,34 +55,24 @@ def model_provider( ...@@ -72,34 +55,24 @@ def model_provider(
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
""" """
args = get_args() args = get_args()
if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
return model_provider_modelopt(pre_process, post_process)
if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))): if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))):
assert args.transformer_impl == "transformer_engine" assert args.transformer_impl == "transformer_engine"
use_te = args.transformer_impl == "transformer_engine" use_te = args.transformer_impl == "transformer_engine"
if args.record_memory_history: if args.record_memory_history:
torch.cuda.memory._record_memory_history( torch.cuda.memory._record_memory_history(True,
True,
# keep 100,000 alloc/free events from before the snapshot # keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000, trace_alloc_max_entries=100000,
# record stack information for the trace events # record stack information for the trace events
trace_alloc_record_context=True, trace_alloc_record_context=True)
)
def oom_observer(device, alloc, device_alloc, device_free): def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened # snapshot right after an OOM happened
print('saving allocated state during OOM') print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot() snapshot = torch.cuda.memory._snapshot()
from pickle import dump from pickle import dump
dump(snapshot, open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'))
dump(
snapshot,
open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'),
)
torch._C._cuda_attach_out_of_memory_observer(oom_observer) torch._C._cuda_attach_out_of_memory_observer(oom_observer)
...@@ -118,41 +91,27 @@ def model_provider( ...@@ -118,41 +91,27 @@ def model_provider(
pre_process=pre_process, pre_process=pre_process,
post_process=post_process, post_process=post_process,
) )
else: # using core models else: # using core models
if args.spec is not None: if args.spec is not None:
transformer_layer_spec = import_module(args.spec) transformer_layer_spec = import_module(args.spec)
else: else:
if args.num_experts: if args.num_experts:
# Define the decoder block spec # Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec( transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te, normalization=args.normalization)
config, use_transformer_engine=use_te, normalization=args.normalization
)
elif args.heterogeneous_layers_config_path is not None:
transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
else: else:
# Define the decoder layer spec # Define the decoder layer spec
if use_te: if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.num_experts, args.moe_grouped_gemm,
args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
args.qk_layernorm,
args.multi_latent_attention,
args.moe_use_legacy_grouped_gemm,
)
else: else:
transformer_layer_spec = get_gpt_layer_local_spec( transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.num_experts, args.moe_grouped_gemm,
args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm,
args.qk_layernorm, normalization=args.normalization)
args.multi_latent_attention,
args.moe_use_legacy_grouped_gemm,
normalization=args.normalization,
)
mtp_block_spec = None mtp_block_spec = None
if args.mtp_num_layers is not None: if args.mtp_num_layers is not None:
mtp_block_spec = get_gpt_mtp_block_spec( mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te)
config, transformer_layer_spec, use_transformer_engine=use_te
)
model = GPTModel( model = GPTModel(
config=config, config=config,
...@@ -170,6 +129,7 @@ def model_provider( ...@@ -170,6 +129,7 @@ def model_provider(
rope_scaling=args.use_rope_scaling, rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec, mtp_block_spec=mtp_block_spec,
) )
print_rank_0(model)
return model return model
...@@ -178,9 +138,7 @@ def get_batch(data_iterator): ...@@ -178,9 +138,7 @@ def get_batch(data_iterator):
"""Generate a batch.""" """Generate a batch."""
# TODO: this is pretty hacky, find a better way # TODO: this is pretty hacky, find a better way
if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and ( if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
not parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
return None, None, None, None, None return None, None, None, None, None
# get batches based on the TP rank you are on # get batches based on the TP rank you are on
...@@ -196,15 +154,12 @@ def get_batch(data_iterator): ...@@ -196,15 +154,12 @@ def get_batch(data_iterator):
SPIKY_LOSS_FACTOR = 10 SPIKY_LOSS_FACTOR = 10
def loss_func( def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[GPTModel] = None
):
"""Loss function. """Loss function.
Args: Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses output_tensor (torch.Tensor): The tensor with the losses
model (GPTModel, optional): The model (can be wrapped)
Returns: Returns:
the loss scalar for this micro-batch the loss scalar for this micro-batch
...@@ -214,16 +169,13 @@ def loss_func( ...@@ -214,16 +169,13 @@ def loss_func(
""" """
args = get_args() args = get_args()
if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
return loss_func_modelopt(loss_mask, output_tensor, model=model)
losses = output_tensor.float() losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float() loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum() total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
if args.context_parallel_size > 1: if args.context_parallel_size > 1:
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
# Check individual rank losses are not NaN prior to DP all-reduce. # Check individual rank losses are not NaN prior to DP all-reduce.
rerun_state_machine = get_rerun_state_machine() rerun_state_machine = get_rerun_state_machine()
...@@ -232,14 +184,14 @@ def loss_func( ...@@ -232,14 +184,14 @@ def loss_func(
result=loss[0], result=loss[0],
rejection_func=torch.isnan, rejection_func=torch.isnan,
message="found NaN in local forward loss calculation", message="found NaN in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=True, fatal=True,
) )
rerun_state_machine.validate_result( rerun_state_machine.validate_result(
result=loss[0], result=loss[0],
rejection_func=torch.isinf, rejection_func=torch.isinf,
message="found Inf in local forward loss calculation", message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=True, fatal=True,
) )
# Check for spiky loss # Check for spiky loss
...@@ -252,18 +204,22 @@ def loss_func( ...@@ -252,18 +204,22 @@ def loss_func(
context="loss", context="loss",
), ),
message="Spiky loss", message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=False, fatal=False,
) )
# Reduce loss for logging. # Reduce loss for logging.
reporting_loss = loss.clone().detach() reporting_loss = loss.clone().detach()
torch.distributed.all_reduce(reporting_loss, group=parallel_state.get_data_parallel_group()) torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
# loss[0] is a view of loss, so it has ._base not None, which triggers assert error # loss[0] is a view of loss, so it has ._base not None, which triggers assert error
# in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone() # in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()
# on loss[0] fixes this # on loss[0] fixes this
local_num_tokens = loss[1].clone().detach().to(torch.int) local_num_tokens = loss[1].clone().detach().to(torch.int)
return (loss[0].clone(), local_num_tokens, {'lm loss': (reporting_loss[0], reporting_loss[1])}) return (
loss[0].clone(),
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)
def forward_step(data_iterator, model: GPTModel): def forward_step(data_iterator, model: GPTModel):
...@@ -280,26 +236,25 @@ def forward_step(data_iterator, model: GPTModel): ...@@ -280,26 +236,25 @@ def forward_step(data_iterator, model: GPTModel):
timers('batch-generator', log_level=2).start() timers('batch-generator', log_level=2).start()
global stimer global stimer
with stimer(bdata=True): with stimer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator) tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
with stimer: with stimer:
if args.use_legacy_models: if args.use_legacy_models:
output_tensor = model(tokens, position_ids, attention_mask, labels=labels) output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else: else:
output_tensor = model( output_tensor = model(tokens, position_ids, attention_mask,
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask labels=labels, loss_mask=loss_mask)
)
# [ModelOpt]: model is needed to access ModelOpt distillation losses return output_tensor, partial(loss_func, loss_mask)
return output_tensor, partial(loss_func, loss_mask, model=model)
def is_dataset_built_on_rank(): def is_dataset_built_on_rank():
return ( return (
parallel_state.is_pipeline_first_stage(ignore_virtual=True) mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
or parallel_state.is_pipeline_last_stage(ignore_virtual=True) ) and mpu.get_tensor_model_parallel_rank() == 0
) and parallel_state.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args): def core_gpt_dataset_config_from_args(args):
...@@ -324,8 +279,7 @@ def core_gpt_dataset_config_from_args(args): ...@@ -324,8 +279,7 @@ def core_gpt_dataset_config_from_args(args):
reset_attention_mask=args.reset_attention_mask, reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss, eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader, create_attention_mask=args.create_attention_mask_in_dataloader,
object_storage_cache_path=args.object_storage_cache_path, s3_cache_path=args.s3_cache_path,
mid_level_dataset_surplus=args.mid_level_dataset_surplus,
) )
...@@ -347,7 +301,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -347,7 +301,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0("> building train, validation, and test datasets for GPT ...") print_rank_0("> building train, validation, and test datasets for GPT ...")
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config dataset_type,
train_val_test_num_samples,
is_dataset_built_on_rank,
config
).build() ).build()
print_rank_0("> finished creating GPT datasets ...") print_rank_0("> finished creating GPT datasets ...")
...@@ -366,5 +323,4 @@ if __name__ == "__main__": ...@@ -366,5 +323,4 @@ if __name__ == "__main__":
ModelType.encoder_or_decoder, ModelType.encoder_or_decoder,
forward_step, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment