Unverified Commit e39674b9 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add option to pass kwargs to CUDA graph module (#945)



* Add option to pass kwargs to CUDA graph module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug unit tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Tweak comments
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6c579267
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# See LICENSE for license information. # See LICENSE for license information.
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Tuple import itertools
from typing import Iterable, List, Tuple, Union
import pytest import pytest
import torch import torch
...@@ -88,7 +89,7 @@ def generate_data( ...@@ -88,7 +89,7 @@ def generate_data(
dpa: bool = False, dpa: bool = False,
warmup: bool = False, warmup: bool = False,
return_grad_output: bool = False, return_grad_output: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""Generate synthetic data.""" """Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn gen_func = torch.ones if warmup else torch.randn
if dpa: if dpa:
...@@ -129,14 +130,20 @@ def generate_data( ...@@ -129,14 +130,20 @@ def generate_data(
return inputs, grad_output return inputs, grad_output
def get_outputs(model, output): def get_outputs(
model: torch.nn.Module,
output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
"""Return grads and params for comparsion.""" """Return grads and params for comparsion."""
values = [] values = []
for param in model.parameters(): for param in model.parameters():
values.append(param) values.append(param)
if param.grad is not None: if param.grad is not None:
values.append(param.grad) values.append(param.grad)
if isinstance(output, torch.Tensor):
values.append(output) values.append(output)
else:
values.extend(output)
return values return values
...@@ -161,7 +168,7 @@ def _test_cuda_graphs( ...@@ -161,7 +168,7 @@ def _test_cuda_graphs(
module: str, module: str,
graph_mode: str, graph_mode: str,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Helper function for test.""" """Helper function for CUDA graph test."""
reset_rng_states() reset_rng_states()
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
dpa = module == "dpa" dpa = module == "dpa"
...@@ -247,7 +254,7 @@ def _test_cuda_graphs( ...@@ -247,7 +254,7 @@ def _test_cuda_graphs(
else: else:
model = modules[0] if dpa else _Sequential(*modules) model = modules[0] if dpa else _Sequential(*modules)
# Loss function and optimizer. # Optimizer.
if not dpa: if not dpa:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001) optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
...@@ -312,3 +319,193 @@ def test_gpt_make_graphed_callables( ...@@ -312,3 +319,193 @@ def test_gpt_make_graphed_callables(
# Check that results match # Check that results match
assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode1)
assert_all_equal(outputs, graph_outputs_mode2) assert_all_equal(outputs, graph_outputs_mode2)
def _test_cuda_graphs_with_kwargs(
*,
config: ModelConfig,
dtype: torch.dtype,
with_graph: bool,
) -> List[torch.Tensor]:
"""Simulate Megatron-LM interleaved pipeline parallelism."""
reset_rng_states()
# Initialize model.
model = TransformerLayer(
config.hidden_size,
config.hidden_size,
config.num_heads,
hidden_dropout=0.0,
attention_dropout=0.0,
self_attn_mask_type="arbitrary",
fuse_qkv_params=True,
params_dtype=dtype,
)
# Initialize gradient buffers.
for param in model.parameters():
param.grad = torch.empty_like(param)
# Make graphed version of model if needed.
if with_graph:
attn_mask = torch.zeros(
(config.batch_size, 1, config.sequence_length, config.sequence_length),
dtype=torch.bool,
device="cuda",
)
model = make_graphed_callables(
model,
generate_data(config, dtype, warmup=True),
sample_kwargs=dict(attention_mask=attn_mask),
allow_unused_input=True,
)
# Optimizer.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Training loop.
for _ in range(3):
optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2):
inputs, grad_output = generate_data(config, dtype, return_grad_output=True)
attn_mask = torch.randint(
2,
(config.batch_size, 1, config.sequence_length, config.sequence_length),
dtype=torch.bool,
device="cuda",
)
output = model(*inputs, attention_mask=attn_mask)
output.backward(grad_output)
optimizer.step()
return get_outputs(model, output)
def test_make_graphed_callables_with_kwargs(
dtype: torch.dtype = torch.float32,
model: str = "small",
) -> None:
"""Test CUDA graphs with keyword arguments."""
config = model_configs[model]
kwargs = dict(config=config, dtype=dtype)
outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
assert_all_equal(outputs, graph_outputs)
def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
*,
config: ModelConfig,
dtype: torch.dtype,
with_graph: bool,
) -> List[torch.Tensor]:
"""Simulate Megatron-LM interleaved pipeline parallelism."""
reset_rng_states()
# Pipeline parallel configuration.
num_layers = 2
num_microbatches = 3
layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]
# Initialize model.
model = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
config.hidden_size,
params_dtype=dtype,
)
for _ in range(num_layers)
]
)
# Initialize gradient buffers.
for param in model.parameters():
param.grad = torch.empty_like(param)
# Make graphed version of model if needed.
layer_forwards = {
(i % num_layers, i // num_layers): model[i % num_layers]
for i in range(num_layers * num_microbatches)
}
if with_graph:
sample_args = tuple(
generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches)
)
layer_forwards = make_graphed_callables(
tuple(model),
sample_args,
allow_unused_input=True,
_order=layer_order,
)
layer_forwards = {
(i // num_microbatches, i % num_microbatches): forward
for i, forward in enumerate(layer_forwards)
}
# Optimizer.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Training loop.
for _ in range(3):
optimizer.zero_grad(set_to_none=False)
# Generate data.
inputs = {}
grad_outputs = {}
for layer_idx in range(num_layers):
for microbatch_idx in range(num_microbatches):
x, dy = generate_data(config, dtype, return_grad_output=True)
idxs = (layer_idx, microbatch_idx)
inputs[idxs] = x[0]
grad_outputs[idxs] = dy
# Cache for layer outputs.
outputs = {}
def forward(layer_idx: int, microbatch_idx: int):
"""Helper function for forward steps"""
idxs = (layer_idx, microbatch_idx)
outputs[idxs] = layer_forwards[idxs](inputs[idxs])
def backward(layer_idx: int, microbatch_idx: int):
"""Helper function for backward steps"""
outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])
# Forward and backward steps.
forward(0, 0)
forward(1, 0)
forward(0, 1)
forward(1, 1)
backward(1, 0)
backward(0, 0)
forward(0, 2)
forward(1, 2)
backward(1, 1)
backward(0, 1)
backward(1, 2)
backward(0, 2)
# Optimizer step.
optimizer.step()
outputs = [y for _, y in sorted(outputs.items())]
return get_outputs(model, outputs)
def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
dtype: torch.dtype = torch.float16,
model: str = "small",
) -> None:
"""Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
config = model_configs[model]
kwargs = dict(config=config, dtype=dtype)
outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
)
graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=True,
**kwargs,
)
assert_all_equal(outputs, graph_outputs)
...@@ -3,11 +3,14 @@ ...@@ -3,11 +3,14 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Functions for CUDA Graphs support in FP8""" """Functions for CUDA Graphs support in FP8"""
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch import torch
from torch.utils._pytree import tree_flatten as _tree_flatten from torch.utils._pytree import tree_flatten as _tree_flatten
from torch.utils._pytree import tree_unflatten as _tree_unflatten from torch.utils._pytree import tree_unflatten as _tree_unflatten
from torch._C import _graph_pool_handle from torch._C import _graph_pool_handle
from transformer_engine.common.recipe import DelayedScaling
from .fp8 import ( from .fp8 import (
fp8_autocast, fp8_autocast,
FP8GlobalStateManager, FP8GlobalStateManager,
...@@ -22,6 +25,9 @@ __all__ = ["make_graphed_callables"] ...@@ -22,6 +25,9 @@ __all__ = ["make_graphed_callables"]
_IS_GRAPH_CAPTURING = False _IS_GRAPH_CAPTURING = False
_T = TypeVar("_T")
SingleOrTuple = Union[_T, Tuple[_T, ...]]
def set_capture_start() -> None: def set_capture_start() -> None:
"""Record beginning of `make_graphed_callables`.""" """Record beginning of `make_graphed_callables`."""
...@@ -48,13 +54,14 @@ def graph_pool_handle(): ...@@ -48,13 +54,14 @@ def graph_pool_handle():
def _make_graphed_callables( def _make_graphed_callables(
callables, callables: SingleOrTuple[Callable],
sample_args, sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
num_warmup_iters=3, num_warmup_iters: int = 3,
allow_unused_input=False, allow_unused_input: bool = False,
fp8_weight_caching=False, fp8_weight_caching: bool = False,
_order=None, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
): _order: Optional[List[int]] = None,
) -> SingleOrTuple[Callable]:
""" """
Helper method for `make_graphed_callables` Helper method for `make_graphed_callables`
""" """
...@@ -65,16 +72,38 @@ def _make_graphed_callables( ...@@ -65,16 +72,38 @@ def _make_graphed_callables(
"caching. Please set `cache_enabled=False`." "caching. Please set `cache_enabled=False`."
) )
just_one_callable = False # Default is to pass no kwargs to callables
if sample_kwargs is None:
if isinstance(callables, tuple):
sample_kwargs = tuple({} for _ in range(len(sample_args)))
else:
sample_kwargs = {}
# Canonicalize args as tuples
just_one_callable = False
if not isinstance(callables, tuple): if not isinstance(callables, tuple):
just_one_callable = True just_one_callable = True
callables = (callables,) callables = (callables,)
sample_args = (sample_args,) sample_args = (sample_args,)
sample_kwargs = (sample_kwargs,)
flatten_sample_args = [] # Check sizes of args
if _order is not None: if _order is None:
# order is a list containing 1..model_chunk values in the order of microbatch schedule assert len(sample_args) == len(callables)
assert len(sample_kwargs) == len(callables)
else:
# Custom logic for interleaved pipeline parallelism
# Note: This is tightly coupled with the Megatron-core
# implementation of interleaved pipeline parallelism at
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py.
# Note: The model is assumed to consist of layers
# (corresponding to callables) that are grouped into
# equally-sized model chunks. _order is a list of chunk
# indices (1-indexed) that indicates the order in which the
# layers are evaluated. Positive values indicate forward
# passes and negative values indicate backward passes. Each
# entry in sample_args corresponds to one of the forward
# passes.
num_model_chunks = max(_order) num_model_chunks = max(_order)
num_microbatches = len(_order) // num_model_chunks // 2 num_microbatches = len(_order) // num_model_chunks // 2
assert num_model_chunks * num_microbatches * 2 == len(_order) assert num_model_chunks * num_microbatches * 2 == len(_order)
...@@ -90,10 +119,13 @@ def _make_graphed_callables( ...@@ -90,10 +119,13 @@ def _make_graphed_callables(
f"Expected {num_model_chunks * num_microbatches}" f"Expected {num_model_chunks * num_microbatches}"
+ f"args tuple, but got {len(sample_args)}." + f"args tuple, but got {len(sample_args)}."
) )
assert len(sample_kwargs) == len(sample_args)
if fp8_weight_caching: if fp8_weight_caching:
# Initialize flag that controls FP8 weight updates
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
# Check callables
for c in callables: for c in callables:
if isinstance(c, torch.nn.Module): if isinstance(c, torch.nn.Module):
assert ( assert (
...@@ -110,9 +142,14 @@ def _make_graphed_callables( ...@@ -110,9 +142,14 @@ def _make_graphed_callables(
+ ":func:`~make_graphed_callables`, only parameters may be trainable. " + ":func:`~make_graphed_callables`, only parameters may be trainable. "
+ "All buffers must have ``requires_grad=False``." + "All buffers must have ``requires_grad=False``."
) )
for args in sample_args:
# Flatten callable arguments
per_callable_kwargs_keys = [list(kwargs.keys()) for kwargs in sample_kwargs]
flatten_sample_args = []
for args, kwargs, kwargs_keys in zip(sample_args, sample_kwargs, per_callable_kwargs_keys):
flatten_arg, _ = _tree_flatten(args) flatten_arg, _ = _tree_flatten(args)
flatten_sample_args.append(tuple(flatten_arg)) flatten_kwarg, _ = _tree_flatten([kwargs[key] for key in kwargs_keys])
flatten_sample_args.append(tuple(flatten_arg + flatten_kwarg))
assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
"In the beta API, sample_args " "In the beta API, sample_args "
+ "for each callable must contain only Tensors. Other types are not allowed." + "for each callable must contain only Tensors. Other types are not allowed."
...@@ -120,6 +157,10 @@ def _make_graphed_callables( ...@@ -120,6 +157,10 @@ def _make_graphed_callables(
# If a callable is an nn.Module, its graph's full input surface is the args the user explicitly # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
# passes to forward (ie, its sample_args) AND the module's parameter attributes. # passes to forward (ie, its sample_args) AND the module's parameter attributes.
# Note: These per_callable_* variables are not actually
# per-callable, but per-forward-pass (see description of _order).
# The names are kept for consistency with
# torch.cuda.make_graphed_callables.
per_callable_len_user_args = [len(args) for args in flatten_sample_args] per_callable_len_user_args = [len(args) for args in flatten_sample_args]
if _order is None: if _order is None:
per_callable_module_params = [ per_callable_module_params = [
...@@ -144,6 +185,7 @@ def _make_graphed_callables( ...@@ -144,6 +185,7 @@ def _make_graphed_callables(
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
graph_callables = [None for _ in range(len(flatten_sample_args))] graph_callables = [None for _ in range(len(flatten_sample_args))]
# For cases with multiple active RNG states, e.g. TP. # For cases with multiple active RNG states, e.g. TP.
if graph_safe_rng_available(): if graph_safe_rng_available():
for _, state in get_all_rng_states().items(): for _, state in get_all_rng_states().items():
...@@ -158,11 +200,12 @@ def _make_graphed_callables( ...@@ -158,11 +200,12 @@ def _make_graphed_callables(
# from ending up in any captures. # from ending up in any captures.
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()): with torch.cuda.stream(torch.cuda.Stream()):
for c_i, func in enumerate(callables): for func_idx, func in enumerate(callables):
args = sample_args[c_i] args = sample_args[func_idx]
static_input_surface = per_callable_static_input_surfaces[c_i] kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx]
for _ in range(num_warmup_iters): for _ in range(num_warmup_iters):
outputs, _ = _tree_flatten(func(*args)) outputs, _ = _tree_flatten(func(*args, **kwargs))
grad_inputs = torch.autograd.grad( grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad), outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad),
...@@ -194,9 +237,10 @@ def _make_graphed_callables( ...@@ -194,9 +237,10 @@ def _make_graphed_callables(
fwd_idx[m_chunk] * num_layers + l_no fwd_idx[m_chunk] * num_layers + l_no
) )
args = sample_args[per_callable_fwd_idx] args = sample_args[per_callable_fwd_idx]
kwargs = sample_kwargs[per_callable_fwd_idx]
fwd_graph = fwd_graphs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx]
with torch.cuda.graph(fwd_graph, pool=mempool): with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args) outputs = func(*args, **kwargs)
flatten_outputs, spec = _tree_flatten(outputs) flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec
...@@ -245,9 +289,9 @@ def _make_graphed_callables( ...@@ -245,9 +289,9 @@ def _make_graphed_callables(
per_callable_static_outputs = [] per_callable_static_outputs = []
per_callable_output_unflatten_spec = [] per_callable_output_unflatten_spec = []
graph_id = 0 graph_id = 0
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool): with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args) outputs = func(*args, **kwargs)
graph_callables[graph_id] = func graph_callables[graph_id] = func
graph_id += 1 graph_id += 1
...@@ -300,6 +344,7 @@ def _make_graphed_callables( ...@@ -300,6 +344,7 @@ def _make_graphed_callables(
fwd_graph, fwd_graph,
bwd_graph, bwd_graph,
module_params, module_params,
kwargs_keys,
len_user_args, len_user_args,
output_unflatten_spec, output_unflatten_spec,
static_input_surface, static_input_surface,
...@@ -312,14 +357,18 @@ def _make_graphed_callables( ...@@ -312,14 +357,18 @@ def _make_graphed_callables(
@staticmethod @staticmethod
def forward(ctx, skip_fp8_weight_update, *inputs): def forward(ctx, skip_fp8_weight_update, *inputs):
# At this stage, only the user args may (potentially) be new tensors.
# Set flag for whether to update FP8 weight updates
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
if ctx.is_first_module and skip_fp8_weight_update is not None: if ctx.is_first_module and skip_fp8_weight_update is not None:
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update)
# Copy values from new tensors into static tensors
for i in range(len_user_args): for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
static_input_surface[i].copy_(inputs[i]) static_input_surface[i].copy_(inputs[i])
# Replay forward graph
fwd_graph.replay() fwd_graph.replay()
assert isinstance(static_outputs, tuple) assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs) return tuple(o.detach() for o in static_outputs)
...@@ -327,6 +376,8 @@ def _make_graphed_callables( ...@@ -327,6 +376,8 @@ def _make_graphed_callables(
@staticmethod @staticmethod
@torch.autograd.function.once_differentiable @torch.autograd.function.once_differentiable
def backward(ctx, *grads): def backward(ctx, *grads):
# Replay backward graph
assert len(grads) == len(static_grad_outputs) assert len(grads) == len(static_grad_outputs)
for g, grad in zip(static_grad_outputs, grads): for g, grad in zip(static_grad_outputs, grads):
if g is not None: if g is not None:
...@@ -336,6 +387,7 @@ def _make_graphed_callables( ...@@ -336,6 +387,7 @@ def _make_graphed_callables(
g.copy_(grad) g.copy_(grad)
bwd_graph.replay() bwd_graph.replay()
# Update FP8 scale factors if needed
if ctx.is_first_module: if ctx.is_first_module:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
...@@ -346,10 +398,8 @@ def _make_graphed_callables( ...@@ -346,10 +398,8 @@ def _make_graphed_callables(
) )
def functionalized(*user_args, **user_kwargs): def functionalized(*user_args, **user_kwargs):
# Runs the autograd function with inputs == all
# inputs to the graph that might require grad # Decide whether to update FP8 weights
# (explicit user args + module parameters)
# Assumes module params didn't change since capture.
skip_fp8_weight_update = None skip_fp8_weight_update = None
if fp8_weight_caching: if fp8_weight_caching:
assert "is_first_microbatch" in user_kwargs and isinstance( assert "is_first_microbatch" in user_kwargs and isinstance(
...@@ -358,8 +408,22 @@ def _make_graphed_callables( ...@@ -358,8 +408,22 @@ def _make_graphed_callables(
skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]
# Check that required kwargs are provided
for key in kwargs_keys:
if key not in user_kwargs:
raise TypeError(
f"Graphed callable was initialized with kwarg {key} ,"
"but it was not provided in graph replay"
)
# Runs the autograd function with inputs == all inputs to
# the graph that might require grad (explicit user args +
# module parameters)
# Assumes module params didn't change since capture.
flatten_user_args, _ = _tree_flatten(user_args) flatten_user_args, _ = _tree_flatten(user_args)
out = Graphed.apply(skip_fp8_weight_update, *(tuple(flatten_user_args) + module_params)) flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys])
func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params
out = Graphed.apply(skip_fp8_weight_update, *func_args)
return _tree_unflatten(out, output_unflatten_spec) return _tree_unflatten(out, output_unflatten_spec)
return functionalized return functionalized
...@@ -371,6 +435,7 @@ def _make_graphed_callables( ...@@ -371,6 +435,7 @@ def _make_graphed_callables(
fwd_graphs[i], fwd_graphs[i],
bwd_graphs[i], bwd_graphs[i],
per_callable_module_params[i], per_callable_module_params[i],
per_callable_kwargs_keys[i],
per_callable_len_user_args[i], per_callable_len_user_args[i],
per_callable_output_unflatten_spec[i], per_callable_output_unflatten_spec[i],
per_callable_static_input_surfaces[i], per_callable_static_input_surfaces[i],
...@@ -443,25 +508,42 @@ def restore_fp8_tensors(modules, fp8_tensors): ...@@ -443,25 +508,42 @@ def restore_fp8_tensors(modules, fp8_tensors):
def make_graphed_callables( def make_graphed_callables(
modules, modules: SingleOrTuple[Callable],
sample_args, sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
num_warmup_iters=3, num_warmup_iters: int = 3,
allow_unused_input=False, allow_unused_input: bool = False,
fp8_enabled=False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
fp8_calibrating=False, fp8_enabled: bool = False,
fp8_recipe=None, fp8_calibrating: bool = False,
fp8_weight_caching=False, fp8_recipe: Optional[DelayedScaling] = None,
_order=None, fp8_weight_caching: bool = False,
): _order: Optional[List[int]] = None,
) -> Union[Callable, Tuple[Callable, ...]]:
""" """
A version of PyTorch's `make_graphed_callables` utility function with support for Make CUDA graph version of Transformer Engine modules
TransformerEngine modules and FP8. Please see the original version in upstream PyTorch
`here <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_ A variation of PyTorch's `make_graphed_callables` utility function
for extensive documentation. The documentation for additional parameters which are with support for Transformer Engine modules and FP8. Please see
specific to FP8 are given below. the
`original PyTorch implementation <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_
FP8 specific parameters for more documentation.
-----------------------
Graphing parameters
-------------------
modules: (tuple of) callable
Callable or callables to graph.
sample_args: (tuple of) tuple of torch.Tensor
Positional arguments to callable(s).
num_warmup_iters: int, default = 3
Number of warmup iterations.
allow_unused_input: bool, default = `False`
Whether to handle case where callable inputs
and outputs are disconnected in compute graph.
sample_kwargs: (tuple of) dict, optional
Keyword arguments to callable(s)
FP8-related parameters
----------------------
fp8_enabled: bool, default = `True` fp8_enabled: bool, default = `True`
whether or not to enable fp8 whether or not to enable fp8
fp8_calibrating: bool, default = `False` fp8_calibrating: bool, default = `False`
...@@ -478,6 +560,7 @@ def make_graphed_callables( ...@@ -478,6 +560,7 @@ def make_graphed_callables(
using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg
must be set to `False` if calculating weight transposes' outside TE, e.g., must be set to `False` if calculating weight transposes' outside TE, e.g.,
in the optimizer step. in the optimizer step.
""" """
set_capture_start() set_capture_start()
...@@ -532,6 +615,7 @@ def make_graphed_callables( ...@@ -532,6 +615,7 @@ def make_graphed_callables(
num_warmup_iters=num_warmup_iters, num_warmup_iters=num_warmup_iters,
allow_unused_input=allow_unused_input, allow_unused_input=allow_unused_input,
fp8_weight_caching=fp8_weight_caching, fp8_weight_caching=fp8_weight_caching,
sample_kwargs=sample_kwargs,
_order=_order, _order=_order,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment