Commit c964fcca authored by dongcl's avatar dongcl
Browse files

only transformer-engine>=2.4.0.dev supports split_bw

parent ab2a8334
__pycache__
*.bak
...@@ -7,7 +7,6 @@ def a2a_overlap_adaptation(patches_manager): ...@@ -7,7 +7,6 @@ def a2a_overlap_adaptation(patches_manager):
""" """
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear
from ..core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher from ..core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from ..core.transformer.transformer_block import TransformerBlock
from ..core.transformer.transformer_layer import TransformerLayer from ..core.transformer.transformer_layer import TransformerLayer
from ..core.models.gpt.gpt_model import GPTModel from ..core.models.gpt.gpt_model import GPTModel
from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
...@@ -32,19 +31,18 @@ def a2a_overlap_adaptation(patches_manager): ...@@ -32,19 +31,18 @@ def a2a_overlap_adaptation(patches_manager):
patches_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher', patches_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
MoEAlltoAllTokenDispatcher) MoEAlltoAllTokenDispatcher)
patches_manager.register_patch('megatron.core.transformer.transformer_block.TransformerBlock',
TransformerBlock)
patches_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer', patches_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer',
TransformerLayer) TransformerLayer)
patches_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel', patches_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan',
GPTModel) GPTModel.build_schedule_plan,
create_dummy=True)
# backward_dw # backward_dw
# patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs', if is_te_min_version("2.4.0.dev0"):
# _get_extra_te_kwargs_wrapper, patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
# apply_wrapper=True) _get_extra_te_kwargs_wrapper,
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear', patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear) TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear', patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
......
...@@ -101,11 +101,11 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -101,11 +101,11 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward
# GPT Model # GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', # MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper, # gpt_model_init_wrapper,
apply_wrapper=True) # apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', # MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward) # gpt_model_forward)
def patch_core_transformers(self): def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper from ..core import transformer_block_init_wrapper
......
...@@ -29,7 +29,8 @@ def _get_extra_te_kwargs_wrapper(fn): ...@@ -29,7 +29,8 @@ def _get_extra_te_kwargs_wrapper(fn):
@wraps(fn) @wraps(fn)
def wrapper(config: TransformerConfig): def wrapper(config: TransformerConfig):
extra_transformer_engine_kwargs = fn(config) extra_transformer_engine_kwargs = fn(config)
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw if hasattr(config, "split_bw") else False if hasattr(config, "split_bw"):
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw
return extra_transformer_engine_kwargs return extra_transformer_engine_kwargs
return wrapper return wrapper
......
...@@ -261,9 +261,6 @@ class TransformerLayerNode(ScheduleNode): ...@@ -261,9 +261,6 @@ class TransformerLayerNode(ScheduleNode):
def backward_impl(self, outputs, output_grad): def backward_impl(self, outputs, output_grad):
detached_grad = tuple([e.grad for e in self.detached]) detached_grad = tuple([e.grad for e in self.detached])
grads = output_grad + detached_grad grads = output_grad + detached_grad
# if len(detached_grad):
# print(f"output_grad: {grads}")
self.default_backward_func(outputs + self.before_detached, grads) self.default_backward_func(outputs + self.before_detached, grads)
self.before_detached = None self.before_detached = None
self.detached = None self.detached = None
......
...@@ -11,7 +11,6 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to ...@@ -11,7 +11,6 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to
from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import WrappedTensor, deprecate_inference_params from megatron.core.utils import WrappedTensor, deprecate_inference_params
from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel
def gpt_model_init_wrapper(fn): def gpt_model_init_wrapper(fn):
...@@ -232,17 +231,10 @@ def gpt_model_forward( ...@@ -232,17 +231,10 @@ def gpt_model_forward(
return loss return loss
class GPTModel(MegatronCoreGPTModel): class GPTModel:
""" """
patch megatron GPTModel patch megatron GPTModel
""" """
def get_transformer_callables_by_layer(self, layer_number: int):
"""
Get the callables for the layer at the given transformer layer number.
"""
return self.decoder.get_layer_callables(layer_number)
def build_schedule_plan( def build_schedule_plan(
self, self,
input_ids: Tensor, input_ids: Tensor,
......
...@@ -37,7 +37,8 @@ def set_current_microbatch(model, microbatch_id): ...@@ -37,7 +37,8 @@ def set_current_microbatch(model, microbatch_id):
except RuntimeError: except RuntimeError:
decoder_exists = False decoder_exists = False
if decoder_exists and decoder is not None: if decoder_exists and decoder is not None:
decoder.current_microbatch = microbatch_id for layer in decoder.layers:
layer.current_microbatch = microbatch_id
def get_pp_rank_microbatches( def get_pp_rank_microbatches(
...@@ -87,6 +88,16 @@ def get_pp_rank_microbatches( ...@@ -87,6 +88,16 @@ def get_pp_rank_microbatches(
) )
def print_rank_4(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 4:
print(message, flush=True)
else:
print(message, flush=True)
from megatron.training import print_rank_0, print_rank_last
def forward_backward_pipelining_with_interleaving( def forward_backward_pipelining_with_interleaving(
*, *,
forward_step_func, forward_step_func,
...@@ -297,6 +308,9 @@ def forward_backward_pipelining_with_interleaving( ...@@ -297,6 +308,9 @@ def forward_backward_pipelining_with_interleaving(
# Both tables are indexed with virtual_microbatch_id. # Both tables are indexed with virtual_microbatch_id.
microbatch_id_table, model_chunk_id_table = zip(*schedule_table) microbatch_id_table, model_chunk_id_table = zip(*schedule_table)
print_rank_4(f"rank last. microbatch_id_table: {microbatch_id_table}. model_chunk_id_table: {model_chunk_id_table}")
print_rank_0(f"rank first. microbatch_id_table: {microbatch_id_table}. model_chunk_id_table: {model_chunk_id_table}")
def get_model_chunk_id(virtual_microbatch_id, forward): def get_model_chunk_id(virtual_microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number.""" """Helper method to get the model chunk ID given the iteration number."""
model_chunk_id = model_chunk_id_table[virtual_microbatch_id % total_num_microbatches] model_chunk_id = model_chunk_id_table[virtual_microbatch_id % total_num_microbatches]
...@@ -687,6 +701,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -687,6 +701,7 @@ def forward_backward_pipelining_with_interleaving(
post_backward=post_backward, post_backward=post_backward,
) )
else: else:
output_tensor = None
input_tensor_grad = None input_tensor_grad = None
if f_virtual_microbatch_id is not None: if f_virtual_microbatch_id is not None:
# forward pass # forward pass
...@@ -711,7 +726,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -711,7 +726,7 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = backward_step_helper(b_virtual_microbatch_id) input_tensor_grad = backward_step_helper(b_virtual_microbatch_id)
if post_backward is not None: if post_backward is not None:
input_tensor_grad = post_backward(input_tensor_grad) input_tensor_grad = post_backward(input_tensor_grad)
return output_tensor if f_virtual_microbatch_id is not None else None, input_tensor_grad return output_tensor, input_tensor_grad
# Run warmup forward passes. # Run warmup forward passes.
parallel_state.set_virtual_pipeline_model_parallel_rank(0) parallel_state.set_virtual_pipeline_model_parallel_rank(0)
...@@ -897,11 +912,13 @@ def forward_backward_pipelining_with_interleaving( ...@@ -897,11 +912,13 @@ def forward_backward_pipelining_with_interleaving(
output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1]) output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1])
# Run 1F1B in steady state. # Run 1F1B in steady state.
output_tensor = None
for k in range(num_microbatches_remaining): for k in range(num_microbatches_remaining):
# Forward pass. # Forward pass.
forward_k = k + num_warmup_microbatches forward_k = k + num_warmup_microbatches
print_rank_0(f"rank first. 1F1B in steady state: {k}/{num_microbatches_remaining}")
print_rank_4(f"rank last. 1F1B in steady state: {k}/{num_microbatches_remaining}")
# Decide to checkpoint all layers' activations of the current micro-batch. # Decide to checkpoint all layers' activations of the current micro-batch.
if max_outstanding_backprops is not None: if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = ( checkpoint_activations_microbatch = (
...@@ -1053,6 +1070,9 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1053,6 +1070,9 @@ def forward_backward_pipelining_with_interleaving(
post_backward=pp_post_backward, post_backward=pp_post_backward,
checkpoint_activations_microbatch=checkpoint_activations_microbatch, checkpoint_activations_microbatch=checkpoint_activations_microbatch,
) )
print_rank_0(f"rank first. 1F1B in steady state: {k}/{num_microbatches_remaining} end")
print_rank_4(f"rank last. 1F1B in steady state: {k}/{num_microbatches_remaining} end")
else: # No p2p overlap. else: # No p2p overlap.
backward_k = k backward_k = k
output_tensor, input_tensor_grad = forward_backward_helper_wrapper( output_tensor, input_tensor_grad = forward_backward_helper_wrapper(
...@@ -1109,6 +1129,11 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1109,6 +1129,11 @@ def forward_backward_pipelining_with_interleaving(
if recv_next: if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
if k == 0:
print_rank_0(f"input_tensor_grad: {input_tensor_grad}")
print_rank_0(f"rank first. 1F1B in steady state end")
print_rank_4(f"rank last. 1F1B in steady state end")
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# Run cooldown backward passes (flush out pipeline). # Run cooldown backward passes (flush out pipeline).
......
from functools import wraps from functools import wraps
from megatron.core.transformer.transformer_block import TransformerBlock as MegatronCoreTransformerBlock
def transformer_block_init_wrapper(fn): def transformer_block_init_wrapper(fn):
@wraps(fn) @wraps(fn)
...@@ -14,12 +13,3 @@ def transformer_block_init_wrapper(fn): ...@@ -14,12 +13,3 @@ def transformer_block_init_wrapper(fn):
self.final_layernorm = None self.final_layernorm = None
return wrapper return wrapper
class TransformerBlock(MegatronCoreTransformerBlock):
def get_layer_callables(self, layer_number: int):
"""
Get the callables for the layer at the given layer number.
"""
return self.layers[layer_number].get_submodule_callables()
...@@ -45,6 +45,65 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -45,6 +45,65 @@ class TransformerLayer(MegatronCoreTransformerLayer):
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors return outputs, detached_output_tensors
def forward(
self,
hidden_states: Tensor,
context: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[Any] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
):
(
hidden_states,
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
) = self._submodule_attention_router_compound_forward(
hidden_states,
attention_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
attention_bias,
inference_context,
packed_seq_params,
sequence_len_offset,
inference_params=inference_params,
)
(tokens_per_expert, global_input_tokens, global_probs) = self._submodule_dispatch_forward(
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs
)
(expert_output, shared_expert_output, mlp_bias) = self._submodule_moe_forward(
tokens_per_expert,
global_input_tokens,
global_probs,
pre_mlp_layernorm_output
)
expert_output = self._submodule_combine_forward(expert_output)[0]
output = self._submodule_post_combine_forward(
expert_output,
shared_expert_output,
mlp_bias,
hidden_states
)
return output, None
def _submodule_attention_forward( def _submodule_attention_forward(
self, self,
hidden_states: Tensor, hidden_states: Tensor,
...@@ -182,14 +241,14 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -182,14 +241,14 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output return output
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_prob, pre_mlp_layernorm_output): def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output):
""" """
Performs a forward pass for the MLP submodule, including both expert-based Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations. and optional shared-expert computations.
""" """
shared_expert_output = None shared_expert_output = None
(dispatched_input, tokens_per_expert, permuted_probs) = ( (dispatched_input, tokens_per_expert, permuted_probs) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_prob) self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
) )
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs) expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output) expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
...@@ -221,141 +280,10 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -221,141 +280,10 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output return output
def _submodule_attention_backward(
self, hidden_states, pre_mlp_layernorm_output, detached_inputs
):
pre_mlp_layernorm_output.backward(detached_inputs[1].grad)
hidden_states.backward(detached_inputs[0].grad)
def _submodule_attention_router_compound_backward(
self,
hidden_states,
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
probs,
detached_inputs,
):
permutated_local_input_tokens.backward(detached_inputs[3].grad)
probs.backward(detached_inputs[4].grad)
# tokens_per_expert.backward(detached_inputs[2].grad)
pre_mlp_layernorm_output.backward(detached_inputs[1].grad)
hidden_states.backward(detached_inputs[0].grad)
def _submodule_dispatch_backward(self, global_input_tokens, detached_inputs):
global_input_tokens.backward(detached_inputs[0].grad)
def _submodule_dense_backward(self, output, detached_inputs):
output.backward(detached_inputs[0].grad)
def _submodule_moe_backward(
self, expert_output, shared_expert_output, mlp_bias, detached_inputs
):
expert_output.backward(detached_inputs[0].grad)
shared_expert_output.backward(detached_inputs[1].grad)
if mlp_bias is not None:
mlp_bias.backward(detached_inputs[2].grad)
def _submodule_combine_backward(self, hidden_states, detached_inputs):
hidden_states.backward(detached_inputs[0].grad)
def _submodule_post_combine_backward(self, output, output_grad):
output.backward(output_grad)
def _submodule_attention_router_compound_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_attention_router_compound_dw(self): def _submodule_attention_router_compound_dw(self):
self.self_attention.backward_dw() self.self_attention.backward_dw()
# raise NotImplementedError("Not implemented") # raise NotImplementedError("Not implemented")
def _submodule_dispatch_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_mlp_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_mlp_dw(self): def _submodule_mlp_dw(self):
self.mlp.backward_dw() self.mlp.backward_dw()
# raise NotImplementedError("Not implemented") # raise NotImplementedError("Not implemented")
def _submodule_combine_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_identity_forward(self, *args):
return args
def _submodule_identity_backward(self, *args):
pass
def get_submodule_callables(self):
"""
Returns a dictionary of submodule callables for the transformer layer.
"""
from megatron.core.transformer.moe.moe_layer import MoELayer
def get_func_with_default(func, default_func):
if isinstance(self.mlp, MoELayer):
return func
return default_func
attention_func = get_func_with_default(
self._submodule_attention_router_compound_forward, self._submodule_attention_forward
)
attention_backward_func = get_func_with_default(
self._submodule_attention_router_compound_backward, self._submodule_attention_backward
)
dispatch_func = get_func_with_default(
self._submodule_dispatch_forward, self._submodule_identity_forward
)
dispatch_backward_func = get_func_with_default(
self._submodule_dispatch_backward, self._submodule_identity_backward
)
mlp_func = get_func_with_default(self._submodule_moe_forward, self._submodule_dense_forward)
mlp_backward_func = get_func_with_default(
self._submodule_moe_backward, self._submodule_dense_backward
)
combine_func = get_func_with_default(
self._submodule_combine_forward, self._submodule_identity_forward
)
combine_backward_func = get_func_with_default(
self._submodule_combine_backward, self._submodule_identity_backward
)
post_combine_func = get_func_with_default(
self._submodule_post_combine_forward, self._submodule_identity_forward
)
post_combine_backward_func = get_func_with_default(
self._submodule_post_combine_backward, self._submodule_identity_backward
)
callables = TransformerLayerSubmoduleCallables(
attention=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, attention_func, skip_detach=True),
backward=partial(self._callable_wrapper, False, attention_backward_func),
# dgrad=partial(self._callable_wrapper, False,self._submodule_attention_router_compound_dgrad),
dw=partial(
self._callable_wrapper, False, self._submodule_attention_router_compound_dw
),
),
dispatch=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, dispatch_func),
backward=partial(self._callable_wrapper, False, dispatch_backward_func),
# dgrad=partial(self._callable_wrapper, False, self._submodule_dispatch_dgrad),
),
mlp=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, mlp_func),
backward=partial(self._callable_wrapper, False, mlp_backward_func),
# dgrad=partial(self._callable_wrapper, False, self._submodule_mlp_dgrad),
dw=partial(self._callable_wrapper, False, self._submodule_mlp_dw),
),
combine=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, combine_func),
backward=partial(self._callable_wrapper, False, combine_backward_func),
# dgrad=partial(self._callable_wrapper, False, self._submodule_combine_dgrad),
),
post_combine=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, post_combine_func),
backward=partial(self._callable_wrapper, False, post_combine_backward_func),
),
)
return callables
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment