Unverified Commit 24024061 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Adding TP overlap support for `te.Linear` with `parallel_mode="column"` (#1343)



* support AG overlap in sequence-parallel Linear forward and RS overlap in sequence-parallel Linear backward
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* implemented TP overlap support for column-parallel te.Linear
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed backward pass for te.Linear column-parallel with TP overlap, updated unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* improved error messages for internal failure to infer TP overlap options in te.Linear
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect TP overlap option asserts
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cbc46531
......@@ -51,15 +51,23 @@ def _get_layer_args(config, tp_group, tp_size, reference=False):
kwargs["ub_overlap_ag"] = not reference
if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not reference
kwargs["ub_name"] = "proj"
if config.linear_parallel_mode == "row":
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["ub_overlap_rs"] = not reference
elif config.linear_parallel_mode == "column":
input_shape[0] = config.seq_length // tp_size
args.append(3 * hidden_size)
kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["parallel_mode"] = config.linear_parallel_mode
kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not reference
kwargs["ub_bulk_dgrad"] = not reference
kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
......@@ -125,6 +133,19 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
)
parser.add_argument(
"--linear-parallel-mode",
type=str.lower,
default="row",
choices=["row", "column"],
help="Parallel mode for te.Linear.",
)
parser.add_argument(
"--overlap-rs-dgrad",
action="store_true",
default=False,
help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.",
)
parser.add_argument(
"--debug",
action="store_true",
......@@ -230,12 +251,19 @@ def _train(opts):
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Intialize userbuffers
ub_cfgs = None
if opts.overlap_rs_dgrad:
ub_cfgs = {
"proj_dgrad": {"method": "ring_exchange"},
"qkv_dgrad": {"method": "ring_exchange"},
}
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
WORLD_SIZE,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs,
)
# Initialize the Transformer Engine layer with overlap
......@@ -314,27 +342,29 @@ def _train(opts):
ref_grads.append(ref_param.grad)
# Make sure we have the same number of gradients
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
if len(test_grads) != len(ref_grads):
numerics_failed[0] = 1
num_grads_failed[0] = 1
numerics_info = (
"NUMERICAL CHECK FAILED: Incorrect number of gradients, "
+ f"expected {len(ref_grads)} but got {len(test_grads)}."
)
dist_print(numerics_info, src=WORLD_RANK, error=True)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world)
# Now validate accuracy
if not bool(numerics_failed.item()):
numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda")
if not bool(num_grads_failed.item()):
for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
rtol = 0.125 if opts.fp8 else 0.025
atol = 0.0625 if opts.fp8 else 0.00125
grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
if bool(numerics_failed.item()):
break
numerics_failed[i] = int(grad_failed)
return_code = torch.max(numerics_failed)
dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world)
else:
return_code = num_grads_failed
te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)
......@@ -344,7 +374,7 @@ def _train(opts):
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)
return numerics_failed[0].item()
return return_code.item()
if __name__ == "__main__":
......
......@@ -21,8 +21,10 @@ SEQ_LENGTH: int = 512
BATCH_SIZE: int = 2
NUM_HEADS: int = 12
HEAD_DIM: int = 64
# NOTE: te.Linear is intentionally omitted here and manually added later for testing both
# row and column parallel layouts.
TE_LAYERS = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
......@@ -86,7 +88,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg
raise AssertionError(result.stderr.decode())
def _run_layer_with_overlap(layer_type, fp8, fp8_init):
def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
test_path = TEST_ROOT / "run_layer_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
......@@ -97,6 +99,8 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init):
f"--head-dim={HEAD_DIM}",
f"--layer-type={layer_type}",
]
if layer_type == te.Linear.__name__:
test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}")
if fp8:
if not fp8_available:
......@@ -245,9 +249,15 @@ def test_bulk_overlaps(comm_type, fp8, connections):
@pytest.mark.parametrize(
"layer_type",
[layer.__name__ for layer in TE_LAYERS],
ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS],
"layer_type,linear_parallel_mode",
(
[(te.Linear.__name__, "row"), (te.Linear.__name__, "column")]
+ list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))]))
),
ids=(
[f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "]
+ [(" " + layer.__name__ + " ") for layer in TE_LAYERS]
),
)
@pytest.mark.parametrize(
"fp8,fp8_init",
......@@ -262,8 +272,8 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" FP8 GEMM - FP8 PARAMS ",
],
)
def test_layers_with_overlap(layer_type, fp8, fp8_init):
def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, fp8, fp8_init)
_run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init)
......@@ -3,6 +3,8 @@
# See LICENSE for license information.
"""Linear API"""
from functools import reduce
from operator import mul as multiply_op
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
......@@ -43,7 +45,7 @@ from ..cpp_extensions import (
fp8_cast_transpose_fused,
cast_to_fp8,
)
from ..constants import GemmParallelModes, dist_group_type
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
......@@ -80,8 +82,12 @@ class _Linear(torch.autograd.Function):
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
ub_overlap_rs: bool,
ub_overlap_ag: bool,
ub_overlap_rs_fprop: bool,
ub_overlap_ag_dgrad: bool,
ub_overlap_ag_fprop: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_dgrad: bool,
ub_bulk_wgrad: bool,
ub_name: str,
fp8_output: bool,
fsdp_group: Union[dist_group_type, None],
......@@ -99,7 +105,8 @@ class _Linear(torch.autograd.Function):
assert_dim_for_fp8_exec(weight)
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
ub_overlap_ag_fprop = False if tp_world_size == 1 else ub_overlap_ag_fprop
ub_overlap_rs_fprop = False if tp_world_size == 1 else ub_overlap_rs_fprop
# Cast input to expected dtype
inputmat = cast_if_needed(inputmat, activation_dtype)
......@@ -150,10 +157,11 @@ class _Linear(torch.autograd.Function):
inputmat_scale_inv.fill_(inputmat_scale_inv.item())
# Column Parallel Linear
if parallel_mode == "column" and sequence_parallel:
if parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
inputmat_total = inputmat
if fp8:
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
......@@ -165,75 +173,92 @@ class _Linear(torch.autograd.Function):
assert isinstance(weight_fp8, Float8Tensor)
if fp8_output:
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
out_index, meta_tensor, out_tedtype, out_pttype = (
tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_meta["scaling_fwd"],
fp8_dtype_forward,
torch.uint8,
)
else:
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
out_index, meta_tensor, out_tedtype, out_pttype = (
None,
None,
None,
activation_dtype,
)
ub_obj = None
ub_algo = None
rs_out = None
if ub_overlap_rs:
ub_obj_projout = get_ub(ub_name + "_fprop")
out = ub_obj_projout.get_ubuf_output(1)
inputmat_data = (
inputmat_total._data if isinstance(inputmat_total, Float8Tensor) else inputmat_total
)
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
out = ub_obj.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = out_features
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap():
if ub_obj_projout.is_atomic_gemm():
if ub_obj.is_p2p_overlap():
if ub_obj.is_atomic_gemm():
ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
else:
ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
if ub_obj_projout.is_atomic_gemm():
if ub_obj.is_atomic_gemm():
ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
else:
ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
if ub_obj_projout.is_fp8_ubuf():
proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
if ub_obj.is_fp8_ubuf():
out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
meta_tensor = fp8_meta["scaling_fwd"]
proj_out_tetype = fp8_dtype_forward
proj_out_pttype = torch.uint8
ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
out_tedtype = fp8_dtype_forward
out_pttype = torch.uint8
ub_obj.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM requires FP8 buffer."
ub_obj.copy_input_to_ubuf(inputmat_data, True)
ub_obj.set_ubuf_scale_inv(inputmat_scale_inv)
if ub_obj.is_atomic_gemm():
ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
out_tedtype = TE_DType[activation_dtype]
out_pttype = activation_dtype
dim_size = list(inputmat_total.size())
dim_size[0] *= tp_size
dim_size[1] = out_features
out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device)
else:
dim_size = list(inputmat_total.size())
dim_size[1] = out_features
out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device)
out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device)
_ = fp8_gemm(
weight_fp8._data,
weight_fp8._scale_inv,
0,
weight_fp8._fp8_dtype,
(
inputmat_total._data
if isinstance(inputmat_total, Float8Tensor)
else inputmat_total
),
inputmat_data,
inputmat_scale_inv,
0,
fp8_dtype_forward,
proj_out_pttype,
out_pttype,
get_workspace(),
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
out=out,
ub_algo=ub_algo if ub_overlap_rs else None,
ub=ub_obj_projout if ub_overlap_rs else None,
extra_output_tensor=rs_out if ub_overlap_rs else None,
out_index=proj_out_index,
ub_algo=ub_algo,
ub=ub_obj,
extra_output_tensor=rs_out,
out_index=out_index,
fp8_meta_tensor=meta_tensor,
D_dtype=proj_out_tetype,
D_dtype=out_tedtype,
)
if fp8_output:
out = Float8Tensor(
......@@ -261,17 +286,30 @@ class _Linear(torch.autograd.Function):
-amin, amax
).float()
if ub_overlap_rs:
ub_obj_projout = get_ub(ub_name + "_fprop")
out = ub_obj_projout.get_ubuf_output(1)
ub_obj = None
ub_algo = None
rs_out = None
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
out = ub_obj.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group)
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = out_features
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap():
if ub_obj.is_p2p_overlap():
ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_obj.copy_input_to_ubuf(inputmat_total, True)
dim_size = list(inputmat_total.size())
dim_size[0] *= tp_size # all-gathered sequence length
dim_size[1] = out_features
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
else:
dim_size = list(inputmat_total.size())
dim_size[1] = out_features
......@@ -285,9 +323,9 @@ class _Linear(torch.autograd.Function):
bias=bias,
use_bias=use_bias,
out=out,
ub_algo=ub_algo if ub_overlap_rs else None,
ub=ub_obj_projout if ub_overlap_rs else None,
extra_output_tensor=rs_out if ub_overlap_rs else None,
ub_algo=ub_algo,
ub=ub_obj,
extra_output_tensor=rs_out,
)
if is_grad_enabled:
......@@ -343,7 +381,10 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp_shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.ub_overlap_ag = ub_overlap_ag
ctx.ub_overlap_ag = ub_overlap_ag_dgrad
ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_name = ub_name
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
......@@ -356,12 +397,13 @@ class _Linear(torch.autograd.Function):
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
# Row Parallel Linear
if ub_overlap_rs:
out = rs_out
elif parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)
if parallel_mode == "row":
if ub_overlap_rs_fprop:
out = rs_out
elif sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp_shape[1:-1], out_features)
......@@ -401,15 +443,75 @@ class _Linear(torch.autograd.Function):
tp_world_size = get_distributed_world_size(ctx.tp_group)
ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
ub_algo = None
ctx.ub_overlap_rs_dgrad = False if tp_world_size == 1 else ctx.ub_overlap_rs_dgrad
ctx.ub_bulk_dgrad = False if tp_world_size == 1 else ctx.ub_bulk_dgrad
ctx.ub_bulk_wgrad = False if tp_world_size == 1 else ctx.ub_bulk_wgrad
ctx.ub_obj_gradout = None
ub_obj_wgrad = None
ub_algo_wgrad = None
ub_algo_dgrad = None
rs_out = None
dgrad = None
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
if ctx.ub_overlap_ag:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size
# Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
if ctx.ub_obj_gradout.is_atomic_gemm():
ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
dgrad = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)
elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
dgrad = ctx.ub_obj_gradout.get_ubuf_output(1)
if ctx.ub_obj_gradout.is_p2p_overlap():
if ctx.ub_obj_gradout.is_atomic_gemm():
ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
else:
ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
if ctx.ub_obj_gradout.is_atomic_gemm():
ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
else:
ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)
ctx.ub_bulk_dgrad = False
ctx.ub_bulk_wgrad = False
else:
if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute
ub_algo_dgrad = tex.CommOverlapAlgo.BULK_OVERLAP_AG
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
inputmat_data = (
inputmat._data if isinstance(inputmat, Float8Tensor) else inputmat
)
ctx.ub_obj_gradout.copy_input_to_ubuf(inputmat_data, True)
inputmat_ubuf = ctx.ub_obj_gradout.get_ubuf_output(1)
if isinstance(inputmat, Float8Tensor):
inputmat._data = inputmat_ubuf
else:
inputmat = inputmat_ubuf
if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute
ub_algo_wgrad = tex.CommOverlapAlgo.BULK_OVERLAP_RS
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
dgrad = ub_obj_wgrad.get_ubuf_output(1)
if dgrad is None:
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
dgrad_shape[0] = dgrad_shape[0] * tp_world_size
dgrad = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)
(
grad_output,
......@@ -420,13 +522,17 @@ class _Linear(torch.autograd.Function):
ctx, grad_output, ctx.parallel_mode == "row"
)
# Column Parallel Linear
# Overlap input AG with dgrad
# Overlap inputmat AG with dgrad via NCCL async comms (no TP overlap via Userbuffers)
inputmat_total = None
inputmat_t_total = None
handle = None
if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel:
inputmat_total, handle = gather_along_first_dim(
inputmat_gather_handle = None
if (
weight.requires_grad
and ctx.parallel_mode == "column"
and ctx.sequence_parallel
and not ctx.ub_bulk_dgrad
):
inputmat_total, inputmat_gather_handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
)
else:
......@@ -446,13 +552,17 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad:
if ctx.fp8:
if ctx.is_input_fp8:
if ctx.is_input_fp8 or (
ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf()
):
out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8BwdTensors.GRAD_INPUT1,
ctx.fp8_meta["scaling_bwd"],
fp8_dtype_backward,
torch.uint8,
)
if ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf():
ctx.ub_obj_gradout.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
else:
out_index, meta_tensor, output_te_dtype, output_dtype = (
None,
......@@ -460,7 +570,7 @@ class _Linear(torch.autograd.Function):
None,
ctx.activation_dtype,
)
dgrad, _ = fp8_gemm(
_ = fp8_gemm(
weight_fp8.transpose_2d(),
weight_fp8._scale_inv,
0,
......@@ -472,12 +582,18 @@ class _Linear(torch.autograd.Function):
output_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=ub_algo if ctx.ub_overlap_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
ub_algo=ub_algo_dgrad,
ub=ctx.ub_obj_gradout,
out=dgrad,
out_index=out_index,
fp8_meta_tensor=meta_tensor,
D_dtype=output_te_dtype,
extra_output_tensor=rs_out,
)
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out
if output_dtype == torch.uint8:
dgrad = Float8Tensor(
data=dgrad,
......@@ -488,30 +604,34 @@ class _Linear(torch.autograd.Function):
dtype=ctx.activation_dtype,
)
else:
dgrad, _, _ = gemm(
_ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
ub_algo=(
tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
if ctx.ub_overlap_ag
else None
),
ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
ub_algo=ub_algo_dgrad,
ub=ctx.ub_obj_gradout,
out=dgrad,
extra_output_tensor=rs_out,
)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if handle is not None:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out
if inputmat_gather_handle is not None:
inputmat_gather_handle.wait()
# Overlap dgrad RS/AR with wgrad via NCCL async comms (no TP overlap via Userbuffers)
dgrad_reduce_handle = None
if ctx.requires_dgrad and ctx.parallel_mode == "column":
if ctx.sequence_parallel and not (ctx.ub_overlap_rs_dgrad or ctx.ub_bulk_wgrad):
dgrad, dgrad_reduce_handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
elif ctx.tensor_parallel and not ctx.sequence_parallel:
dgrad, dgrad_reduce_handle = allreduce(dgrad, ctx.tp_group, async_op=True)
wgrad = None
if weight.requires_grad:
......@@ -548,6 +668,8 @@ class _Linear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
ub=ub_obj_wgrad,
ub_algo=ub_algo_wgrad,
)
else:
wgrad, _, _ = gemm(
......@@ -559,6 +681,8 @@ class _Linear(torch.autograd.Function):
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub=ub_obj_wgrad,
ub_algo=ub_algo_wgrad,
)
else:
# WGRAD
......@@ -572,15 +696,20 @@ class _Linear(torch.autograd.Function):
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub=ub_obj_wgrad,
ub_algo=ub_algo_wgrad,
)
if ctx.ub_bulk_wgrad:
dgrad = ub_obj_wgrad.get_ubuf_output(0)
# Deallocate input tensor
clear_tensor_data(inputmat_total)
clear_tensor_data(inputmat_t_total)
# Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
# Wait for dgrad reduce-scatter or all-reduce
if dgrad_reduce_handle is not None:
dgrad_reduce_handle.wait()
if not ctx.use_bias:
grad_bias = None
......@@ -634,8 +763,12 @@ class _Linear(torch.autograd.Function):
None, # activation_dtype
None, # parallel_mode
None, # is_grad_enabled
None, # ub_overlap_rs
None, # ub_overlap_ag
None, # ub_overlap_rs_fprop
None, # ub_overlap_ag_dgrad
None, # ub_overlap_ag_fprop
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fp8_output
None, # fsdp_group
......@@ -729,8 +862,10 @@ class Linear(TransformerEngineBaseModule):
parallel_mode: Optional[str] = None,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
device: Union[torch.device, str] = "cuda",
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
ub_overlap_rs: bool = False,
ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
ub_name: Optional[str] = None,
) -> None:
super().__init__()
......@@ -742,13 +877,6 @@ class Linear(TransformerEngineBaseModule):
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag
if ub_overlap_rs or ub_overlap_ag:
assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
if device == "meta":
assert parameters_split is None, "Cannot split module parameters on 'meta' device."
......@@ -773,6 +901,45 @@ class Linear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
# Column parallel TP overlap options
self.ub_overlap_ag_fprop = parallel_mode == "column" and sequence_parallel and ub_overlap_ag
self.ub_overlap_rs_dgrad = parallel_mode == "column" and sequence_parallel and ub_overlap_rs
self.ub_bulk_dgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_dgrad
self.ub_bulk_wgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_wgrad
if self.ub_overlap_rs_dgrad:
self.ub_bulk_dgrad = False
self.ub_bulk_wgrad = False
# Row parallel TP overlap options
self.ub_overlap_rs_fprop = parallel_mode == "row" and sequence_parallel and ub_overlap_rs
self.ub_overlap_ag_dgrad = parallel_mode == "row" and sequence_parallel and ub_overlap_ag
if any(
[
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
]
):
assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
self.ub_name = ub_name
assert not (
self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop
), "Cannot enable AG+GEMM and GEMM+RS overlaps at the same time."
assert not (
self.ub_overlap_rs_dgrad and self.ub_bulk_dgrad
), "Cannot enable DGRAD+RS and bulk DGRAD overlaps at the same time."
assert not (
self.ub_overlap_ag_dgrad and (self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad)
), "Cannot enable AG+DGRAD and DGRAD+RS or bulk DGRAD overlaps at the same time."
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
# Initialize params in FP8
with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
......@@ -1017,8 +1184,12 @@ class Linear(TransformerEngineBaseModule):
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
self.ub_overlap_rs,
self.ub_overlap_ag,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment