"vscode:/vscode.git/clone" did not exist on "64ca424e7518d1176fff032b76d85e49a4fc936a"
Unverified Commit b3c25057 authored by Pingtian Li's avatar Pingtian Li Committed by GitHub
Browse files

[Pytorch] Fix backward_dw cuda graph order (#2376)



* fix backward_dw cuda graph order
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>

* add validation for num_layers_per_chunk
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d52ed471
......@@ -7,6 +7,7 @@ from collections.abc import Iterable
import contextlib
import gc
import warnings
from math import ceil
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch
......@@ -127,6 +128,8 @@ def _make_graphed_callables(
)
# Check sizes of args
_order_without_wgrad = None
delay_wgrad_compute = False
if _order is None:
assert len(sample_args) == len(callables)
assert len(sample_kwargs) == len(callables)
......@@ -145,17 +148,34 @@ def _make_graphed_callables(
# values indicate backward passes. Each
# entry in sample_args corresponds to one of the forward
# passes.
num_model_chunks = max(_order)
num_microbatches = len(_order) // num_model_chunks // 2
assert num_model_chunks * num_microbatches * 2 == len(_order)
_order_without_wgrad = []
for c_id in _order:
if ceil(c_id) != c_id:
delay_wgrad_compute = True
continue
_order_without_wgrad.append(c_id)
num_model_chunks = max(_order_without_wgrad)
num_microbatches = len(_order_without_wgrad) // num_model_chunks // 2
assert num_model_chunks * num_microbatches * 2 == len(_order_without_wgrad)
# When delay_wgrad_compute is enabled, each layer is treated as a model chunk, which
# allows for fine-grained graph capture order.
if delay_wgrad_compute:
assert (
_num_layers_per_chunk is not None
), "'_num_layers_per_chunk' must be provided when delay_wgrad_compute is True."
for num_layers in _num_layers_per_chunk:
assert (
num_layers == 1
), "Each model chunk must have only one layer when delay_wgrad_compute is True."
# Determine number of layers in each model chunk.
if _num_layers_per_chunk is None:
assert len(sample_args) * 2 >= len(_order) and (
len(sample_args) * 2 % len(_order) == 0
assert len(sample_args) * 2 >= len(_order_without_wgrad) and (
len(sample_args) * 2 % len(_order_without_wgrad) == 0
), (
f"{len(sample_args)} * 2 >= {len(_order)} and {len(sample_args)} * 2 %"
f" {len(_order)} == 0"
f"{len(sample_args)} * 2 >= {len(_order_without_wgrad)} and {len(sample_args)} * 2"
f" % {len(_order_without_wgrad)} == 0"
)
num_layers = len(sample_args) // num_model_chunks // num_microbatches
_num_layers_per_chunk = [num_layers] * num_model_chunks
......@@ -175,7 +195,7 @@ def _make_graphed_callables(
+ f"entries when order input is provided but got {len(callables)}."
)
assert len(sample_args) == total_num_layers * num_microbatches, (
f"Expected {total_num_layers * num_microbatches}"
f"Expected {total_num_layers * num_microbatches} "
+ f"args tuple, but got {len(sample_args)}."
)
......@@ -214,7 +234,7 @@ def _make_graphed_callables(
consumed_sample_q = {}
fwd_idx = [0] * num_model_chunks
for c_id in _order:
m_chunk = abs(c_id) - 1
m_chunk = abs(ceil(c_id)) - 1
if c_id > 0:
sample_start_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
......@@ -241,6 +261,8 @@ def _make_graphed_callables(
sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx]
sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx]
fwd_idx[m_chunk] += 1
elif ceil(c_id) != c_id:
continue
else:
num_consumed_samples = min(
len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk]
......@@ -477,9 +499,11 @@ def _make_graphed_callables(
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
static_grad_outputs_dict = {}
wgrad_validation_list = [None] * len(_order)
previous_chunk_last_callable_bwd_idx = None
for c_id in _order:
for i, c_id in enumerate(_order):
if c_id > 0:
assert isinstance(c_id, int), "Forward order value must be an integer."
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
m_chunk = c_id - 1
for l_no in range(_num_layers_per_chunk[m_chunk]):
......@@ -499,12 +523,65 @@ def _make_graphed_callables(
fwd_idx[m_chunk] += 1
else:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk = -c_id - 1
m_chunk = -ceil(c_id) - 1
previous_per_callable_bwd_idx = None
for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))):
per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
)
if ceil(c_id) == c_id and need_bwd_dw_graph[per_callable_bwd_idx]:
# Check if bwd graph has corresponding wgrad graph:
# Number of dgrad backward graphs should be equal to number of
# wgrad backward graphs.
# Note: For MCore, the validation rule is more strict (the next backward
# of dgrad graph must be corresponding wgrad graph).
if wgrad_validation_list[i] is None:
same_bwd_c_id_list = [i]
num_wgrad_c_id = 0
for idx in range(i + 1, len(_order)):
if _order[idx] > 0:
continue
if _order[idx] == c_id:
same_bwd_c_id_list.append(idx)
if _order[idx] + 0.5 == c_id:
num_wgrad_c_id += 1
if len(same_bwd_c_id_list) == num_wgrad_c_id:
for same_c_id_idx in same_bwd_c_id_list:
wgrad_validation_list[same_c_id_idx] = True
break
if len(same_bwd_c_id_list) < num_wgrad_c_id:
# It's impossible to have more wgrad than dgrad.
wgrad_validation_list[i] = False
break
if wgrad_validation_list[i] is None:
wgrad_validation_list[i] = False
assert wgrad_validation_list[i], (
f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number "
f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}."
)
elif ceil(c_id) != c_id:
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]
assert is_training, "Only training mode supports backward_dw."
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate
# the specific order of backward_dw.
assert ceil(c_id) - c_id == 0.5, (
"The order diff of wgrad and dgrad must be 0.5, "
f"get {ceil(c_id) - c_id}."
)
assert need_bwd_dw_graph[
per_callable_bwd_idx
], "No module needs wgrad computation but get float in order"
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
continue
static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx]
static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx]
......@@ -537,17 +614,6 @@ def _make_graphed_callables(
allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward,
)
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it.
if need_bwd_dw_graph[per_callable_bwd_idx]:
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
......@@ -596,8 +662,8 @@ def _make_graphed_callables(
per_callable_static_grad_inputs[idx]
)
previous_chunk_last_callable_bwd_idx = per_callable_bwd_idx
bwd_idx[m_chunk] += 1
if ceil(c_id) == c_id:
bwd_idx[m_chunk] += 1
else:
# Capture forward graphs
per_callable_static_outputs = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment