"projects/vscode:/vscode.git/clone" did not exist on "a9bc082ce95f5b2b976d17acea9afd0b425a3c04"
Unverified Commit 4cc220c9 authored by 李金梁's avatar 李金梁 Committed by GitHub
Browse files

fix bug of attn backward in non-casual model with context parallel open. (#1031)



This bug will cause bug [ERROR] failed (exitcode: -11) local_rank: 0 (pid: 1761020) of binary: ~/megatron/bin/python.

That is because we miss the rng_states that is required in attention recompute (for dropout), but no hint is provided.  

It is very very very difficult to trace and cost me two weeks.

```python
before the start of training step] datetime: 2024-07-22 18:26:45 
[2024-07-22 18:27:00,941] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -11) local_rank: 0 (pid: 1761020) of binary: /home//miniconda3/envs/megatron/bin/python
Traceback (most recent call last):
  File "/home//miniconda3/envs/megatron/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.2.1+cu121', 'console_scripts', 'torchrun')())
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
```
Signed-off-by: default avatar李金梁 <975761915@qq.com>
parent 1aaf1cc8
...@@ -2132,6 +2132,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -2132,6 +2132,7 @@ class AttnFuncWithCP(torch.autograd.Function):
ctx.dropout_p, ctx.dropout_p,
ctx.softmax_scale, ctx.softmax_scale,
False, False,
rng_state=rng_states[cp_size - i - 1],
**fa_optional_backward_kwargs, **fa_optional_backward_kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment