"scripts/playground/load_tokenizer.py" did not exist on "a385ee27bd0025781eba61578889e470a1c027fb"
Commit c964fcca authored by dongcl's avatar dongcl
Browse files

only transformer-engine>=2.4.0.dev supports split_bw

parent ab2a8334
__pycache__
*.bak
......@@ -7,7 +7,6 @@ def a2a_overlap_adaptation(patches_manager):
"""
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear
from ..core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from ..core.transformer.transformer_block import TransformerBlock
from ..core.transformer.transformer_layer import TransformerLayer
from ..core.models.gpt.gpt_model import GPTModel
from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
......@@ -32,19 +31,18 @@ def a2a_overlap_adaptation(patches_manager):
patches_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
MoEAlltoAllTokenDispatcher)
patches_manager.register_patch('megatron.core.transformer.transformer_block.TransformerBlock',
TransformerBlock)
patches_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer',
TransformerLayer)
patches_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel',
GPTModel)
patches_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan',
GPTModel.build_schedule_plan,
create_dummy=True)
# backward_dw
# patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
# _get_extra_te_kwargs_wrapper,
# apply_wrapper=True)
if is_te_min_version("2.4.0.dev0"):
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper,
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
......
......@@ -101,11 +101,11 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward)
# MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
# gpt_model_init_wrapper,
# apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
# gpt_model_forward)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
......
......@@ -29,7 +29,8 @@ def _get_extra_te_kwargs_wrapper(fn):
@wraps(fn)
def wrapper(config: TransformerConfig):
extra_transformer_engine_kwargs = fn(config)
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw if hasattr(config, "split_bw") else False
if hasattr(config, "split_bw"):
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw
return extra_transformer_engine_kwargs
return wrapper
......
......@@ -261,9 +261,6 @@ class TransformerLayerNode(ScheduleNode):
def backward_impl(self, outputs, output_grad):
detached_grad = tuple([e.grad for e in self.detached])
grads = output_grad + detached_grad
# if len(detached_grad):
# print(f"output_grad: {grads}")
self.default_backward_func(outputs + self.before_detached, grads)
self.before_detached = None
self.detached = None
......@@ -344,7 +341,7 @@ class MoeMlPNode(TransformerLayerNode):
)
assert mlp_bias is None
# pre_mlp_layernorm_output used
# pre_mlp_layernorm_output used
self.common_state.pre_mlp_layernorm_output = None
return expert_output, shared_expert_output
......
......@@ -11,7 +11,6 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel
def gpt_model_init_wrapper(fn):
......@@ -232,17 +231,10 @@ def gpt_model_forward(
return loss
class GPTModel(MegatronCoreGPTModel):
class GPTModel:
"""
patch megatron GPTModel
"""
def get_transformer_callables_by_layer(self, layer_number: int):
"""
Get the callables for the layer at the given transformer layer number.
"""
return self.decoder.get_layer_callables(layer_number)
def build_schedule_plan(
self,
input_ids: Tensor,
......
......@@ -37,7 +37,8 @@ def set_current_microbatch(model, microbatch_id):
except RuntimeError:
decoder_exists = False
if decoder_exists and decoder is not None:
decoder.current_microbatch = microbatch_id
for layer in decoder.layers:
layer.current_microbatch = microbatch_id
def get_pp_rank_microbatches(
......@@ -87,6 +88,16 @@ def get_pp_rank_microbatches(
)
def print_rank_4(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 4:
print(message, flush=True)
else:
print(message, flush=True)
from megatron.training import print_rank_0, print_rank_last
def forward_backward_pipelining_with_interleaving(
*,
forward_step_func,
......@@ -297,6 +308,9 @@ def forward_backward_pipelining_with_interleaving(
# Both tables are indexed with virtual_microbatch_id.
microbatch_id_table, model_chunk_id_table = zip(*schedule_table)
print_rank_4(f"rank last. microbatch_id_table: {microbatch_id_table}. model_chunk_id_table: {model_chunk_id_table}")
print_rank_0(f"rank first. microbatch_id_table: {microbatch_id_table}. model_chunk_id_table: {model_chunk_id_table}")
def get_model_chunk_id(virtual_microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
model_chunk_id = model_chunk_id_table[virtual_microbatch_id % total_num_microbatches]
......@@ -687,6 +701,7 @@ def forward_backward_pipelining_with_interleaving(
post_backward=post_backward,
)
else:
output_tensor = None
input_tensor_grad = None
if f_virtual_microbatch_id is not None:
# forward pass
......@@ -711,7 +726,7 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = backward_step_helper(b_virtual_microbatch_id)
if post_backward is not None:
input_tensor_grad = post_backward(input_tensor_grad)
return output_tensor if f_virtual_microbatch_id is not None else None, input_tensor_grad
return output_tensor, input_tensor_grad
# Run warmup forward passes.
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
......@@ -897,11 +912,13 @@ def forward_backward_pipelining_with_interleaving(
output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1])
# Run 1F1B in steady state.
output_tensor = None
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
print_rank_0(f"rank first. 1F1B in steady state: {k}/{num_microbatches_remaining}")
print_rank_4(f"rank last. 1F1B in steady state: {k}/{num_microbatches_remaining}")
# Decide to checkpoint all layers' activations of the current micro-batch.
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
......@@ -1053,6 +1070,9 @@ def forward_backward_pipelining_with_interleaving(
post_backward=pp_post_backward,
checkpoint_activations_microbatch=checkpoint_activations_microbatch,
)
print_rank_0(f"rank first. 1F1B in steady state: {k}/{num_microbatches_remaining} end")
print_rank_4(f"rank last. 1F1B in steady state: {k}/{num_microbatches_remaining} end")
else: # No p2p overlap.
backward_k = k
output_tensor, input_tensor_grad = forward_backward_helper_wrapper(
......@@ -1109,6 +1129,11 @@ def forward_backward_pipelining_with_interleaving(
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
if k == 0:
print_rank_0(f"input_tensor_grad: {input_tensor_grad}")
print_rank_0(f"rank first. 1F1B in steady state end")
print_rank_4(f"rank last. 1F1B in steady state end")
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# Run cooldown backward passes (flush out pipeline).
......
from functools import wraps
from megatron.core.transformer.transformer_block import TransformerBlock as MegatronCoreTransformerBlock
def transformer_block_init_wrapper(fn):
@wraps(fn)
......@@ -14,12 +13,3 @@ def transformer_block_init_wrapper(fn):
self.final_layernorm = None
return wrapper
class TransformerBlock(MegatronCoreTransformerBlock):
def get_layer_callables(self, layer_number: int):
"""
Get the callables for the layer at the given layer number.
"""
return self.layers[layer_number].get_submodule_callables()
......@@ -45,6 +45,65 @@ class TransformerLayer(MegatronCoreTransformerLayer):
torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors
def forward(
self,
hidden_states: Tensor,
context: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[Any] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
):
(
hidden_states,
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
) = self._submodule_attention_router_compound_forward(
hidden_states,
attention_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
attention_bias,
inference_context,
packed_seq_params,
sequence_len_offset,
inference_params=inference_params,
)
(tokens_per_expert, global_input_tokens, global_probs) = self._submodule_dispatch_forward(
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs
)
(expert_output, shared_expert_output, mlp_bias) = self._submodule_moe_forward(
tokens_per_expert,
global_input_tokens,
global_probs,
pre_mlp_layernorm_output
)
expert_output = self._submodule_combine_forward(expert_output)[0]
output = self._submodule_post_combine_forward(
expert_output,
shared_expert_output,
mlp_bias,
hidden_states
)
return output, None
def _submodule_attention_forward(
self,
hidden_states: Tensor,
......@@ -182,14 +241,14 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_prob, pre_mlp_layernorm_output):
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
shared_expert_output = None
(dispatched_input, tokens_per_expert, permuted_probs) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_prob)
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
)
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
......@@ -221,141 +280,10 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output
def _submodule_attention_backward(
self, hidden_states, pre_mlp_layernorm_output, detached_inputs
):
pre_mlp_layernorm_output.backward(detached_inputs[1].grad)
hidden_states.backward(detached_inputs[0].grad)
def _submodule_attention_router_compound_backward(
self,
hidden_states,
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
probs,
detached_inputs,
):
permutated_local_input_tokens.backward(detached_inputs[3].grad)
probs.backward(detached_inputs[4].grad)
# tokens_per_expert.backward(detached_inputs[2].grad)
pre_mlp_layernorm_output.backward(detached_inputs[1].grad)
hidden_states.backward(detached_inputs[0].grad)
def _submodule_dispatch_backward(self, global_input_tokens, detached_inputs):
global_input_tokens.backward(detached_inputs[0].grad)
def _submodule_dense_backward(self, output, detached_inputs):
output.backward(detached_inputs[0].grad)
def _submodule_moe_backward(
self, expert_output, shared_expert_output, mlp_bias, detached_inputs
):
expert_output.backward(detached_inputs[0].grad)
shared_expert_output.backward(detached_inputs[1].grad)
if mlp_bias is not None:
mlp_bias.backward(detached_inputs[2].grad)
def _submodule_combine_backward(self, hidden_states, detached_inputs):
hidden_states.backward(detached_inputs[0].grad)
def _submodule_post_combine_backward(self, output, output_grad):
output.backward(output_grad)
def _submodule_attention_router_compound_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_attention_router_compound_dw(self):
self.self_attention.backward_dw()
# raise NotImplementedError("Not implemented")
def _submodule_dispatch_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_mlp_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_mlp_dw(self):
self.mlp.backward_dw()
# raise NotImplementedError("Not implemented")
def _submodule_combine_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_identity_forward(self, *args):
return args
def _submodule_identity_backward(self, *args):
pass
def get_submodule_callables(self):
"""
Returns a dictionary of submodule callables for the transformer layer.
"""
from megatron.core.transformer.moe.moe_layer import MoELayer
def get_func_with_default(func, default_func):
if isinstance(self.mlp, MoELayer):
return func
return default_func
attention_func = get_func_with_default(
self._submodule_attention_router_compound_forward, self._submodule_attention_forward
)
attention_backward_func = get_func_with_default(
self._submodule_attention_router_compound_backward, self._submodule_attention_backward
)
dispatch_func = get_func_with_default(
self._submodule_dispatch_forward, self._submodule_identity_forward
)
dispatch_backward_func = get_func_with_default(
self._submodule_dispatch_backward, self._submodule_identity_backward
)
mlp_func = get_func_with_default(self._submodule_moe_forward, self._submodule_dense_forward)
mlp_backward_func = get_func_with_default(
self._submodule_moe_backward, self._submodule_dense_backward
)
combine_func = get_func_with_default(
self._submodule_combine_forward, self._submodule_identity_forward
)
combine_backward_func = get_func_with_default(
self._submodule_combine_backward, self._submodule_identity_backward
)
post_combine_func = get_func_with_default(
self._submodule_post_combine_forward, self._submodule_identity_forward
)
post_combine_backward_func = get_func_with_default(
self._submodule_post_combine_backward, self._submodule_identity_backward
)
callables = TransformerLayerSubmoduleCallables(
attention=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, attention_func, skip_detach=True),
backward=partial(self._callable_wrapper, False, attention_backward_func),
# dgrad=partial(self._callable_wrapper, False,self._submodule_attention_router_compound_dgrad),
dw=partial(
self._callable_wrapper, False, self._submodule_attention_router_compound_dw
),
),
dispatch=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, dispatch_func),
backward=partial(self._callable_wrapper, False, dispatch_backward_func),
# dgrad=partial(self._callable_wrapper, False, self._submodule_dispatch_dgrad),
),
mlp=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, mlp_func),
backward=partial(self._callable_wrapper, False, mlp_backward_func),
# dgrad=partial(self._callable_wrapper, False, self._submodule_mlp_dgrad),
dw=partial(self._callable_wrapper, False, self._submodule_mlp_dw),
),
combine=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, combine_func),
backward=partial(self._callable_wrapper, False, combine_backward_func),
# dgrad=partial(self._callable_wrapper, False, self._submodule_combine_dgrad),
),
post_combine=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, post_combine_func),
backward=partial(self._callable_wrapper, False, post_combine_backward_func),
),
)
return callables
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment