Unverified Commit 06947e87 authored by buptzyb's avatar buptzyb Committed by GitHub
Browse files

[PyTorch] Fix cudagraph static_input and static_grad_input reuse (#2018)



* fix graph static grad input reuse
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent c3f8a9f5
...@@ -179,24 +179,17 @@ def _make_graphed_callables( ...@@ -179,24 +179,17 @@ def _make_graphed_callables(
assert isinstance( assert isinstance(
sample_args, list sample_args, list
), "sample_args must be a list for _reuse_graph_input_output_buffers." ), "sample_args must be a list for _reuse_graph_input_output_buffers."
len_args = len(sample_args[0])
for i, arg in enumerate(sample_args):
assert len_args == len(
arg
), "Arguments must have same length and shape for `_reuse_graph_input_output_buffers`."
len_kwargs = len(sample_kwargs[0])
assert isinstance(
sample_kwargs, list
), "sample_kwargs must be a list for _reuse_graph_input_output_buffers."
for i, kwarg in enumerate(sample_kwargs):
assert len_kwargs == len(kwarg), (
"Keyword arguments must have same length and shape for"
" `_reuse_graph_input_output_buffers`."
)
# Reorganize args and kwargs for input tensor reuse. # Reorganize args and kwargs for input tensor reuse.
# fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples.
# Each tuple contains the sample key signature and its fwd_idx. When we finish a backward
# chunk, we pop the corresponding fwd_idx and push to the consumed_sample_q.
# consumed_sample_q is keyed by the sample key signature. The value is a queue of the
# fwd_idx whose backward has been called so that we can reuse the same static buffers.
# In this way, we can reuse the same static input buffers for the non-overlapping samples
# with the same input signature.
fwd_sample_qs = {} fwd_sample_qs = {}
consumed_sample_q = [] consumed_sample_q = {}
fwd_idx = [0] * num_model_chunks fwd_idx = [0] * num_model_chunks
for c_id in _order: for c_id in _order:
m_chunk = abs(c_id) - 1 m_chunk = abs(c_id) - 1
...@@ -208,10 +201,21 @@ def _make_graphed_callables( ...@@ -208,10 +201,21 @@ def _make_graphed_callables(
fwd_sample_idx = [ fwd_sample_idx = [
sample_start_idx + i for i in range(_num_layers_per_chunk[m_chunk]) sample_start_idx + i for i in range(_num_layers_per_chunk[m_chunk])
] ]
fwd_sample_qs[m_chunk] = fwd_sample_qs.get(m_chunk, []) + fwd_sample_idx if m_chunk not in fwd_sample_qs:
fwd_sample_qs[m_chunk] = []
for per_callable_fwd_idx in fwd_sample_idx: for per_callable_fwd_idx in fwd_sample_idx:
if consumed_sample_q: sample_args_keys = tuple(
reuse_fwd_idx = consumed_sample_q.pop(0) (t.shape, t.dtype, t.layout) for t in sample_args[per_callable_fwd_idx]
)
sample_kwargs_keys = tuple(
(k, v.shape, v.dtype, v.layout)
for k, v in sorted(sample_kwargs[per_callable_fwd_idx].items())
)
sample_keys = sample_args_keys + sample_kwargs_keys
fwd_sample_qs[m_chunk].append((sample_keys, per_callable_fwd_idx))
if consumed_sample_q.get(sample_keys, []):
reuse_fwd_idx = consumed_sample_q[sample_keys].pop(0)
sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx] sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx]
sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx] sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx]
fwd_idx[m_chunk] += 1 fwd_idx[m_chunk] += 1
...@@ -219,7 +223,12 @@ def _make_graphed_callables( ...@@ -219,7 +223,12 @@ def _make_graphed_callables(
num_consumed_samples = min( num_consumed_samples = min(
len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk] len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk]
) )
consumed_sample_q += fwd_sample_qs[m_chunk][:num_consumed_samples] for sample_keys, per_callable_fwd_idx in fwd_sample_qs[m_chunk][
:num_consumed_samples
]:
if sample_keys not in consumed_sample_q:
consumed_sample_q[sample_keys] = []
consumed_sample_q[sample_keys].append(per_callable_fwd_idx)
fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:] fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:]
if fp8_weight_caching: if fp8_weight_caching:
...@@ -423,7 +432,7 @@ def _make_graphed_callables( ...@@ -423,7 +432,7 @@ def _make_graphed_callables(
fwd_idx = [0] * num_model_chunks fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks
static_grad_outputs_dict = {} static_grad_outputs_dict = {}
previous_per_callable_bwd_idx = None previous_chunk_last_callable_bwd_idx = None
for c_id in _order: for c_id in _order:
if c_id > 0: if c_id > 0:
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
...@@ -446,6 +455,7 @@ def _make_graphed_callables( ...@@ -446,6 +455,7 @@ def _make_graphed_callables(
else: else:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk = -c_id - 1 m_chunk = -c_id - 1
previous_per_callable_bwd_idx = None
for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))): for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))):
per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
...@@ -508,19 +518,29 @@ def _make_graphed_callables( ...@@ -508,19 +518,29 @@ def _make_graphed_callables(
per_callable_static_outputs[per_callable_bwd_idx] = make_weak_ref( per_callable_static_outputs[per_callable_bwd_idx] = make_weak_ref(
static_outputs static_outputs
) )
# Weak ref the static grad inputs of the previous backward pass.
# Note: After a backward pass, we assume Mcore will send the # Weak ref the static grad inputs of the previous backward pass within the
# grad input to another pipeline parallel rank and that the # same chunk.
# communication is finished before the end of the next backward
# pass.
if previous_per_callable_bwd_idx is not None: if previous_per_callable_bwd_idx is not None:
per_callable_static_grad_inputs[previous_per_callable_bwd_idx] = ( idx = previous_per_callable_bwd_idx
make_weak_ref( per_callable_static_grad_inputs[idx] = make_weak_ref(
per_callable_static_grad_inputs[previous_per_callable_bwd_idx] per_callable_static_grad_inputs[idx]
)
) )
previous_per_callable_bwd_idx = per_callable_bwd_idx previous_per_callable_bwd_idx = per_callable_bwd_idx
# Weak ref the static grad inputs of the previous chunk's last backward
# pass.
# Note: After a chunk's backward pass, we assume Mcore will send the grad
# input to another pipeline parallel rank and that the communication is
# finished before the end of the next chunk's backward pass.
if l_no == 0:
if previous_chunk_last_callable_bwd_idx is not None:
idx = previous_chunk_last_callable_bwd_idx
per_callable_static_grad_inputs[idx] = make_weak_ref(
per_callable_static_grad_inputs[idx]
)
previous_chunk_last_callable_bwd_idx = per_callable_bwd_idx
bwd_idx[m_chunk] += 1 bwd_idx[m_chunk] += 1
else: else:
# Capture forward graphs # Capture forward graphs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment