"vscode:/vscode.git/clone" did not exist on "80b2b3207a086061f419b7ae38ccec7dc5562f69"
Commit b7c994f5 authored by dongcl's avatar dongcl
Browse files

patch for megatron v0.12.0.rc3. When tp > 1 and combined_1f1b = true, megatron...

patch for megatron v0.12.0.rc3. When tp > 1 and combined_1f1b = true, megatron (main) cannot execute properly
parent 08efd4ec
......@@ -66,8 +66,7 @@ class MegatronAdaptation:
"""
Execute after other adaptations.
"""
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.transformer_block import TransformerBlock
pass
class MegatronAdaptationABC:
......@@ -101,11 +100,11 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward
# GPT Model
# MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
# gpt_model_init_wrapper,
# apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
# gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
......
......@@ -13,12 +13,12 @@ from megatron.core.utils import get_te_version, is_te_min_version
from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.process_groups_config import ModelCommProcessGroups
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.extensions.transformer_engine import TELinear as MegatronCoreTELinear
from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear as MegatronCoreTELayerNormColumnParallelLinear
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_hierarchical_context_parallel_groups,
get_tensor_model_parallel_group,
......@@ -65,7 +65,6 @@ class TELinear(MegatronCoreTELinear):
skip_weight_param_allocation: bool,
tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
......@@ -81,7 +80,6 @@ class TELinear(MegatronCoreTELinear):
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
is_expert=is_expert,
tp_group=tp_group,
)
def backward_dw(self):
......@@ -108,7 +106,6 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
......@@ -124,7 +121,6 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
is_expert=is_expert,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
)
def backward_dw(self):
......@@ -144,7 +140,6 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
k_channels: Optional[int] = None,
v_channels: Optional[int] = None,
cp_comm_type: str = "p2p",
model_comm_pgs: ModelCommProcessGroups = None,
):
self.config = config
self.te_forward_mask_type = False
......@@ -171,26 +166,6 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if model_comm_pgs is None:
# For backward compatibility, remove in v0.14 and raise error
# raise ValueError("TEDotProductAttention was called without ModelCommProcessGroups")
model_comm_pgs = ModelCommProcessGroups(
tp=get_tensor_model_parallel_group(check_initialized=False),
cp=get_context_parallel_group(check_initialized=False),
hcp=get_hierarchical_context_parallel_groups(check_initialized=False),
)
else:
assert hasattr(
model_comm_pgs, 'tp'
), "TEDotProductAttention model_comm_pgs must have tp pg"
assert hasattr(
model_comm_pgs, 'cp'
), "TEDotProductAttention model_comm_pgs must have cp pg"
if cp_comm_type == "a2a+p2p":
assert hasattr(
model_comm_pgs, 'hcp'
), "TEDotProductAttention model_comm_pgs must have hierarchical cp pg"
if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
......@@ -206,9 +181,9 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = model_comm_pgs.cp
extra_kwargs["cp_global_ranks"] = torch.distributed.get_process_group_ranks(
model_comm_pgs.cp
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
if is_te_min_version("1.10.0"):
......@@ -282,7 +257,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
tp_group=model_comm_pgs.tp,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
**extra_kwargs,
)
......@@ -313,7 +288,6 @@ if is_te_min_version("1.9.0.dev0"):
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
......@@ -329,7 +303,6 @@ if is_te_min_version("1.9.0.dev0"):
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
)
def backward_dw(self):
......
......@@ -289,7 +289,7 @@ class MoeAttnNode(TransformerLayerNode):
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
probs,
) = self.layer._submodule_attention_router_compound_forward(
hidden_states,
attention_mask=attention_mask,
......@@ -305,10 +305,11 @@ class MoeAttnNode(TransformerLayerNode):
self.common_state.tokens_per_expert = tokens_per_expert
# detached here
self.common_state.probs = self.detach(probs)
self.common_state.residual = self.detach(hidden_states)
self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output)
return permutated_local_input_tokens, permuted_probs
return permutated_local_input_tokens
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
......@@ -317,27 +318,26 @@ class MoeAttnNode(TransformerLayerNode):
class MoeDispatchNode(TransformerLayerNode):
def forward_impl(self, permutated_local_input_tokens, permuted_probs):
def forward_impl(self, permutated_local_input_tokens):
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
inputs = permutated_local_input_tokens
tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all(
self.common_state.tokens_per_expert, permutated_local_input_tokens, permuted_probs
tokens_per_expert, global_input_tokens = token_dispatcher.dispatch_all_to_all(
self.common_state.tokens_per_expert, permutated_local_input_tokens
)
# release tensor not used by backward
# inputs.untyped_storage().resize_(0)
self.common_state.tokens_per_expert = tokens_per_expert
return global_input_tokens, global_probs
return global_input_tokens
class MoeMlPNode(TransformerLayerNode):
def forward_impl(self, global_input_tokens, global_probs):
def forward_impl(self, global_input_tokens):
pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward(
self.common_state.tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output
self.common_state.tokens_per_expert, global_input_tokens, pre_mlp_layernorm_output
)
assert mlp_bias is None
......@@ -364,7 +364,9 @@ class MoeCombineNode(TransformerLayerNode):
)
cur_stream = torch.cuda.current_stream()
self.common_state.residual.record_stream(cur_stream)
self.common_state.probs.record_stream(cur_stream)
self.common_state.residual = None
self.common_state.probs = None
return output
......
......@@ -13,8 +13,6 @@ from megatron.core.utils import (
get_model_config,
get_model_type,
get_model_xattn,
nvtx_range_pop,
nvtx_range_push,
)
from megatron.core.pipeline_parallel.schedules import (
forward_step,
......@@ -90,16 +88,6 @@ def get_pp_rank_microbatches(
)
def print_rank_4(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 4:
print(message, flush=True)
else:
print(message, flush=True)
from megatron.training import print_rank_0, print_rank_last
def forward_backward_pipelining_with_interleaving(
*,
forward_step_func,
......@@ -108,11 +96,10 @@ def forward_backward_pipelining_with_interleaving(
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: Optional[int] = None,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: Optional[bool] = None,
adjust_tensor_shapes_fn: Optional[Callable] = None, # unused
first_val_step: bool = None,
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
......@@ -133,9 +120,6 @@ def forward_backward_pipelining_with_interleaving(
assert isinstance(
data_iterator, list
), "interleaved pipeline parallelism expected each model chunk to have a data iterator"
assert (
adjust_tensor_shapes_fn is None
), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism"
config = get_model_config(model[0])
......@@ -310,9 +294,6 @@ def forward_backward_pipelining_with_interleaving(
# Both tables are indexed with virtual_microbatch_id.
microbatch_id_table, model_chunk_id_table = zip(*schedule_table)
print_rank_4(f"rank last. microbatch_id_table: {microbatch_id_table}. model_chunk_id_table: {model_chunk_id_table}")
print_rank_0(f"rank first. microbatch_id_table: {microbatch_id_table}. model_chunk_id_table: {model_chunk_id_table}")
def get_model_chunk_id(virtual_microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
model_chunk_id = model_chunk_id_table[virtual_microbatch_id % total_num_microbatches]
......@@ -432,7 +413,7 @@ def forward_backward_pipelining_with_interleaving(
)
# forward step
if parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
......@@ -460,7 +441,6 @@ def forward_backward_pipelining_with_interleaving(
is_first_microbatch_for_model_chunk(virtual_microbatch_id),
),
current_microbatch=microbatch_id,
vp_stage=model_chunk_id,
)
output_tensors[model_chunk_id].append(output_tensor)
......@@ -491,7 +471,7 @@ def forward_backward_pipelining_with_interleaving(
synchronized_model_chunks.add(model_chunk_id)
# pylint: disable=E0606
if parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
......@@ -730,14 +710,9 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = post_backward(input_tensor_grad)
return output_tensor, input_tensor_grad
is_vp_first_stage = partial(parallel_state.is_pipeline_first_stage, ignore_virtual=False)
is_vp_last_stage = partial(parallel_state.is_pipeline_last_stage, ignore_virtual=False)
# Run warmup forward passes.
nvtx_range_push(suffix="warmup")
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape, config, is_vp_first_stage())
)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
fwd_wait_handles = None
fwd_wait_recv_handles = None
......@@ -767,7 +742,7 @@ def forward_backward_pipelining_with_interleaving(
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if config.overlap_p2p_comm_warmup_flush:
if not is_vp_first_stage(vp_stage=cur_model_chunk_id) and k != 0:
if not parallel_state.is_pipeline_first_stage(ignore_virtual=False) and k != 0:
assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, iteration {k},'
'should have registered recv handle'
......@@ -814,7 +789,7 @@ def forward_backward_pipelining_with_interleaving(
)
# Don't send tensor downstream if on last stage.
if is_vp_last_stage(vp_stage=cur_model_chunk_id):
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
......@@ -917,17 +892,12 @@ def forward_backward_pipelining_with_interleaving(
if recv_next:
output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1])
nvtx_range_pop(suffix="warmup")
# Run 1F1B in steady state.
nvtx_range_push(suffix="steady")
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
print_rank_0(f"rank first. 1F1B in steady state: {k}/{num_microbatches_remaining}")
print_rank_4(f"rank last. 1F1B in steady state: {k}/{num_microbatches_remaining}")
# Decide to checkpoint all layers' activations of the current micro-batch.
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
......@@ -944,7 +914,7 @@ def forward_backward_pipelining_with_interleaving(
cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if not is_vp_first_stage(vp_stage=cur_model_chunk_id):
if not parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if config.overlap_p2p_comm_warmup_flush:
assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, fwd iteration {forward_k}, '
......@@ -972,7 +942,7 @@ def forward_backward_pipelining_with_interleaving(
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
# Last virtual stage no activation tensor to send.
if is_vp_last_stage(vp_stage=forward_model_chunk_id):
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
output_tensor = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
......@@ -999,9 +969,7 @@ def forward_backward_pipelining_with_interleaving(
send_next_wait_handle.wait()
if fwd_wait_handles is not None:
send_next_wait_handle = (
fwd_wait_handles.pop("send_next")
if "send_next" in fwd_wait_handles
else None
fwd_wait_handles.pop("send_next") if "send_next" in fwd_wait_handles else None
)
if "recv_prev" in fwd_wait_handles:
recv_prev_wait_handles.append(fwd_wait_handles.pop("recv_prev"))
......@@ -1024,7 +992,7 @@ def forward_backward_pipelining_with_interleaving(
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if not is_vp_last_stage(vp_stage=backward_model_chunk_id):
if not parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, bwd iteration {backward_k}, '
......@@ -1048,7 +1016,7 @@ def forward_backward_pipelining_with_interleaving(
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
# First virtual stage no activation gradient tensor to send.
if is_vp_first_stage(vp_stage=backward_model_chunk_id):
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
input_tensor_grad = None
recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage(
......@@ -1091,9 +1059,6 @@ def forward_backward_pipelining_with_interleaving(
post_backward=pp_post_backward,
checkpoint_activations_microbatch=checkpoint_activations_microbatch,
)
print_rank_0(f"rank first. 1F1B in steady state: {k}/{num_microbatches_remaining} end")
print_rank_4(f"rank last. 1F1B in steady state: {k}/{num_microbatches_remaining} end")
else: # No p2p overlap.
backward_k = k
output_tensor, input_tensor_grad = forward_backward_helper_wrapper(
......@@ -1109,12 +1074,12 @@ def forward_backward_pipelining_with_interleaving(
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if is_vp_last_stage(vp_stage=forward_model_chunk_id):
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if is_vp_first_stage(vp_stage=backward_model_chunk_id):
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
input_tensor_grad = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
......@@ -1150,16 +1115,9 @@ def forward_backward_pipelining_with_interleaving(
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
if k == 0:
print_rank_0(f"input_tensor_grad: {input_tensor_grad}")
print_rank_0(f"rank first. 1F1B in steady state end")
print_rank_4(f"rank last. 1F1B in steady state end")
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
nvtx_range_pop(suffix="steady")
# Run cooldown backward passes (flush out pipeline).
nvtx_range_push(suffix="cooldown")
if not forward_only:
if bwd_wait_handles is not None:
for bwd_wait_handle in bwd_wait_handles.values():
......@@ -1167,14 +1125,12 @@ def forward_backward_pipelining_with_interleaving(
if are_all_microbatches_in_warmup:
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(
tensor_shape, config=config, is_last_stage=is_vp_last_stage()
)
p2p_communication.recv_backward(tensor_shape, config=config)
)
for k in range(num_microbatches_remaining, total_num_microbatches):
cur_model_chunk_id = get_model_chunk_id(k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if not is_vp_last_stage(vp_stage=cur_model_chunk_id) and k != 0:
if not parallel_state.is_pipeline_last_stage(ignore_virtual=False) and k != 0:
if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, backward iteration {k}, '
......@@ -1214,7 +1170,7 @@ def forward_backward_pipelining_with_interleaving(
_, input_tensor_grad = forward_backward_helper_wrapper(b_virtual_microbatch_id=k)
# First virtual stage no activation gradient tensor to send.
if is_vp_first_stage(vp_stage=cur_model_chunk_id):
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
input_tensor_grad = None
if config.overlap_p2p_comm_warmup_flush:
......@@ -1271,9 +1227,7 @@ def forward_backward_pipelining_with_interleaving(
if model_chunk_id not in synchronized_model_chunks:
config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id)
nvtx_range_pop(suffix="cooldown")
nvtx_range_push(suffix="misc")
assert (
not recv_prev_wait_handles
), 'recv_prev_wait_handles should be cleared at the end of a step'
......@@ -1303,7 +1257,5 @@ def forward_backward_pipelining_with_interleaving(
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
nvtx_range_pop(suffix="misc")
return forward_data_store
......@@ -108,20 +108,18 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self.hidden_shape_before_permute = hidden_states.shape
(
permutated_local_input_tokens,
permuted_probs,
self.reversed_local_input_permutation_mapping,
) = permute(
hidden_states,
routing_map,
probs=self.probs,
num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad,
)
return tokens_per_expert, permutated_local_input_tokens, permuted_probs
return tokens_per_expert, permutated_local_input_tokens
def dispatch_all_to_all(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs):
def dispatch_all_to_all(self, tokens_per_expert, permutated_local_input_tokens):
# Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert
......@@ -129,13 +127,10 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
global_input_tokens = all_to_all(
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits
)
global_probs = all_to_all(
self.ep_group, permuted_probs, self.output_splits, self.input_splits
)
return tokens_per_expert, global_input_tokens, global_probs
return tokens_per_expert, global_input_tokens
def dispatch_postprocess(self, tokens_per_expert, global_input_tokens, global_probs):
def dispatch_postprocess(self, tokens_per_expert, global_input_tokens):
if self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
......@@ -148,9 +143,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
global_input_tokens = gather_from_sequence_parallel_region(
global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes
)
global_probs = gather_from_sequence_parallel_region(
global_probs, group=self.tp_group, output_split_sizes=output_split_sizes
)
# Permutation 2: Sort tokens by local expert.
tokens_per_expert = self._maybe_dtoh_and_synchronize(
......@@ -169,28 +161,16 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
global_probs = (
global_probs.view(
self.tp_size * self.ep_size,
self.num_local_experts,
self.capacity,
*global_probs.size()[1:],
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
else:
global_input_tokens, global_probs = sort_chunks_by_idxs(
global_input_tokens = sort_chunks_by_idxs(
global_input_tokens,
self.num_global_tokens_per_local_expert.ravel(),
self.sort_input_by_local_experts,
probs=global_probs,
fused=self.config.moe_permute_fusion,
)
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert, global_probs
return global_input_tokens, tokens_per_expert
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
......@@ -218,15 +198,15 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
tokens_per_expert, permutated_local_input_tokens = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
# Perform expert parallel AlltoAll communication
tokens_per_expert, global_input_tokens, global_probs = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
tokens_per_expert, global_input_tokens = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens)
# Permutation 2: Sort tokens by local expert.
global_input_tokens, tokens_per_expert, global_probs = self.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
global_input_tokens, tokens_per_expert = self.dispatch_postprocess(tokens_per_expert, global_input_tokens)
return global_input_tokens, tokens_per_expert, global_probs
return global_input_tokens, tokens_per_expert
def combine_preprocess(self, hidden_states):
# Unpermutation 2: Unsort tokens by local expert.
......@@ -283,6 +263,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
restore_shape=self.hidden_shape_before_permute,
probs=self.probs,
routing_map=self.routing_map,
fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad,
......
......@@ -8,7 +8,11 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config']
if hasattr(config, "mtp_num_layers") and config.mtp_num_layers is not None:
if (
hasattr(config, "mtp_num_layers")
and config.mtp_num_layers is not None
and config.mtp_num_layers > 0
):
self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None
......
......@@ -66,7 +66,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
probs,
) = self._submodule_attention_router_compound_forward(
hidden_states,
attention_mask,
......@@ -80,16 +80,14 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params=inference_params,
)
(tokens_per_expert, global_input_tokens, global_probs) = self._submodule_dispatch_forward(
(tokens_per_expert, global_input_tokens) = self._submodule_dispatch_forward(
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs
)
(expert_output, shared_expert_output, mlp_bias) = self._submodule_moe_forward(
tokens_per_expert,
global_input_tokens,
global_probs,
pre_mlp_layernorm_output
)
......@@ -125,13 +123,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual = hidden_states
# Optional Input Layer norm
if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
self.input_layernorm, hidden_states
)
else:
input_layernorm_output = self.input_layernorm(hidden_states)
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
......@@ -146,13 +138,6 @@ class TransformerLayer(MegatronCoreTransformerLayer):
sequence_len_offset=sequence_len_offset,
)
if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self.input_layernorm_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
......@@ -193,19 +178,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
)
# Optional Layer norm post the cross-attention.
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
self.pre_mlp_layernorm, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare(
pre_mlp_layernorm_output, probs, routing_map
)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.mlp.token_dispatcher.dispatch_preprocess(
tokens_per_expert, permutated_local_input_tokens = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, routing_map, tokens_per_expert
)
......@@ -214,18 +193,18 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
probs,
]
return tuple(outputs)
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs):
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens):
"""
Dispatches tokens to the appropriate experts based on the router output.
"""
tokens_per_expert, global_input_tokens, global_probs = self.mlp.token_dispatcher.dispatch_all_to_all(
tokens_per_expert, permutated_local_input_tokens, permuted_probs
tokens_per_expert, global_input_tokens = self.mlp.token_dispatcher.dispatch_all_to_all(
tokens_per_expert, permutated_local_input_tokens
)
return [tokens_per_expert, global_input_tokens, global_probs]
return [tokens_per_expert, global_input_tokens]
def _submodule_dense_forward(self, hidden_states):
residual = hidden_states
......@@ -241,16 +220,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output):
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, pre_mlp_layernorm_output):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
shared_expert_output = None
(dispatched_input, tokens_per_expert, permuted_probs) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
(dispatched_input, tokens_per_expert) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens)
)
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs)
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
......
......@@ -3,7 +3,6 @@ import argparse
from typing import Union
from megatron.training.arguments import add_megatron_arguments
from megatron.core.msc_utils import MultiStorageClientFeature
def remove_original_params(parser, param_names: Union[list, str]):
......@@ -60,12 +59,6 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
# args.rank = int(os.getenv('RANK', '0'))
# args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Args to disable MSC
if not args.enable_msc:
MultiStorageClientFeature.disable()
assert MultiStorageClientFeature.is_enabled() is False
print('WARNING: The MSC feature is disabled.')
return args
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT."""
import os, sys
current_dir = os.path.dirname(os.path.abspath(__file__))
megatron_path = os.path.join(current_dir, "Megatron-LM")
sys.path.append(megatron_path)
from functools import partial
from typing import List, Optional, Tuple, Union
import os
import torch
from functools import partial
from contextlib import nullcontext
import inspect
from megatron.core import parallel_state
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
from typing import List, Optional, Tuple, Union
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.core.models.gpt import GPTModel
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
get_gpt_heterogeneous_layer_spec,
)
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.core.transformer.spec_utils import import_module
import megatron.legacy.model
from megatron.core.models.gpt import GPTModel
from megatron.training import pretrain
from megatron.core.utils import StragglerDetector
from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.core.transformer.spec_utils import import_module
from megatron.training.utils import (
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
get_blend_and_blend_per_split,
)
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
import megatron.legacy.model # isort: skip
# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import
try:
from megatron.post_training.arguments import add_modelopt_args, modelopt_args_enabled
from megatron.post_training.loss_func import loss_func as loss_func_modelopt
from megatron.post_training.model_provider import model_provider as model_provider_modelopt
has_nvidia_modelopt = True
except ImportError:
has_nvidia_modelopt = False
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from dcu_megatron import megatron_adaptor
stimer = StragglerDetector()
def model_provider(
pre_process=True, post_process=True
) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
"""Builds the model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
......@@ -72,34 +55,24 @@ def model_provider(
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
"""
args = get_args()
if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
return model_provider_modelopt(pre_process, post_process)
if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))):
assert args.transformer_impl == "transformer_engine"
use_te = args.transformer_impl == "transformer_engine"
if args.record_memory_history:
torch.cuda.memory._record_memory_history(
True,
torch.cuda.memory._record_memory_history(True,
# keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000,
# record stack information for the trace events
trace_alloc_record_context=True,
)
trace_alloc_record_context=True)
def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened
print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
dump(
snapshot,
open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'),
)
dump(snapshot, open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'))
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
......@@ -118,41 +91,27 @@ def model_provider(
pre_process=pre_process,
post_process=post_process,
)
else: # using core models
else: # using core models
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
if args.num_experts:
# Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(
config, use_transformer_engine=use_te, normalization=args.normalization
)
elif args.heterogeneous_layers_config_path is not None:
transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te, normalization=args.normalization)
else:
# Define the decoder layer spec
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
args.moe_use_legacy_grouped_gemm,
)
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
else:
transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
args.moe_use_legacy_grouped_gemm,
normalization=args.normalization,
)
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm,
normalization=args.normalization)
mtp_block_spec = None
if args.mtp_num_layers is not None:
mtp_block_spec = get_gpt_mtp_block_spec(
config, transformer_layer_spec, use_transformer_engine=use_te
)
mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te)
model = GPTModel(
config=config,
......@@ -170,6 +129,7 @@ def model_provider(
rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec,
)
print_rank_0(model)
return model
......@@ -178,9 +138,7 @@ def get_batch(data_iterator):
"""Generate a batch."""
# TODO: this is pretty hacky, find a better way
if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and (
not parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
# get batches based on the TP rank you are on
......@@ -196,15 +154,12 @@ def get_batch(data_iterator):
SPIKY_LOSS_FACTOR = 10
def loss_func(
loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[GPTModel] = None
):
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
"""Loss function.
Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses
model (GPTModel, optional): The model (can be wrapped)
Returns:
the loss scalar for this micro-batch
......@@ -214,16 +169,13 @@ def loss_func(
"""
args = get_args()
if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
return loss_func_modelopt(loss_mask, output_tensor, model=model)
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
if args.context_parallel_size > 1:
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
# Check individual rank losses are not NaN prior to DP all-reduce.
rerun_state_machine = get_rerun_state_machine()
......@@ -232,14 +184,14 @@ def loss_func(
result=loss[0],
rejection_func=torch.isnan,
message="found NaN in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=torch.isinf,
message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
# Check for spiky loss
......@@ -252,18 +204,22 @@ def loss_func(
context="loss",
),
message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic
tolerance=0.0, # forward pass calculations are determinisic
fatal=False,
)
# Reduce loss for logging.
reporting_loss = loss.clone().detach()
torch.distributed.all_reduce(reporting_loss, group=parallel_state.get_data_parallel_group())
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
# loss[0] is a view of loss, so it has ._base not None, which triggers assert error
# in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()
# on loss[0] fixes this
local_num_tokens = loss[1].clone().detach().to(torch.int)
return (loss[0].clone(), local_num_tokens, {'lm loss': (reporting_loss[0], reporting_loss[1])})
return (
loss[0].clone(),
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)
def forward_step(data_iterator, model: GPTModel):
......@@ -280,26 +236,25 @@ def forward_step(data_iterator, model: GPTModel):
timers('batch-generator', log_level=2).start()
global stimer
with stimer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
with stimer:
if args.use_legacy_models:
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels, loss_mask=loss_mask)
# [ModelOpt]: model is needed to access ModelOpt distillation losses
return output_tensor, partial(loss_func, loss_mask, model=model)
return output_tensor, partial(loss_func, loss_mask)
def is_dataset_built_on_rank():
return (
parallel_state.is_pipeline_first_stage(ignore_virtual=True)
or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
) and parallel_state.get_tensor_model_parallel_rank() == 0
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
) and mpu.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args):
......@@ -324,8 +279,7 @@ def core_gpt_dataset_config_from_args(args):
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader,
object_storage_cache_path=args.object_storage_cache_path,
mid_level_dataset_surplus=args.mid_level_dataset_surplus,
s3_cache_path=args.s3_cache_path,
)
......@@ -347,7 +301,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0("> building train, validation, and test datasets for GPT ...")
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config
dataset_type,
train_val_test_num_samples,
is_dataset_built_on_rank,
config
).build()
print_rank_0("> finished creating GPT datasets ...")
......@@ -366,5 +323,4 @@ if __name__ == "__main__":
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment