Unverified Commit ae393e81 authored by buptzyb's avatar buptzyb Committed by GitHub
Browse files

Support CUDA Graph for MoE models (#1233)



* Align RNG tracker with megatron
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarYifei Song <yifeis@nvidia.com>

* Fix module_params order and warmup bug in cudagraph
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarYifei Song <yifeis@nvidia.com>

* Add fp8_group argument and fix fp8 accuracy issue for cudagraph
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarYifei Song <yifeis@nvidia.com>

* Add TE modules and weights filters to support MoE models
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarYifei Song <yifeis@nvidia.com>

* Revert self.fp8
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* Use hooks to filter module params
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* Filter all TE modules in hooks
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarYifei Song <yifeis@nvidia.com>

* Format code
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update graph.py
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>

* Revert CudaRNGStatesTracker
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* Format Update
Signed-off-by: default avatarYifei Song <yifeis@nvidia.com>

* Revert "Use hooks to filter module params"

This reverts commit 73a22e2e8bcf43ec84c23bc844b8d16d06626e26.
Signed-off-by: default avatarYifei Song <yifeis@nvidia.com>

* Remove filtering module params
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>
Signed-off-by: default avatarYifei Song <yifeis@nvidia.com>
Co-authored-by: default avatarYifei Song <yifeis@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <yaox12@outlook.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 8952bc41
......@@ -442,16 +442,16 @@ class FP8GlobalStateManager:
stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft()
# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]
fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0])
fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1])
fp8_meta["scaling_fwd"].scale_inv.copy_(stashed_fp8_meta[2])
@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"])
fp8_meta["scaling_fwd"].scale_inv.copy_(fp8_meta["updated_scale_inv_fwd"])
@contextmanager
......
......@@ -12,6 +12,7 @@ from torch.utils._pytree import tree_unflatten as _tree_unflatten
from torch._C import _graph_pool_handle
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.constants import dist_group_type
from .fp8 import (
fp8_autocast,
FP8GlobalStateManager,
......@@ -173,11 +174,14 @@ def _make_graphed_callables(
]
else:
per_callable_module_params = []
for c in callables:
for i in range(num_microbatches):
per_callable_module_params.append(
tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
)
for m_chunk in range(num_model_chunks):
for _ in range(num_microbatches):
for l_no in range(num_layers):
per_callable_module_params.append(
tuple(callables[m_chunk * num_layers + l_no].parameters())
if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module)
else ()
)
assert len(per_callable_module_params) == len(flatten_sample_args)
per_callable_static_input_surfaces = [
flatten_sample_args[i] + per_callable_module_params[i]
......@@ -201,13 +205,55 @@ def _make_graphed_callables(
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
# from ending up in any captures.
torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()):
# Get warmup func and func_idx.
warmup_func_idx = []
warmup_func = []
if _order is None:
for func_idx, func in enumerate(callables):
warmup_func_idx.append(func_idx)
warmup_func.append(func)
else:
fwd_idx = [0] * num_model_chunks
for c_id in _order:
if c_id > 0:
m_chunk = c_id - 1
for l_no in range(num_layers):
func = callables[m_chunk * num_layers + l_no]
func_idx = (m_chunk * num_microbatches * num_layers) + (
fwd_idx[m_chunk] * num_layers + l_no
)
warmup_func_idx.append(func_idx)
warmup_func.append(func)
fwd_idx[m_chunk] += 1
assert len(warmup_func) == len(
sample_args
), f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}."
assert len(warmup_func_idx) == len(
set(warmup_func_idx)
), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique."
# Filter the TE modules that cudagraph can access.
visited_te_modules = set()
def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module)
# Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()):
for func_idx, func in zip(warmup_func_idx, warmup_func):
args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx]
for _ in range(num_warmup_iters):
hooks = []
for module in func.modules():
hook = module.register_forward_hook(hook_fn)
hooks.append(hook)
outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks:
hook.remove()
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
......@@ -216,6 +262,11 @@ def _make_graphed_callables(
allow_unused=allow_unused_input,
)
del outputs, grad_inputs
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True
torch.cuda.synchronize()
# All captures here share a mempool. To avoid replays corrupting each other's memory,
......@@ -462,6 +513,19 @@ def _make_graphed_callables(
isinstance(m, TransformerEngineBaseModule)
and FP8GlobalStateManager.is_fp8_enabled()
):
if m not in visited_te_modules:
# Only Set the FP8 meta for the modules included by forward
continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
from transformer_engine.pytorch.attention import DotProductAttention
if (
isinstance(m, DotProductAttention)
and not fp8_recipe.fp8_mha
and not fp8_recipe.fp8_dpa
):
# Don't need to update FP8 meta for non-FP8 DPA
continue
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
......@@ -538,6 +602,7 @@ def make_graphed_callables(
fp8_enabled: bool = False,
fp8_calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
......@@ -579,6 +644,9 @@ def make_graphed_callables(
using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None`
recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
fp8_weight_caching: bool, default = `False`
Whether or not to cache FP8 weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward
......@@ -607,7 +675,11 @@ def make_graphed_callables(
def forward_func(*args, **kwargs):
with fp8_autocast(
enabled=fp8_enabled, calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, _graph=True
enabled=fp8_enabled,
calibrating=fp8_calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=True,
):
outputs = old_forward(*args, **kwargs)
return outputs
......
......@@ -1152,7 +1152,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced)
"""
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False
......
......@@ -1484,7 +1484,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced)
"""
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False
......
......@@ -938,8 +938,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment