"cacheflow/vscode:/vscode.git/clone" did not exist on "3be29a1104e15c3bb30ed1d42eda476d2bf9f04e"
Unverified Commit b3c25057 authored by Pingtian Li's avatar Pingtian Li Committed by GitHub
Browse files

[Pytorch] Fix backward_dw cuda graph order (#2376)



* fix backward_dw cuda graph order
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>

* add validation for num_layers_per_chunk
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPingtian Li <pingtianl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d52ed471
...@@ -7,6 +7,7 @@ from collections.abc import Iterable ...@@ -7,6 +7,7 @@ from collections.abc import Iterable
import contextlib import contextlib
import gc import gc
import warnings import warnings
from math import ceil
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch import torch
...@@ -127,6 +128,8 @@ def _make_graphed_callables( ...@@ -127,6 +128,8 @@ def _make_graphed_callables(
) )
# Check sizes of args # Check sizes of args
_order_without_wgrad = None
delay_wgrad_compute = False
if _order is None: if _order is None:
assert len(sample_args) == len(callables) assert len(sample_args) == len(callables)
assert len(sample_kwargs) == len(callables) assert len(sample_kwargs) == len(callables)
...@@ -145,17 +148,34 @@ def _make_graphed_callables( ...@@ -145,17 +148,34 @@ def _make_graphed_callables(
# values indicate backward passes. Each # values indicate backward passes. Each
# entry in sample_args corresponds to one of the forward # entry in sample_args corresponds to one of the forward
# passes. # passes.
num_model_chunks = max(_order) _order_without_wgrad = []
num_microbatches = len(_order) // num_model_chunks // 2 for c_id in _order:
assert num_model_chunks * num_microbatches * 2 == len(_order) if ceil(c_id) != c_id:
delay_wgrad_compute = True
continue
_order_without_wgrad.append(c_id)
num_model_chunks = max(_order_without_wgrad)
num_microbatches = len(_order_without_wgrad) // num_model_chunks // 2
assert num_model_chunks * num_microbatches * 2 == len(_order_without_wgrad)
# When delay_wgrad_compute is enabled, each layer is treated as a model chunk, which
# allows for fine-grained graph capture order.
if delay_wgrad_compute:
assert (
_num_layers_per_chunk is not None
), "'_num_layers_per_chunk' must be provided when delay_wgrad_compute is True."
for num_layers in _num_layers_per_chunk:
assert (
num_layers == 1
), "Each model chunk must have only one layer when delay_wgrad_compute is True."
# Determine number of layers in each model chunk. # Determine number of layers in each model chunk.
if _num_layers_per_chunk is None: if _num_layers_per_chunk is None:
assert len(sample_args) * 2 >= len(_order) and ( assert len(sample_args) * 2 >= len(_order_without_wgrad) and (
len(sample_args) * 2 % len(_order) == 0 len(sample_args) * 2 % len(_order_without_wgrad) == 0
), ( ), (
f"{len(sample_args)} * 2 >= {len(_order)} and {len(sample_args)} * 2 %" f"{len(sample_args)} * 2 >= {len(_order_without_wgrad)} and {len(sample_args)} * 2"
f" {len(_order)} == 0" f" % {len(_order_without_wgrad)} == 0"
) )
num_layers = len(sample_args) // num_model_chunks // num_microbatches num_layers = len(sample_args) // num_model_chunks // num_microbatches
_num_layers_per_chunk = [num_layers] * num_model_chunks _num_layers_per_chunk = [num_layers] * num_model_chunks
...@@ -175,7 +195,7 @@ def _make_graphed_callables( ...@@ -175,7 +195,7 @@ def _make_graphed_callables(
+ f"entries when order input is provided but got {len(callables)}." + f"entries when order input is provided but got {len(callables)}."
) )
assert len(sample_args) == total_num_layers * num_microbatches, ( assert len(sample_args) == total_num_layers * num_microbatches, (
f"Expected {total_num_layers * num_microbatches}" f"Expected {total_num_layers * num_microbatches} "
+ f"args tuple, but got {len(sample_args)}." + f"args tuple, but got {len(sample_args)}."
) )
...@@ -214,7 +234,7 @@ def _make_graphed_callables( ...@@ -214,7 +234,7 @@ def _make_graphed_callables(
consumed_sample_q = {} consumed_sample_q = {}
fwd_idx = [0] * num_model_chunks fwd_idx = [0] * num_model_chunks
for c_id in _order: for c_id in _order:
m_chunk = abs(c_id) - 1 m_chunk = abs(ceil(c_id)) - 1
if c_id > 0: if c_id > 0:
sample_start_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( sample_start_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
...@@ -241,6 +261,8 @@ def _make_graphed_callables( ...@@ -241,6 +261,8 @@ def _make_graphed_callables(
sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx] sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx]
sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx] sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx]
fwd_idx[m_chunk] += 1 fwd_idx[m_chunk] += 1
elif ceil(c_id) != c_id:
continue
else: else:
num_consumed_samples = min( num_consumed_samples = min(
len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk] len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk]
...@@ -477,9 +499,11 @@ def _make_graphed_callables( ...@@ -477,9 +499,11 @@ def _make_graphed_callables(
fwd_idx = [0] * num_model_chunks fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks
static_grad_outputs_dict = {} static_grad_outputs_dict = {}
wgrad_validation_list = [None] * len(_order)
previous_chunk_last_callable_bwd_idx = None previous_chunk_last_callable_bwd_idx = None
for c_id in _order: for i, c_id in enumerate(_order):
if c_id > 0: if c_id > 0:
assert isinstance(c_id, int), "Forward order value must be an integer."
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
m_chunk = c_id - 1 m_chunk = c_id - 1
for l_no in range(_num_layers_per_chunk[m_chunk]): for l_no in range(_num_layers_per_chunk[m_chunk]):
...@@ -499,12 +523,65 @@ def _make_graphed_callables( ...@@ -499,12 +523,65 @@ def _make_graphed_callables(
fwd_idx[m_chunk] += 1 fwd_idx[m_chunk] += 1
else: else:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk = -c_id - 1 m_chunk = -ceil(c_id) - 1
previous_per_callable_bwd_idx = None previous_per_callable_bwd_idx = None
for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))): for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))):
per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
) )
if ceil(c_id) == c_id and need_bwd_dw_graph[per_callable_bwd_idx]:
# Check if bwd graph has corresponding wgrad graph:
# Number of dgrad backward graphs should be equal to number of
# wgrad backward graphs.
# Note: For MCore, the validation rule is more strict (the next backward
# of dgrad graph must be corresponding wgrad graph).
if wgrad_validation_list[i] is None:
same_bwd_c_id_list = [i]
num_wgrad_c_id = 0
for idx in range(i + 1, len(_order)):
if _order[idx] > 0:
continue
if _order[idx] == c_id:
same_bwd_c_id_list.append(idx)
if _order[idx] + 0.5 == c_id:
num_wgrad_c_id += 1
if len(same_bwd_c_id_list) == num_wgrad_c_id:
for same_c_id_idx in same_bwd_c_id_list:
wgrad_validation_list[same_c_id_idx] = True
break
if len(same_bwd_c_id_list) < num_wgrad_c_id:
# It's impossible to have more wgrad than dgrad.
wgrad_validation_list[i] = False
break
if wgrad_validation_list[i] is None:
wgrad_validation_list[i] = False
assert wgrad_validation_list[i], (
f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number "
f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}."
)
elif ceil(c_id) != c_id:
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]
assert is_training, "Only training mode supports backward_dw."
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate
# the specific order of backward_dw.
assert ceil(c_id) - c_id == 0.5, (
"The order diff of wgrad and dgrad must be 0.5, "
f"get {ceil(c_id) - c_id}."
)
assert need_bwd_dw_graph[
per_callable_bwd_idx
], "No module needs wgrad computation but get float in order"
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
continue
static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx] static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx]
static_outputs = per_callable_static_outputs[per_callable_bwd_idx] static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx] bwd_graph = bwd_graphs[per_callable_bwd_idx]
...@@ -537,17 +614,6 @@ def _make_graphed_callables( ...@@ -537,17 +614,6 @@ def _make_graphed_callables(
allow_unused=allow_unused_input, allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward, retain_graph=retain_graph_in_backward,
) )
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it.
if need_bwd_dw_graph[per_callable_bwd_idx]:
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
# Constructs a tuple suitable for returning from Graphed.backward: # Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs # Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern. # that don't require grad. I couldn't think of a one-liner for this pattern.
...@@ -596,8 +662,8 @@ def _make_graphed_callables( ...@@ -596,8 +662,8 @@ def _make_graphed_callables(
per_callable_static_grad_inputs[idx] per_callable_static_grad_inputs[idx]
) )
previous_chunk_last_callable_bwd_idx = per_callable_bwd_idx previous_chunk_last_callable_bwd_idx = per_callable_bwd_idx
if ceil(c_id) == c_id:
bwd_idx[m_chunk] += 1 bwd_idx[m_chunk] += 1
else: else:
# Capture forward graphs # Capture forward graphs
per_callable_static_outputs = [] per_callable_static_outputs = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment