Commit 56819e16 authored by dongcl's avatar dongcl
Browse files

Merge branch 'a2a_overlap', support 1f1b overlap

parent 1e8185f4
...@@ -39,10 +39,9 @@ def a2a_overlap_adaptation(patches_manager): ...@@ -39,10 +39,9 @@ def a2a_overlap_adaptation(patches_manager):
create_dummy=True) create_dummy=True)
# backward_dw # backward_dw
if is_te_min_version("2.4.0.dev0"): patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs', _get_extra_te_kwargs_wrapper,
_get_extra_te_kwargs_wrapper, apply_wrapper=True)
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear', patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear) TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear', patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
......
import os import os
import copy
import torch import torch
import dataclasses import dataclasses
import transformer_engine as te import transformer_engine as te
...@@ -7,6 +8,7 @@ from functools import wraps ...@@ -7,6 +8,7 @@ from functools import wraps
from typing import Any, Optional, Callable from typing import Any, Optional, Callable
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
from megatron.training import get_args
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel import get_cuda_rng_tracker from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.utils import get_te_version, is_te_min_version from megatron.core.utils import get_te_version, is_te_min_version
...@@ -25,15 +27,17 @@ from megatron.core.parallel_state import ( ...@@ -25,15 +27,17 @@ from megatron.core.parallel_state import (
) )
def _get_extra_te_kwargs_wrapper(fn): def _get_extra_te_kwargs_wrapper(_get_extra_te_kwargs_func):
@wraps(fn) @wraps(_get_extra_te_kwargs_func)
def wrapper(config: TransformerConfig): def wrapper(config: TransformerConfig):
extra_transformer_engine_kwargs = fn(config) extra_transformer_engine_kwargs = _get_extra_te_kwargs_func(config)
if hasattr(config, "split_bw"): if hasattr(config, "split_bw"):
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw
return extra_transformer_engine_kwargs return extra_transformer_engine_kwargs
return wrapper if is_te_min_version("2.3.0.dev0"):
return wrapper
return _get_extra_te_kwargs_func
class TELinear(MegatronCoreTELinear): class TELinear(MegatronCoreTELinear):
...@@ -66,8 +70,14 @@ class TELinear(MegatronCoreTELinear): ...@@ -66,8 +70,14 @@ class TELinear(MegatronCoreTELinear):
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False, is_expert: bool = False,
): ):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False args = get_args()
assert not self.split_bw, "split_bw is currently not supported" self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__( super().__init__(
input_size, input_size,
...@@ -86,6 +96,8 @@ class TELinear(MegatronCoreTELinear): ...@@ -86,6 +96,8 @@ class TELinear(MegatronCoreTELinear):
if not self.split_bw: if not self.split_bw:
return return
return super(MegatronCoreTELinear, self).backward_dw()
class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinear): class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinear):
""" """
...@@ -107,8 +119,14 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea ...@@ -107,8 +119,14 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
): ):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False args = get_args()
assert not self.split_bw, "split_bw is currently not supported" self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__( super().__init__(
input_size, input_size,
...@@ -127,6 +145,8 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea ...@@ -127,6 +145,8 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
if not self.split_bw: if not self.split_bw:
return return
return super(MegatronCoreTELayerNormColumnParallelLinear, self).backward_dw()
class TEDotProductAttentionPatch(te.pytorch.DotProductAttention): class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
def __init__( def __init__(
...@@ -289,8 +309,14 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -289,8 +309,14 @@ if is_te_min_version("1.9.0.dev0"):
is_expert: bool = False, is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
): ):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False args = get_args()
assert not self.split_bw, "split_bw is currently not supported" self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__( super().__init__(
num_gemms, num_gemms,
...@@ -308,3 +334,5 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -308,3 +334,5 @@ if is_te_min_version("1.9.0.dev0"):
def backward_dw(self): def backward_dw(self):
if not self.split_bw: if not self.split_bw:
return return
return super(MegatronCoreTEGroupedLinear, self).backward_dw()
...@@ -540,6 +540,8 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan): ...@@ -540,6 +540,8 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
def state(self): def state(self):
return self._model_chunk_state return self._model_chunk_state
# F_DISPATCH_B_MLP_SYNC_EVENT = torch.cuda.Event()
F_DISPATCH_B_MLP_SYNC_EVENT = None
def schedule_layer_1f1b( def schedule_layer_1f1b(
f_layer, f_layer,
...@@ -579,13 +581,17 @@ def schedule_layer_1f1b( ...@@ -579,13 +581,17 @@ def schedule_layer_1f1b(
with f_context: with f_context:
f_input = f_layer.attn.forward(f_input) f_input = f_layer.attn.forward(f_input)
f_dispatch_b_mlp_sync_event = None
if f_layer is not None and b_layer is not None:
f_dispatch_b_mlp_sync_event = F_DISPATCH_B_MLP_SYNC_EVENT
if f_layer is not None: if f_layer is not None:
with f_context: with f_context:
f_input = f_layer.dispatch.forward(f_input) f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
if b_layer is not None: if b_layer is not None:
with b_context: with b_context:
b_grad = b_layer.mlp.backward(b_grad) b_grad = b_layer.mlp.backward(b_grad, stream_wait_event=f_dispatch_b_mlp_sync_event)
b_grad = b_layer.dispatch.backward(b_grad) b_grad = b_layer.dispatch.backward(b_grad)
b_layer.mlp.dw() b_layer.mlp.dw()
...@@ -688,32 +694,28 @@ def schedule_chunk_1f1b( ...@@ -688,32 +694,28 @@ def schedule_chunk_1f1b(
) )
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
# tail forward
f_input = layer_pre_forward()
del layer_pre_forward
# tail backward # tail backward
grad = layer_pre_backward() grad = layer_pre_backward()
del layer_pre_backward del layer_pre_backward
with b_context: with b_context:
for i in range(overlaped_layers, b_num_layers): for i in range(overlaped_layers, b_num_layers):
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i) b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b") torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b")
tmp, grad, _ = schedule_layer_1f1b(None, b_layer, b_grad=grad) _, grad, _ = schedule_layer_1f1b(None, b_layer, b_grad=grad)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
if b_schedule_plan is not None:
b_schedule_plan.pre_process.backward(grad)
# tail forward
f_input = layer_pre_forward()
del layer_pre_forward
with f_context: with f_context:
for i in range(overlaped_layers, f_num_layers): for i in range(overlaped_layers, f_num_layers):
f_layer = f_schedule_plan.get_layer(i) f_layer = f_schedule_plan.get_layer(i)
torch.cuda.nvtx.range_push(f"layer_{i}f") torch.cuda.nvtx.range_push(f"layer_{i}f")
f_input, tmp, _ = schedule_layer_1f1b(f_layer, None, f_input=f_input) f_input, _, _ = schedule_layer_1f1b(f_layer, None, f_input=f_input)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
f_input = f_schedule_plan.post_process.forward(f_input)
# output pp send receive, overlapped with attn backward # output pp send receive, overlapped with attn backward
if f_schedule_plan is not None and post_forward is not None: if f_schedule_plan is not None and post_forward is not None:
with f_context: with f_context:
...@@ -730,6 +732,13 @@ def schedule_chunk_1f1b( ...@@ -730,6 +732,13 @@ def schedule_chunk_1f1b(
layer_pre_backward_dw() layer_pre_backward_dw()
del layer_pre_backward_dw del layer_pre_backward_dw
with f_context:
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
f_input = f_schedule_plan.post_process.forward(f_input)
with b_context:
if b_schedule_plan is not None:
b_schedule_plan.pre_process.backward(grad)
if f_schedule_plan: if f_schedule_plan:
f_schedule_plan.wait_current_stream() f_schedule_plan.wait_current_stream()
if b_schedule_plan: if b_schedule_plan:
...@@ -764,7 +773,7 @@ def build_model_chunk_schedule_plan( ...@@ -764,7 +773,7 @@ def build_model_chunk_schedule_plan(
state.attention_mask = attention_mask state.attention_mask = attention_mask
state.decoder_input = decoder_input state.decoder_input = decoder_input
state.labels = labels state.labels = labels
state.inference_context =inference_context state.inference_context = inference_context
state.packed_seq_params = packed_seq_params state.packed_seq_params = packed_seq_params
state.extra_block_kwargs = extra_block_kwargs state.extra_block_kwargs = extra_block_kwargs
state.runtime_gather_output = runtime_gather_output state.runtime_gather_output = runtime_gather_output
......
import contextlib import contextlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, List, Tuple, Union from typing import List, Union
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -12,13 +12,10 @@ from megatron.core.distributed import DistributedDataParallel ...@@ -12,13 +12,10 @@ from megatron.core.distributed import DistributedDataParallel
from megatron.core.transformer.module import Float16Module from megatron.core.transformer.module import Float16Module
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor
# Types
Shape = Union[List[int], torch.Size]
def make_viewless(e): def make_viewless(e):
"""make_viewless util func""" """make_viewless util func"""
e = make_viewless_tensor(inp=e, requires_grad=e.requires_grad, keep_graph=True) e = make_viewless_tensor(inp=e, requires_grad=e.requires_grad, keep_graph=True)
...@@ -56,6 +53,11 @@ class ScheduleNode: ...@@ -56,6 +53,11 @@ class ScheduleNode:
self.outputs = None self.outputs = None
def default_backward_func(self, outputs, output_grad): def default_backward_func(self, outputs, output_grad):
# Handle scalar output
if output_grad is None:
assert outputs.numel() == 1, "implicit grad requires scalar output."
output_grad = torch.ones_like(outputs, memory_format=torch.preserve_format)
Variable._execution_engine.run_backward( Variable._execution_engine.run_backward(
tensors=outputs, tensors=outputs,
grad_tensors=output_grad, grad_tensors=output_grad,
...@@ -67,17 +69,20 @@ class ScheduleNode: ...@@ -67,17 +69,20 @@ class ScheduleNode:
) )
return output_grad return output_grad
def forward(self, inputs=()): def forward(self, inputs=(), stream_wait_event=None, stream_record_event=None):
"""schedule node forward""" """schedule node forward"""
if not isinstance(inputs, tuple): if not isinstance(inputs, tuple):
inputs = (inputs,) inputs = (inputs,)
return self._forward(*inputs) return self._forward(*inputs, stream_wait_event=stream_wait_event, stream_record_event=stream_record_event)
def _forward(self, *inputs): def _forward(self, *inputs, stream_wait_event=None, stream_record_event=None):
with stream_acquire_context(self.stream, self.event): with stream_acquire_context(self.stream, self.event):
torch.cuda.nvtx.range_push(f"{self.name} forward") torch.cuda.nvtx.range_push(f"{self.name} forward")
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
if stream_wait_event is not None:
stream_wait_event.wait(self.stream)
self.inputs = [make_viewless(e).detach() if e is not None else None for e in inputs] self.inputs = [make_viewless(e).detach() if e is not None else None for e in inputs]
for i, input in enumerate(self.inputs): for i, input in enumerate(self.inputs):
if input is not None: if input is not None:
...@@ -92,6 +97,10 @@ class ScheduleNode: ...@@ -92,6 +97,10 @@ class ScheduleNode:
data = tuple([make_viewless(e) if isinstance(e, Tensor) else e for e in data]) data = tuple([make_viewless(e) if isinstance(e, Tensor) else e for e in data])
self.output = data self.output = data
if stream_record_event is not None:
stream_record_event.record(self.stream)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
if self.free_inputs: if self.free_inputs:
...@@ -105,16 +114,19 @@ class ScheduleNode: ...@@ -105,16 +114,19 @@ class ScheduleNode:
"""get the forward output""" """get the forward output"""
return self.output return self.output
def backward(self, output_grad): def backward(self, output_grad, stream_wait_event=None, stream_record_event=None):
"""schedule node backward""" """schedule node backward"""
if not isinstance(output_grad, tuple): if not isinstance(output_grad, tuple):
output_grad = (output_grad,) output_grad = (output_grad,)
return self._backward(*output_grad) return self._backward(*output_grad, stream_wait_event=stream_wait_event, stream_record_event=stream_record_event)
def _backward(self, *output_grad): def _backward(self, *output_grad, stream_wait_event=None, stream_record_event=None):
with stream_acquire_context(self.stream, self.event): with stream_acquire_context(self.stream, self.event):
torch.cuda.nvtx.range_push(f"{self.name} backward") torch.cuda.nvtx.range_push(f"{self.name} backward")
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
if stream_wait_event is not None:
stream_wait_event.wait(self.stream)
outputs = self.output outputs = self.output
if not isinstance(outputs, tuple): if not isinstance(outputs, tuple):
outputs = (outputs,) outputs = (outputs,)
...@@ -125,6 +137,10 @@ class ScheduleNode: ...@@ -125,6 +137,10 @@ class ScheduleNode:
output_grad = self.backward_func(outputs, output_grad) output_grad = self.backward_func(outputs, output_grad)
else: else:
output_grad = self.default_backward_func(outputs, output_grad) output_grad = self.default_backward_func(outputs, output_grad)
if stream_record_event is not None:
stream_record_event.record(self.stream)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
# output_grad maybe from another stream # output_grad maybe from another stream
...@@ -192,17 +208,6 @@ def schedule_chunk_1f1b( ...@@ -192,17 +208,6 @@ def schedule_chunk_1f1b(
) )
def schedule_chunk_forward(schedule_plan):
"""model level fine-grained forward schedule"""
f_input = schedule_chunk_1f1b(schedule_plan, None, None)
return f_input
def schedule_chunk_backward(schedule_plan, grad):
"""model level fine-grained backward schedule"""
tmp = schedule_chunk_1f1b(None, schedule_plan, grad)
_COMP_STREAM = None _COMP_STREAM = None
_COM_STREAM = None _COM_STREAM = None
...@@ -215,7 +220,7 @@ def set_streams(comp_stream=None, com_stream=None): ...@@ -215,7 +220,7 @@ def set_streams(comp_stream=None, com_stream=None):
return return
if comp_stream is None: if comp_stream is None:
comp_stream = torch.cuda.Stream(device="cuda") comp_stream = torch.cuda.current_stream()
if com_stream is None: if com_stream is None:
com_stream = torch.cuda.Stream(device="cuda") com_stream = torch.cuda.Stream(device="cuda")
...@@ -342,7 +347,7 @@ def forward_backward_step( ...@@ -342,7 +347,7 @@ def forward_backward_step(
Tensor or list[Tensor]: The output object(s) from the forward step. Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens. Tensor: The number of tokens.
""" """
from .schedules import set_current_microbatch from megatron.core.pipeline_parallel.schedules import set_current_microbatch
if config.timers is not None: if config.timers is not None:
config.timers('forward-compute', log_level=2).start() config.timers('forward-compute', log_level=2).start()
...@@ -441,12 +446,13 @@ def forward_backward_step( ...@@ -441,12 +446,13 @@ def forward_backward_step(
output_tensor, num_tokens, loss_reduced = outputs output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss: if not config.calculate_per_token_loss:
output_tensor /= num_tokens output_tensor /= num_tokens
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor /= num_microbatches output_tensor /= num_microbatches
else: else:
# preserve legacy loss averaging behavior # preserve legacy loss averaging behavior (ie, over the number of microbatches)
# (ie, over the number of microbatches)
assert len(outputs) == 2 assert len(outputs) == 2
output_tensor, loss_reduced = outputs output_tensor, loss_reduced = outputs
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor = output_tensor / num_microbatches output_tensor = output_tensor / num_microbatches
forward_data_store.append(loss_reduced) forward_data_store.append(loss_reduced)
...@@ -464,12 +470,11 @@ def forward_backward_step( ...@@ -464,12 +470,11 @@ def forward_backward_step(
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale # Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly. # explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None: if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available, # Calculate the loss scale based on the grad_scale_func if available, else default to 1.
# else default to 1.
loss_scale = ( loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device)) config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None if config.grad_scale_func is not None
else torch.tensor(1.0) else torch.ones(1, device=output_tensor.device)
) )
# Set the loss scale # Set the loss scale
if config.calculate_per_token_loss: if config.calculate_per_token_loss:
...@@ -477,8 +482,23 @@ def forward_backward_step( ...@@ -477,8 +482,23 @@ def forward_backward_step(
else: else:
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches) MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
if not unwrap_output_tensor: if not unwrap_output_tensor:
output_tensor, num_tokens = [output_tensor], num_tokens output_tensor, num_tokens = [output_tensor], num_tokens
# backward post process # backward post process
input_tensor_grad = None input_tensor_grad = None
if b_model is not None: if b_model is not None:
......
import contextlib import contextlib
from typing import Callable, Iterator, List, Optional, Union from typing import Iterator, List, Union
import torch import torch
...@@ -7,10 +7,8 @@ from megatron.training import get_args ...@@ -7,10 +7,8 @@ from megatron.training import get_args
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.utils import ( from megatron.core.utils import (
get_attr_wrapped_model,
get_model_config, get_model_config,
get_model_type, get_model_type,
get_model_xattn, get_model_xattn,
...@@ -448,6 +446,8 @@ def forward_backward_pipelining_with_interleaving( ...@@ -448,6 +446,8 @@ def forward_backward_pipelining_with_interleaving(
"""Helper method to run backward step with model split into chunks """Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling (run set_virtual_pipeline_model_parallel_rank() before calling
backward_step()).""" backward_step())."""
nonlocal output_tensor_grads
model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False) model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
......
...@@ -40,8 +40,7 @@ class ExtraTransformerConfig: ...@@ -40,8 +40,7 @@ class ExtraTransformerConfig:
combined_1f1b_recipe: str = 'ep_a2a' combined_1f1b_recipe: str = 'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported.""" """Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
split_bw: bool = False # split_bw: bool = False
"""If true, split dgrad and wgrad for better overlapping in combined 1F1B."""
@dataclass @dataclass
......
from functools import partial
from typing import Any, Optional from typing import Any, Optional
import torch
from torch import Tensor from torch import Tensor
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
...@@ -12,8 +10,7 @@ from megatron.core.utils import ( ...@@ -12,8 +10,7 @@ from megatron.core.utils import (
) )
from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables
class TransformerLayer(MegatronCoreTransformerLayer): class TransformerLayer(MegatronCoreTransformerLayer):
...@@ -34,7 +31,10 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -34,7 +31,10 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
): ):
if not isinstance(self.mlp, MoELayer): if (
not isinstance(self.mlp, MoELayer)
or not isinstance(self.mlp.token_dispatcher, MoEAlltoAllTokenDispatcher)
):
return super().forward( return super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
context=context, context=context,
...@@ -55,7 +55,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -55,7 +55,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output, pre_mlp_layernorm_output,
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
probs, _,
) = self._submodule_attention_router_compound_forward( ) = self._submodule_attention_router_compound_forward(
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -97,7 +97,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -97,7 +97,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
print("> initializing torch distributed ...", flush=True) print("> initializing torch distributed ...", flush=True)
# Manually set the device ids. # Manually set the device ids.
if device_count > 0: if device_count > 0:
torch.cuda.set_device(args.local_rank % device_count) torch.cuda.set_device(args.local_rank)
device_id = torch.device(f'cuda:{args.local_rank}') device_id = torch.device(f'cuda:{args.local_rank}')
else: else:
device_id = None device_id = None
......
...@@ -195,7 +195,8 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -195,7 +195,8 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
active=args.profile_step_end-args.profile_step_start, active=args.profile_step_end-args.profile_step_start,
repeat=1), repeat=1),
on_trace_ready=trace_handler, on_trace_ready=trace_handler,
record_shapes=True, record_shapes=True,
with_stack=True,
) )
prof.start() prof.start()
elif args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_hip_profiler: elif args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_hip_profiler:
......
...@@ -16,9 +16,6 @@ from megatron.core.enums import ModelType ...@@ -16,9 +16,6 @@ from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
get_gpt_heterogeneous_layer_spec,
)
from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.rerun_state_machine import get_rerun_state_machine
import megatron.legacy.model import megatron.legacy.model
from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt import GPTModel
...@@ -38,7 +35,6 @@ from megatron.core.models.gpt.gpt_layer_specs import ( ...@@ -38,7 +35,6 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec, get_gpt_mtp_block_spec,
) )
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from dcu_megatron import megatron_adaptor from dcu_megatron import megatron_adaptor
...@@ -102,8 +98,6 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -102,8 +98,6 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
if args.num_experts: if args.num_experts:
# Define the decoder block spec # Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te, normalization=args.normalization) transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te, normalization=args.normalization)
elif args.heterogeneous_layers_config_path is not None:
transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
else: else:
# Define the decoder layer spec # Define the decoder layer spec
if use_te: if use_te:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment