Unverified Commit e39674b9 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add option to pass kwargs to CUDA graph module (#945)



* Add option to pass kwargs to CUDA graph module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug unit tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Tweak comments
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6c579267
......@@ -3,7 +3,8 @@
# See LICENSE for license information.
from dataclasses import dataclass
from typing import List, Tuple
import itertools
from typing import Iterable, List, Tuple, Union
import pytest
import torch
......@@ -88,7 +89,7 @@ def generate_data(
dpa: bool = False,
warmup: bool = False,
return_grad_output: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn
if dpa:
......@@ -129,14 +130,20 @@ def generate_data(
return inputs, grad_output
def get_outputs(model, output):
def get_outputs(
model: torch.nn.Module,
output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
"""Return grads and params for comparsion."""
values = []
for param in model.parameters():
values.append(param)
if param.grad is not None:
values.append(param.grad)
values.append(output)
if isinstance(output, torch.Tensor):
values.append(output)
else:
values.extend(output)
return values
......@@ -161,7 +168,7 @@ def _test_cuda_graphs(
module: str,
graph_mode: str,
) -> List[torch.Tensor]:
"""Helper function for test."""
"""Helper function for CUDA graph test."""
reset_rng_states()
FP8GlobalStateManager.reset()
dpa = module == "dpa"
......@@ -247,7 +254,7 @@ def _test_cuda_graphs(
else:
model = modules[0] if dpa else _Sequential(*modules)
# Loss function and optimizer.
# Optimizer.
if not dpa:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
......@@ -312,3 +319,193 @@ def test_gpt_make_graphed_callables(
# Check that results match
assert_all_equal(outputs, graph_outputs_mode1)
assert_all_equal(outputs, graph_outputs_mode2)
def _test_cuda_graphs_with_kwargs(
*,
config: ModelConfig,
dtype: torch.dtype,
with_graph: bool,
) -> List[torch.Tensor]:
"""Simulate Megatron-LM interleaved pipeline parallelism."""
reset_rng_states()
# Initialize model.
model = TransformerLayer(
config.hidden_size,
config.hidden_size,
config.num_heads,
hidden_dropout=0.0,
attention_dropout=0.0,
self_attn_mask_type="arbitrary",
fuse_qkv_params=True,
params_dtype=dtype,
)
# Initialize gradient buffers.
for param in model.parameters():
param.grad = torch.empty_like(param)
# Make graphed version of model if needed.
if with_graph:
attn_mask = torch.zeros(
(config.batch_size, 1, config.sequence_length, config.sequence_length),
dtype=torch.bool,
device="cuda",
)
model = make_graphed_callables(
model,
generate_data(config, dtype, warmup=True),
sample_kwargs=dict(attention_mask=attn_mask),
allow_unused_input=True,
)
# Optimizer.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Training loop.
for _ in range(3):
optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2):
inputs, grad_output = generate_data(config, dtype, return_grad_output=True)
attn_mask = torch.randint(
2,
(config.batch_size, 1, config.sequence_length, config.sequence_length),
dtype=torch.bool,
device="cuda",
)
output = model(*inputs, attention_mask=attn_mask)
output.backward(grad_output)
optimizer.step()
return get_outputs(model, output)
def test_make_graphed_callables_with_kwargs(
dtype: torch.dtype = torch.float32,
model: str = "small",
) -> None:
"""Test CUDA graphs with keyword arguments."""
config = model_configs[model]
kwargs = dict(config=config, dtype=dtype)
outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
assert_all_equal(outputs, graph_outputs)
def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
*,
config: ModelConfig,
dtype: torch.dtype,
with_graph: bool,
) -> List[torch.Tensor]:
"""Simulate Megatron-LM interleaved pipeline parallelism."""
reset_rng_states()
# Pipeline parallel configuration.
num_layers = 2
num_microbatches = 3
layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]
# Initialize model.
model = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
config.hidden_size,
params_dtype=dtype,
)
for _ in range(num_layers)
]
)
# Initialize gradient buffers.
for param in model.parameters():
param.grad = torch.empty_like(param)
# Make graphed version of model if needed.
layer_forwards = {
(i % num_layers, i // num_layers): model[i % num_layers]
for i in range(num_layers * num_microbatches)
}
if with_graph:
sample_args = tuple(
generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches)
)
layer_forwards = make_graphed_callables(
tuple(model),
sample_args,
allow_unused_input=True,
_order=layer_order,
)
layer_forwards = {
(i // num_microbatches, i % num_microbatches): forward
for i, forward in enumerate(layer_forwards)
}
# Optimizer.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Training loop.
for _ in range(3):
optimizer.zero_grad(set_to_none=False)
# Generate data.
inputs = {}
grad_outputs = {}
for layer_idx in range(num_layers):
for microbatch_idx in range(num_microbatches):
x, dy = generate_data(config, dtype, return_grad_output=True)
idxs = (layer_idx, microbatch_idx)
inputs[idxs] = x[0]
grad_outputs[idxs] = dy
# Cache for layer outputs.
outputs = {}
def forward(layer_idx: int, microbatch_idx: int):
"""Helper function for forward steps"""
idxs = (layer_idx, microbatch_idx)
outputs[idxs] = layer_forwards[idxs](inputs[idxs])
def backward(layer_idx: int, microbatch_idx: int):
"""Helper function for backward steps"""
outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])
# Forward and backward steps.
forward(0, 0)
forward(1, 0)
forward(0, 1)
forward(1, 1)
backward(1, 0)
backward(0, 0)
forward(0, 2)
forward(1, 2)
backward(1, 1)
backward(0, 1)
backward(1, 2)
backward(0, 2)
# Optimizer step.
optimizer.step()
outputs = [y for _, y in sorted(outputs.items())]
return get_outputs(model, outputs)
def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
dtype: torch.dtype = torch.float16,
model: str = "small",
) -> None:
"""Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
config = model_configs[model]
kwargs = dict(config=config, dtype=dtype)
outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
)
graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=True,
**kwargs,
)
assert_all_equal(outputs, graph_outputs)
......@@ -3,11 +3,14 @@
# See LICENSE for license information.
"""Functions for CUDA Graphs support in FP8"""
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch
from torch.utils._pytree import tree_flatten as _tree_flatten
from torch.utils._pytree import tree_unflatten as _tree_unflatten
from torch._C import _graph_pool_handle
from transformer_engine.common.recipe import DelayedScaling
from .fp8 import (
fp8_autocast,
FP8GlobalStateManager,
......@@ -22,6 +25,9 @@ __all__ = ["make_graphed_callables"]
_IS_GRAPH_CAPTURING = False
_T = TypeVar("_T")
SingleOrTuple = Union[_T, Tuple[_T, ...]]
def set_capture_start() -> None:
"""Record beginning of `make_graphed_callables`."""
......@@ -48,13 +54,14 @@ def graph_pool_handle():
def _make_graphed_callables(
callables,
sample_args,
num_warmup_iters=3,
allow_unused_input=False,
fp8_weight_caching=False,
_order=None,
):
callables: SingleOrTuple[Callable],
sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
fp8_weight_caching: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
_order: Optional[List[int]] = None,
) -> SingleOrTuple[Callable]:
"""
Helper method for `make_graphed_callables`
"""
......@@ -65,16 +72,38 @@ def _make_graphed_callables(
"caching. Please set `cache_enabled=False`."
)
just_one_callable = False
# Default is to pass no kwargs to callables
if sample_kwargs is None:
if isinstance(callables, tuple):
sample_kwargs = tuple({} for _ in range(len(sample_args)))
else:
sample_kwargs = {}
# Canonicalize args as tuples
just_one_callable = False
if not isinstance(callables, tuple):
just_one_callable = True
callables = (callables,)
sample_args = (sample_args,)
sample_kwargs = (sample_kwargs,)
flatten_sample_args = []
if _order is not None:
# order is a list containing 1..model_chunk values in the order of microbatch schedule
# Check sizes of args
if _order is None:
assert len(sample_args) == len(callables)
assert len(sample_kwargs) == len(callables)
else:
# Custom logic for interleaved pipeline parallelism
# Note: This is tightly coupled with the Megatron-core
# implementation of interleaved pipeline parallelism at
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py.
# Note: The model is assumed to consist of layers
# (corresponding to callables) that are grouped into
# equally-sized model chunks. _order is a list of chunk
# indices (1-indexed) that indicates the order in which the
# layers are evaluated. Positive values indicate forward
# passes and negative values indicate backward passes. Each
# entry in sample_args corresponds to one of the forward
# passes.
num_model_chunks = max(_order)
num_microbatches = len(_order) // num_model_chunks // 2
assert num_model_chunks * num_microbatches * 2 == len(_order)
......@@ -90,10 +119,13 @@ def _make_graphed_callables(
f"Expected {num_model_chunks * num_microbatches}"
+ f"args tuple, but got {len(sample_args)}."
)
assert len(sample_kwargs) == len(sample_args)
if fp8_weight_caching:
# Initialize flag that controls FP8 weight updates
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
# Check callables
for c in callables:
if isinstance(c, torch.nn.Module):
assert (
......@@ -110,9 +142,14 @@ def _make_graphed_callables(
+ ":func:`~make_graphed_callables`, only parameters may be trainable. "
+ "All buffers must have ``requires_grad=False``."
)
for args in sample_args:
# Flatten callable arguments
per_callable_kwargs_keys = [list(kwargs.keys()) for kwargs in sample_kwargs]
flatten_sample_args = []
for args, kwargs, kwargs_keys in zip(sample_args, sample_kwargs, per_callable_kwargs_keys):
flatten_arg, _ = _tree_flatten(args)
flatten_sample_args.append(tuple(flatten_arg))
flatten_kwarg, _ = _tree_flatten([kwargs[key] for key in kwargs_keys])
flatten_sample_args.append(tuple(flatten_arg + flatten_kwarg))
assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
"In the beta API, sample_args "
+ "for each callable must contain only Tensors. Other types are not allowed."
......@@ -120,6 +157,10 @@ def _make_graphed_callables(
# If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
# passes to forward (ie, its sample_args) AND the module's parameter attributes.
# Note: These per_callable_* variables are not actually
# per-callable, but per-forward-pass (see description of _order).
# The names are kept for consistency with
# torch.cuda.make_graphed_callables.
per_callable_len_user_args = [len(args) for args in flatten_sample_args]
if _order is None:
per_callable_module_params = [
......@@ -144,6 +185,7 @@ def _make_graphed_callables(
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
graph_callables = [None for _ in range(len(flatten_sample_args))]
# For cases with multiple active RNG states, e.g. TP.
if graph_safe_rng_available():
for _, state in get_all_rng_states().items():
......@@ -158,11 +200,12 @@ def _make_graphed_callables(
# from ending up in any captures.
torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()):
for c_i, func in enumerate(callables):
args = sample_args[c_i]
static_input_surface = per_callable_static_input_surfaces[c_i]
for func_idx, func in enumerate(callables):
args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx]
for _ in range(num_warmup_iters):
outputs, _ = _tree_flatten(func(*args))
outputs, _ = _tree_flatten(func(*args, **kwargs))
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
......@@ -194,9 +237,10 @@ def _make_graphed_callables(
fwd_idx[m_chunk] * num_layers + l_no
)
args = sample_args[per_callable_fwd_idx]
kwargs = sample_kwargs[per_callable_fwd_idx]
fwd_graph = fwd_graphs[per_callable_fwd_idx]
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args)
outputs = func(*args, **kwargs)
flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec
......@@ -245,9 +289,9 @@ def _make_graphed_callables(
per_callable_static_outputs = []
per_callable_output_unflatten_spec = []
graph_id = 0
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args)
outputs = func(*args, **kwargs)
graph_callables[graph_id] = func
graph_id += 1
......@@ -300,6 +344,7 @@ def _make_graphed_callables(
fwd_graph,
bwd_graph,
module_params,
kwargs_keys,
len_user_args,
output_unflatten_spec,
static_input_surface,
......@@ -312,14 +357,18 @@ def _make_graphed_callables(
@staticmethod
def forward(ctx, skip_fp8_weight_update, *inputs):
# At this stage, only the user args may (potentially) be new tensors.
# Set flag for whether to update FP8 weight updates
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
if ctx.is_first_module and skip_fp8_weight_update is not None:
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update)
# Copy values from new tensors into static tensors
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
static_input_surface[i].copy_(inputs[i])
# Replay forward graph
fwd_graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
......@@ -327,6 +376,8 @@ def _make_graphed_callables(
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, *grads):
# Replay backward graph
assert len(grads) == len(static_grad_outputs)
for g, grad in zip(static_grad_outputs, grads):
if g is not None:
......@@ -336,6 +387,7 @@ def _make_graphed_callables(
g.copy_(grad)
bwd_graph.replay()
# Update FP8 scale factors if needed
if ctx.is_first_module:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
......@@ -346,10 +398,8 @@ def _make_graphed_callables(
)
def functionalized(*user_args, **user_kwargs):
# Runs the autograd function with inputs == all
# inputs to the graph that might require grad
# (explicit user args + module parameters)
# Assumes module params didn't change since capture.
# Decide whether to update FP8 weights
skip_fp8_weight_update = None
if fp8_weight_caching:
assert "is_first_microbatch" in user_kwargs and isinstance(
......@@ -358,8 +408,22 @@ def _make_graphed_callables(
skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]
# Check that required kwargs are provided
for key in kwargs_keys:
if key not in user_kwargs:
raise TypeError(
f"Graphed callable was initialized with kwarg {key} ,"
"but it was not provided in graph replay"
)
# Runs the autograd function with inputs == all inputs to
# the graph that might require grad (explicit user args +
# module parameters)
# Assumes module params didn't change since capture.
flatten_user_args, _ = _tree_flatten(user_args)
out = Graphed.apply(skip_fp8_weight_update, *(tuple(flatten_user_args) + module_params))
flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys])
func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params
out = Graphed.apply(skip_fp8_weight_update, *func_args)
return _tree_unflatten(out, output_unflatten_spec)
return functionalized
......@@ -371,6 +435,7 @@ def _make_graphed_callables(
fwd_graphs[i],
bwd_graphs[i],
per_callable_module_params[i],
per_callable_kwargs_keys[i],
per_callable_len_user_args[i],
per_callable_output_unflatten_spec[i],
per_callable_static_input_surfaces[i],
......@@ -443,25 +508,42 @@ def restore_fp8_tensors(modules, fp8_tensors):
def make_graphed_callables(
modules,
sample_args,
num_warmup_iters=3,
allow_unused_input=False,
fp8_enabled=False,
fp8_calibrating=False,
fp8_recipe=None,
fp8_weight_caching=False,
_order=None,
):
modules: SingleOrTuple[Callable],
sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
fp8_enabled: bool = False,
fp8_calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None,
) -> Union[Callable, Tuple[Callable, ...]]:
"""
A version of PyTorch's `make_graphed_callables` utility function with support for
TransformerEngine modules and FP8. Please see the original version in upstream PyTorch
`here <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_
for extensive documentation. The documentation for additional parameters which are
specific to FP8 are given below.
FP8 specific parameters
-----------------------
Make CUDA graph version of Transformer Engine modules
A variation of PyTorch's `make_graphed_callables` utility function
with support for Transformer Engine modules and FP8. Please see
the
`original PyTorch implementation <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_
for more documentation.
Graphing parameters
-------------------
modules: (tuple of) callable
Callable or callables to graph.
sample_args: (tuple of) tuple of torch.Tensor
Positional arguments to callable(s).
num_warmup_iters: int, default = 3
Number of warmup iterations.
allow_unused_input: bool, default = `False`
Whether to handle case where callable inputs
and outputs are disconnected in compute graph.
sample_kwargs: (tuple of) dict, optional
Keyword arguments to callable(s)
FP8-related parameters
----------------------
fp8_enabled: bool, default = `True`
whether or not to enable fp8
fp8_calibrating: bool, default = `False`
......@@ -478,6 +560,7 @@ def make_graphed_callables(
using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg
must be set to `False` if calculating weight transposes' outside TE, e.g.,
in the optimizer step.
"""
set_capture_start()
......@@ -532,6 +615,7 @@ def make_graphed_callables(
num_warmup_iters=num_warmup_iters,
allow_unused_input=allow_unused_input,
fp8_weight_caching=fp8_weight_caching,
sample_kwargs=sample_kwargs,
_order=_order,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment