Unverified Commit 94f54d71 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

[PyTorch] upgrade context parallelism implementations (#572)



* try to use cuDNN fused attention for context parallelism
Signed-off-by: default avatarxren <xren@nvidia.com>

* assert CP is only supported with NVTE_F16_arbitrary_seqlen
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* port fused attn api to context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add one more assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert CP does not support padded tokens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add qkv_format into CP implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove qkv_format from CP function
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix qkv_for,at
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bwd error with FA v2
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make cp implementation support non-causal masking
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant asserts for CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor assert information change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert core attn bias has not been supported with CP yet
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make CP work with window_sizes of [-1, -1] and [-1, 0]
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft code for fa test with cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* move fused attn test to a specific folder
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add assert_close to flash attn cp test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add more tests for CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add optional arguments for FA v2.4+
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add skip condition for CP test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* class and function naming fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* docstring fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not use fused attn if backend does not work with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* create a separate folder for CP test as it needs multi-GPUs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add attn_mask_type check in attn_forwrad_func_with_cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarxren <xren@nvidia.com>
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent bb759adc
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
...@@ -10,6 +10,6 @@ pip install pytest==6.2.5 onnxruntime==1.13.1 ...@@ -10,6 +10,6 @@ pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from test_fused_attn_with_cp import model_configs
dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend='FlashAttention'):
"""Test DotProductAttention module with context parallelism"""
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
print(f"[INFO] world_size:{world_size}, rank:{rank}")
dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)
# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
assert(rank in cp_comm_ranks)
cp_comm_group = dist.new_group(cp_comm_ranks, backend='nccl')
config = model_configs[model]
assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!"
# instantiate core attn module
core_attn = DotProductAttention(config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim)
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
# make sure all GPU ranks have same inputs
for x in [q, k, v, dout]:
dist.broadcast(x, 0, group=cp_comm_group)
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
out = core_attn(q, k, v)
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
out_ = core_attn(q_, k_, v_)
out_.backward(dout_)
for x in [out_, q_.grad, k_.grad, v_.grad]:
assert(torch.all(~torch.isnan(x)))
assert(torch.all(~torch.isinf(x)))
# compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16':
tols = dict(atol=2.5e-2, rtol=2.5e-2)
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q.grad, k.grad, v.grad, out]]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]
if qkv_format == "bshd":
torch.testing.assert_close(out_[:, 0], out[:, 0], **tols)
torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols)
torch.testing.assert_close(dk_[:, 0], dk[:, 0], **tols)
torch.testing.assert_close(dv_[:, 0], dv[:, 0], **tols)
torch.testing.assert_close(out_[:, 1], out[:, 1], **tols)
torch.testing.assert_close(dq_[:, 1], dq[:, 1], **tols)
torch.testing.assert_close(dk_[:, 1], dk[:, 1], **tols)
torch.testing.assert_close(dv_[:, 1], dv[:, 1], **tols)
elif qkv_format == "sbhd":
torch.testing.assert_close(out_[0], out[0], **tols)
torch.testing.assert_close(dq_[0], dq[0], **tols)
torch.testing.assert_close(dk_[0], dk[0], **tols)
torch.testing.assert_close(dv_[0], dv[0], **tols)
torch.testing.assert_close(out_[1], out[1], **tols)
torch.testing.assert_close(dq_[1], dq[1], **tols)
torch.testing.assert_close(dk_[1], dk[1], **tols)
torch.testing.assert_close(dv_[1], dv[1], **tols)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
def main(**kwargs):
run_dpa_with_cp(**kwargs)
if __name__ == "__main__":
kwargs = dict(arg.split('=') for arg in sys.argv[2:])
main(**kwargs)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import pytest
import subprocess
from test_fused_attn import (
ModelConfig,
_is_flash_attention_2_available,
_cudnn_version,
)
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # GQA
}
def get_bash_arguments(**kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
args.append(script_path)
for k, v in kwargs.items():
args.append(f"{k}={v}")
return args
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend='FlashAttention'
),
check=True
)
@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend='FusedAttention'
),
check=True
)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment