"vscode:/vscode.git/clone" did not exist on "15f5632365a98fd43ea42e4948a995aa399e99b5"
Unverified Commit 095b27d0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Userbuffers support in operation-based API (#1142)



* Add Userbuffers support for column TP linear layer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add Userbuffers support for row TP linear layer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Interpret linear+RS as row TP linear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add Userbuffers support for FP8 row TP linear layer

Assumes FP8 RS, which is not a good assumption.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug bug with incorrect bias pointers in UB GEMM

Bias pointers are not properly offset for different data chunks. Also removed logic for FP8 RS.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add Userbuffers support for linear dgrad

Test passes with row TP, fails with col TP.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add Userbuffers support for linear wgrad
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add support for grad bias
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fused cast-transpose-dbias
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support case where wgrad is optional
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand documentation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use recently added convenience functions in Float8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Respect autograd dtype
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix missing imports
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Respect PyT autocast dtype in bprop
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug merge conflicts
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 77c37d49
......@@ -10,4 +10,5 @@ pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import argparse
import dataclasses
import functools
import itertools
import os
import pathlib
import subprocess
import sys
import pytest
import torch
import transformer_engine
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
)
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, str_to_dtype
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
# Check if there are multiple GPUs
if torch.cuda.device_count() < 2:
pytest.skip("Userbuffers requires at least 2 GPUs.")
@dataclasses.dataclass
class ModelConfig:
"""Tensor dimensions in Transformer model"""
sequence_length: int
batch_size: int
num_heads: int
head_dim: int
dtype: torch.dtype
fp8: bool
@property
def hidden_size(self):
return self.num_heads * self.head_dim
@functools.cache
def launcher() -> str:
"""Launcher for current parallel job"""
if "OMPI_COMM_WORLD_SIZE" in os.environ:
return "ompi"
if "TORCHELASTIC_RUN_ID" in os.environ:
return "torchrun"
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`")
@functools.cache
def world_group() -> torch.distributed.ProcessGroup:
"""Get NCCL process group, initializing if needed"""
# Get launch config from environment
if launcher() == "ompi":
# OpenMPI
world_size = int(os.getenv("OMPI_COMM_WORLD_SIZE"))
rank = int(os.getenv("OMPI_COMM_WORLD_RANK"))
local_size = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE"))
local_rank = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK"))
elif launcher() == "torchrun":
# torchrun
world_size = int(os.getenv("WORLD_SIZE"))
rank = int(os.getenv("RANK"))
local_size = int(os.getenv("LOCAL_WORLD_SIZE"))
local_rank = int(os.getenv("LOCAL_RANK"))
else:
raise RuntimeError("Unexpected launcher ({launcher()})")
# Construct communicator
assert local_size == world_size
torch.cuda.set_device(local_rank)
group = torch.distributed.init_process_group(
"nccl",
init_method="file:///tmp/rdzv",
world_size=world_size,
rank=rank,
device_id=torch.device(f"cuda:{local_rank}"),
)
return group
def reset_rng(seed: int = 1234) -> None:
"""Reset random number generators"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
"""
# Random data
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor
if test_is_fp8:
test = Float8Tensor.to_float8(ref)
else:
test = ref.to(device=test_device, dtype=test_dtype)
if test.data_ptr() == ref.data_ptr():
test = test.clone()
# Make sure reference and test tensors represent exact same values
ref.copy_(test)
# Return reference and test tensors
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def _test_linear(
*,
model_config: ModelConfig,
bias: bool = False,
device: torch.device = "cuda",
tensor_parallel_mode: str = "column",
sequence_parallel: bool = True,
weight_requires_grad: bool = True,
) -> None:
dtype = model_config.dtype
fp8_compute = model_config.fp8
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
out_features = model_config.hidden_size
in_features = model_config.hidden_size
batch_size = model_config.sequence_length * model_config.batch_size
in_shape = [batch_size, in_features]
out_shape = [batch_size, out_features]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
)
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
bias_shape = [world_size, out_features]
else:
bias_shape = [out_features]
b_ref, b_test = make_reference_and_test_tensors(
bias_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
if bias:
if tensor_parallel_mode == "row":
y_ref += b_ref.sum(dim=0)
else:
y_ref += b_ref
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dw_ref = w_ref.grad
db_ref = b_ref.grad if bias else None
dx_ref = x_ref.grad
if tensor_parallel_mode == "column":
local_out_features = out_features // world_size
local_slice = slice(
rank * local_out_features,
(rank + 1) * local_out_features,
)
w_ref = w_ref[local_slice, :]
dw_ref = dw_ref[local_slice, :]
w_test = w_test[local_slice, :]
if bias:
b_ref = b_ref[local_slice]
db_ref = db_ref[local_slice]
b_test = b_test[local_slice]
y_ref = y_ref[..., local_slice]
dy_ref = dy_ref[..., local_slice]
dy_test = dy_test[..., local_slice].clone()
elif tensor_parallel_mode == "row":
local_in_features = in_features // world_size
local_slice = slice(
rank * local_in_features,
(rank + 1) * local_in_features,
)
w_ref = w_ref[:, local_slice]
dw_ref = dw_ref[:, local_slice]
w_test = w_test[:, local_slice]
if bias:
b_ref = b_ref[rank, :]
db_ref = db_ref[rank, :]
b_test = b_test[rank, :]
x_ref = x_ref[..., local_slice]
dx_ref = dx_ref[..., local_slice]
x_test = x_test[..., local_slice].clone()
if sequence_parallel:
local_batch_size = batch_size // world_size
local_slice = slice(
rank * local_batch_size,
(rank + 1) * local_batch_size,
)
if tensor_parallel_mode == "column":
x_ref = x_ref[local_slice, ...]
dx_ref = dx_ref[local_slice, ...]
x_test = x_test[local_slice, ...].clone()
elif tensor_parallel_mode == "row":
y_ref = y_ref[local_slice, ...]
dy_ref = dy_ref[local_slice, ...]
dy_test = dy_test[local_slice, ...].clone()
x_test.requires_grad_()
# Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_compute):
ops = []
linear_op = None
bias_op = None
if tensor_parallel_mode == "column":
userbuffers_options = {}
if not weight_requires_grad:
if fp8_compute:
userbuffers_options["comm_name"] = "fc1"
else:
# There is a correctness bug with overlapping
# dgrad reduce-scatter with dgrad GEMM. Fall back
# to overlapping dgrad reduce-scatter with wgrad
# GEMM, even though wgrad isn't needed.
userbuffers_options["comm_name"] = "qkv"
else:
userbuffers_options["comm_name"] = "qkv"
linear_op = te_ops.BasicLinear(
in_features,
out_features,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=process_group,
sequence_parallel=sequence_parallel,
userbuffers_options=userbuffers_options,
)
ops.append(linear_op)
if bias:
bias_op = te_ops.Bias(
out_features // world_size,
device=device,
dtype=dtype,
)
ops.append(bias_op)
elif tensor_parallel_mode == "row":
userbuffers_options = dict(comm_name="proj")
linear_op = te_ops.BasicLinear(
in_features // world_size,
out_features,
device=device,
dtype=dtype,
userbuffers_options=userbuffers_options,
)
ops.append(linear_op)
if bias:
bias_op = te_ops.Bias(out_features, device=device, dtype=dtype)
ops.append(bias_op)
ops.append(te_ops.ReduceScatter(process_group))
model = te_ops.Sequential(*ops)
with torch.no_grad():
linear_op.weight.copy_(w_test)
linear_op.weight.requires_grad_(requires_grad=weight_requires_grad)
if bias:
bias_op.bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=fp8_compute):
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
backward_ops = model._module_groups[0]._backward_ops
assert len(forward_ops) == 1
assert len(backward_ops) == 1
assert isinstance(forward_ops[0][0], UserbuffersForwardLinear)
assert isinstance(backward_ops[0][0], UserbuffersBackwardLinear)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
tols = dtype_tols(
model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight)
else tex.DType.kFloat8E4M3
)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
if weight_requires_grad:
dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dw_test, dw_ref, **tols)
if bias:
db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, db_ref, **tols)
def run_parallel_tests(model_config: ModelConfig) -> None:
"""Run parallel tests"""
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Linear op
for test_config in itertools.product(
(False, True), # bias
("column", "row"), # tensor_parallel_mode
(False, True), # weight_requires_grad
):
if rank == 0:
print(f"Running _test_linear with {test_config=}")
bias, tensor_parallel_mode, weight_requires_grad = test_config
_test_linear(
model_config=model_config,
bias=bias,
tensor_parallel_mode=tensor_parallel_mode,
weight_requires_grad=weight_requires_grad,
)
# Parallel job sizes
_world_sizes = []
if torch.cuda.device_count() > 1:
_world_sizes.append(torch.cuda.device_count())
@pytest.mark.parametrize("world_size", _world_sizes)
@pytest.mark.parametrize("fp8", (False, True))
def test_fuser_ops_with_userbuffers(
*,
world_size: int,
dtype: torch.dtype = torch.bfloat16,
fp8: bool,
) -> None:
"""Launch parallel job and run tests"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Parallel job launcher
command = []
if tex.ubuf_built_with_mpi():
python_exe = pathlib.Path(sys.executable).resolve()
command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe))
else:
command.extend(("torchrun", f"--nproc_per_node={world_size}"))
# Script invocation
command.extend(
(
_current_file,
"--parallel",
"--batch-size",
str(world_size),
"--num-heads",
str(world_size),
"--dtype",
str(dtype),
)
)
if fp8:
command.append("--fp8")
# Environment
env = dict(os.environ)
if not tex.device_supports_multicast():
env["UB_SKIPMC"] = "1"
env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
env["PYTORCH_JIT"] = "0"
env["NVTE_TORCH_COMPILE"] = "0"
env["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
# Launch parallel job
result = subprocess.run(command, check=True, env=env)
def main() -> None:
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
parser.add_argument("--sequence-length", type=int, default=32)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--head-dim", type=int, default=32)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--fp8", action="store_true")
args = parser.parse_args()
# Run parallel tests if needed
if args.parallel:
# Model config
model_config = ModelConfig(
sequence_length=args.sequence_length,
batch_size=args.batch_size,
num_heads=args.num_heads,
head_dim=args.head_dim,
dtype=str_to_dtype(args.dtype),
fp8=args.fp8,
)
# Initialize Userbuffers
group = world_group() # Initialize NCCL
bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl"
userbuffer_configs = {
"fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM
}
te.module.base.initialize_ub(
[
model_config.sequence_length * model_config.batch_size,
model_config.num_heads * model_config.head_dim,
],
torch.distributed.get_world_size(group),
use_fp8=model_config.fp8,
dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs,
)
# Run tests
run_parallel_tests(model_config)
# Clean up
te.module.base.destroy_ub()
if __name__ == "__main__":
main()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import torch
import transformer_engine
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
"""Convert type name to PyTorch dtype"""
if isinstance(dtype, torch.dtype):
return dtype
name = str(dtype).strip().lower()
if name.startswith("torch."):
name = name.replace("torch.", "", 1)
if name.startswith("fp"):
name = name.replace("fp", "float", 1)
dtype = dict(
float32=torch.float32,
float=torch.float32,
float64=torch.float64,
double=torch.float64,
float16=torch.float16,
half=torch.float16,
bfloat16=torch.bfloat16,
bf16=torch.bfloat16,
float8_e4m3fn=torch.float8_e4m3fn,
float8_e4m3=torch.float8_e4m3fn,
float8e4m3=torch.float8_e4m3fn,
float8=torch.float8_e4m3fn,
float8_e5m2=torch.float8_e5m2,
float8e5m2=torch.float8_e5m2,
uint8=torch.uint8,
byte=torch.uint8,
int8=torch.int8,
char=torch.int8,
int16=torch.int16,
short=torch.int16,
int32=torch.int32,
int=torch.int32,
int64=torch.int64,
long=torch.int64,
bool=torch.bool,
)[name]
return dtype
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
tex.DType.kFloat8E5M2: torch.float8_e5m2,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
if dtype == torch.float8_e4m3fn:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
raise ValueError(f"Unsupported dtype ({dtype})")
......@@ -314,11 +314,13 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
size_t m_chunk = m / _num_splits;
size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk;
size_t bias_chunk_size = m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.dptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.dptr());
char *bias_chunk_ptr = reinterpret_cast<char *>(bias.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
......@@ -337,16 +339,21 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk =
TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr);
auto bias_chunk =
TensorWrapper(bias.dptr(), {m_chunk}, bias.dtype(), nullptr, nullptr, nullptr);
auto workspace_chunk = TensorWrapper(
workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * D.element_size();
if (bias_chunk_ptr != nullptr) {
bias_chunk_ptr += bias_chunk_size * bias.element_size();
}
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
......@@ -354,10 +361,12 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
A.dtype(), nullptr, nullptr, A.scale_inv());
output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr), {n, m_chunk},
D.dtype(), D.amax(), D.scale(), nullptr);
bias_chunk = TensorWrapper(reinterpret_cast<void *>(bias_chunk_ptr), {m_chunk}, bias.dtype(),
nullptr, nullptr, nullptr);
workspace_chunk = TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -409,11 +418,13 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr),
{n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr);
auto bias_chunk = TensorWrapper(reinterpret_cast<void *>(bias_chunk_ptr), {m_chunk},
bias.dtype(), nullptr, nullptr, nullptr);
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -440,6 +451,9 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
rs_output_ptr += m_chunk * rs_output.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
if (bias_chunk_ptr != nullptr) {
bias_chunk_ptr += bias_chunk_size * bias.element_size();
}
}
}
......
......@@ -9,6 +9,8 @@ from typing import Any, Iterable, Optional
import torch
from transformer_engine_torch import FP8TensorMeta
from ..fp8 import FP8GlobalStateManager
from ..tensor import Float8Tensor
from ..utils import (
canonicalize_device, # pylint: disable=unused-import
......@@ -134,3 +136,25 @@ def maybe_autocast_dtype(
if torch.is_autocast_enabled(device_type):
return torch.get_autocast_dtype(device_type)
return canonicalize_dtype(default_dtype)
def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, int]:
"""Get FP8TensorMeta object and index corresponding to Float8Tensor
Constructs FP8TensorMeta if needed.
"""
# Check if tensor already has FP8 metadata
if tensor._fp8_meta is not None:
key = FP8GlobalStateManager.get_meta_tensor_key(
forward=tensor._fp8_meta_forward,
)
return tensor._fp8_meta[key], tensor._fp8_meta_index
# Create FP8TensorMeta class
fp8_meta = FP8TensorMeta()
fp8_meta.scale = tensor._scale_inv.reciprocal()
fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=tensor.device)
fp8_meta.scale_inv = tensor._scale_inv
return fp8_meta, 0
......@@ -83,6 +83,10 @@ class BasicLinear(BasicOperation):
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
experimental.
"""
......@@ -98,6 +102,7 @@ class BasicLinear(BasicOperation):
sequence_parallel: bool = False,
rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
accumulate_into_main_grad: bool = False,
userbuffers_options: Optional[dict[str, Any]] = None,
) -> None:
super().__init__()
......@@ -143,7 +148,7 @@ class BasicLinear(BasicOperation):
)
# Whether weight tensor is natively in FP8
self._with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self._with_fp8_parameters: bool = FP8GlobalStateManager.with_fp8_parameters()
if self._with_fp8_parameters:
self._fp8_metas = self._make_fp8_metas()
......@@ -163,7 +168,10 @@ class BasicLinear(BasicOperation):
self.reset_parameters()
# Whether to accumulate weight gradient into main_grad
self._accumulate_into_main_grad = accumulate_into_main_grad
self._accumulate_into_main_grad: bool = accumulate_into_main_grad
# Userbuffers options
self._userbuffers_options: Optional[dict[str, Any]] = userbuffers_options
@classmethod
def _canonicalize_tensor_parallelism(
......@@ -707,7 +715,7 @@ class BasicLinear(BasicOperation):
FP8 metadata for casting loss gradient w.r.t. output
tensor to FP8. Required if output grad is not already in
FP8.
grad_output_fp8_meta: dict, optional
grad_input_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8
......
......@@ -16,3 +16,11 @@ from .forward_linear_bias_add import (
ForwardLinearBiasAdd,
fuse_forward_linear_bias_add,
)
from .userbuffers_backward_linear import (
UserbuffersBackwardLinear,
fuse_userbuffers_backward_linear,
)
from .userbuffers_forward_linear import (
UserbuffersForwardLinear,
fuse_userbuffers_forward_linear,
)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Linear layer backward with Userbuffers communication."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import warnings
import torch
from transformer_engine_torch import CommOverlapAlgo
from ...cpp_extensions import (
fp8_cast_transpose_bgrad_fused,
fp8_gemm,
gemm,
)
from ...distributed import get_distributed_world_size
from ...float8_tensor import Float8Tensor
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...module.base import get_ub, get_workspace
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import FusedOperation, FusibleOperation, OperationContext
from .._common import (
convert_tensor,
get_fp8_meta_from_fp8_tensor,
is_float8_tensor,
reshape,
)
class UserbuffersBackwardLinear(FusedOperation):
"""Linear backward implementation using Userbuffers
This operation is equivalent to a linear operation's backward
pass, but it uses Userbuffers to overlap tensor-parallel
communication with compute.
"""
def __init__(
self,
*,
linear: BasicLinear,
bias: Optional[Bias],
reduce_scatter: Optional[ReduceScatter],
) -> None:
# Basic operations that comprise this fused operation
op_idxs = {"linear": None, "bias": None, "reduce_scatter": None}
ops = []
if reduce_scatter is not None:
op_idxs["reduce_scatter"] = len(ops)
ops.append(reduce_scatter)
if bias is not None:
op_idxs["bias"] = len(ops)
ops.append(bias)
op_idxs["linear"] = len(ops)
ops.append(linear)
# Initialize base class
super().__init__(ops)
# Index of each basic operations
self._op_idxs: dict[str, Optional[int]] = op_idxs
# Tensor parallelism configuration
self.tensor_parallel_mode: Optional[str]
self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup]
self.tensor_parallel_size: int
self.sequence_parallel: bool
if reduce_scatter is None:
self.tensor_parallel_mode = linear.tensor_parallel_mode
self.tensor_parallel_group = linear.tensor_parallel_group
self.tensor_parallel_size = linear.tensor_parallel_size
self.sequence_parallel = linear.sequence_parallel
else:
self.tensor_parallel_mode = "row"
self.tensor_parallel_group = reduce_scatter.process_group
self.tensor_parallel_size = reduce_scatter.process_group_size
self.sequence_parallel = True
@staticmethod
def _functional_backward(
grad_output: torch.Tensor,
input: Optional[torch.Tensor], # pylint: disable=redefined-builtin
weight: Optional[torch.Tensor],
input_dims: Iterable[int],
weight_dims: Iterable[int],
*,
weight_requires_grad: bool = True,
bias_requires_grad: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
grad_weight: Optional[torch.Tensor] = None,
accumulate_into_grad_weight: bool = False,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None,
sequence_parallel: bool = False,
with_fp8_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None,
grad_output_fp8_meta: Optional[dict[str, Any]] = None,
grad_input_fp8_meta: Optional[dict[str, Any]] = None,
ub_comm_name: str,
) -> tuple[torch.Tensor, Optional[torch.Tensor], dict]:
"""Functional API for backward pass
Parameters
----------
grad_output: torch.Tensor
Loss gradient w.r.t. output tensor
input: torch.Tensor, optional
Input tensor. Required to compute loss gradient w.r.t.
weight.
weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t.
input.
input_dims: iterable of int
Input tensor dimensions
weight_dims: iterable of int
Weight tensor dimensions
weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor
bias_requires_grad: bool
Whether to compute loss gradient w.r.t. bias tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
grad_weight: torch.Tensor, optional
Loss gradient w.r.t. weight tensor
accumulate_into_grad_weight: bool, default = `False`
Add result to weight grad instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
grad_output_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. output
tensor to FP8. Required if output grad is not already in
FP8.
grad_input_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
(e.g. "qkv_dgrad", "qkv_wgrad").
Returns
-------
torch.Tensor
Loss gradient w.r.t. input tensor
torch.Tensor
Loss gradient w.r.t. weight tensor
dict
Extra output tensors. "grad_bias" is loss gradient w.r.t.
the bias tensor.
"""
# Configuration-specific outputs
extra_outputs = {}
# Check device
if device is None:
device = weight.device
device = canonicalize_device(device)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype
if dtype is None:
dtype = weight.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Input tensor dims
output_dims = tuple(grad_output.size())
input_dims = tuple(input_dims)
weight_dims = tuple(weight_dims)
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
if weight_dims[0] != output_dims[-1]:
raise ValueError(
f"Grad output tensor (shape={output_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Check tensor parallel group
if tensor_parallel_size is None:
tensor_parallel_size = get_distributed_world_size(tensor_parallel_group)
if tensor_parallel_size == 1:
tensor_parallel_mode = None
if tensor_parallel_mode not in ("column", "row"):
raise RuntimeError(
"Invalid configuration for Userbuffers "
f"({tensor_parallel_size=}, {tensor_parallel_mode=})"
)
if not sequence_parallel:
raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})")
# Check if FP8 is enabled
if with_fp8_compute:
if grad_output_fp8_meta is None and not is_float8_tensor(grad_output):
raise ValueError("No FP8 metadata was provided for casting output gradient to FP8")
else:
input_fp8_meta = None
weight_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
with_fp8_grad_input = (
with_fp8_compute
and tensor_parallel_mode != "column"
and grad_input_fp8_meta is not None
)
# Get Userbuffers communicators and algorithms
# Note: communication patterns are (1) overlap dy all-gather
# with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM
# and dx reduce-scatter with wgrad GEMM, (3) overlap dx
# reduce-scatter with dgrad GEMM.
with_ub_all_gather_dy = False
with_ub_reduce_scatter_dx = False
with_ub_all_gather_x = False
ub_comm_dy = None
ub_comm_dx = None
ub_comm_x = None
ub_algo_dy = None
ub_algo_dx = None
ub_algo_x = None
if tensor_parallel_mode == "row":
with_ub_all_gather_dy = True
ub_comm_dy = get_ub(ub_comm_name + "_dgrad")
if with_fp8_compute and ub_comm_dy.is_atomic_gemm():
ub_algo_dy = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo_dy = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif tensor_parallel_mode == "column":
with_ub_reduce_scatter_dx = True
if weight_requires_grad:
with_ub_all_gather_x = True
ub_comm_dx = get_ub(ub_comm_name + "_wgrad")
ub_comm_x = get_ub(ub_comm_name + "_dgrad")
ub_algo_dx = CommOverlapAlgo.BULK_OVERLAP_RS
ub_algo_x = CommOverlapAlgo.BULK_OVERLAP_AG
else:
with_ub_all_gather_x = False
ub_comm_dx = get_ub(ub_comm_name + "_dgrad")
is_atomic_gemm = with_fp8_compute and ub_comm_dx.is_atomic_gemm()
ub_algo_dx = {
(True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P,
(True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P,
(False, True): CommOverlapAlgo.ATOMIC_GEMM_RS,
(False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS,
}[(ub_comm_dx.is_p2p_overlap(), is_atomic_gemm)]
# Check grad output tensor
# Note: Possibly fuse cast with computing grad bias
dy_local = reshape(
grad_output,
(-1, output_dims[-1]),
device=device,
dtype=dtype,
)
db = None
db_async = None
if bias_requires_grad and with_fp8_compute and with_ub_all_gather_dy:
# We don't have a grad bias impl that takes FP8 input. For
# cases where we cast to FP8 and all-gather, it's better
# to compute the grad bias on ungathered, non-FP8 values.
db = dy_local.sum(dim=0)
db_async = torch.distributed.all_reduce(
db,
group=tensor_parallel_group,
async_op=True,
)
if with_fp8_compute and not is_float8_tensor(dy_local):
fp8_dtype = get_fp8_te_dtype(
grad_output_fp8_meta["recipe"],
fprop_tensor=False,
)
if bias_requires_grad and db is None:
# Fused cast-transpose-bgrad
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device)
db, data, data_transpose = fp8_cast_transpose_bgrad_fused(
dy_local,
grad_output_fp8_meta[fp8_meta_key],
0,
fp8_dtype,
scale_inv=fp8_scale_inv,
)
if with_ub_all_gather_dy:
data = ub_comm_dy.get_ubuf_output(0).copy_(data)
dy_local = Float8Tensor(
data=data,
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
dtype=dtype,
data_transpose=data_transpose,
)
else:
dy_local = Float8Tensor.to_float8(
dy_local,
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_comm_dy.get_ubuf_output(0) if with_ub_all_gather_dy else None),
with_transpose_cache=(not with_ub_all_gather_dy),
)
elif not with_fp8_compute and is_float8_tensor(dy_local):
if with_ub_all_gather_dy:
ub_local_buffer = ub_comm_dy.get_ubuf_output(0)
dy_local = ub_local_buffer.copy_(dy_local)
else:
dy_local = dy_local.dequantize()
if bias_requires_grad and db is None and with_fp8_compute and with_ub_all_gather_dy:
# We don't have a fused grad bias impl that takes FP8
# input. For cases where we cast to FP8 and all-gather,
# it's better to compute the grad bias on ungathered,
# non-FP8 values.
db = dy_local.sum(dim=0)
db_async = torch.distributed.all_reduce(
db,
group=tensor_parallel_group,
async_op=True,
)
# Check input tensor
x_local = None
if weight_requires_grad:
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
x_local = Float8Tensor.to_float8(
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_comm_x.get_ubuf_output(0) if with_ub_all_gather_x else None),
with_transpose_cache=(not with_ub_all_gather_x),
)
elif not with_fp8_compute and is_float8_tensor(x_local):
if with_ub_all_gather_x:
ub_local_buffer = ub_comm_x.get_ubuf_output(0)
x_local = ub_local_buffer.copy_(x_local)
else:
x_local = x_local.dequantize()
# Check weight tensor
w = convert_tensor(
weight,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
with_transpose_cache=True,
)
elif not with_fp8_compute and is_float8_tensor(w):
w = w.dequantize()
# Initialize buffers for UB all-gather if needed
dy = dy_local
x = x_local
if with_ub_all_gather_dy:
ub_local_buffer = ub_comm_dy.get_ubuf_output(0)
ub_global_buffer = ub_comm_dy.get_ubuf_output(1)
if with_fp8_compute:
dy = Float8Tensor.make_like(dy_local, data=ub_global_buffer)
if dy_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(dy_local._data)
else:
dy = ub_global_buffer
if dy_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(dy_local)
if with_ub_all_gather_x:
ub_local_buffer = ub_comm_x.get_ubuf_output(0)
ub_global_buffer = ub_comm_x.get_ubuf_output(1)
if with_fp8_compute:
x = Float8Tensor.make_like(x_local, data=ub_global_buffer)
if x_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local._data)
else:
x = ub_global_buffer
if x_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local)
# Construct grad input tensor
dx = None
dx_local = None
if with_ub_reduce_scatter_dx:
# Initialize buffers for UB reduce-scatter
dx = ub_comm_dx.get_ubuf_output(1)
ub_local_buffer = ub_comm_dx.get_ubuf_output(0)
if with_ub_all_gather_x:
dx_local = ub_local_buffer
else:
dx_local = torch.empty_like(ub_local_buffer)
else:
# Allocate grad input tensor
if with_fp8_grad_input:
fp8_dtype = get_fp8_te_dtype(
grad_input_fp8_meta["recipe"],
fprop_tensor=False,
)
data = torch.empty(
(dy.size(0), w.size(-1)),
dtype=torch.uint8,
device=device,
)
dx = Float8Tensor(
data=data,
fp8_meta=grad_input_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
dx = torch.empty(
(dy.size(0), w.size(-1)),
dtype=dtype,
device=device,
)
dx_local = dx
# Allocate grad input tensor
if grad_weight is None:
if accumulate_into_grad_weight:
raise ValueError(
"Attempted to accumulate into grad weight bufferwithout providing grad weight"
)
grad_weight = torch.empty(
weight_dims,
dtype=dtype,
device=device,
memory_format=torch.contiguous_format,
)
# Perform dgrad GEMM
if with_fp8_compute:
kwargs = {"out": dx, "use_split_accumulator": True}
if with_ub_all_gather_dy:
kwargs["ub_algo"] = ub_algo_dy
kwargs["ub"] = ub_comm_dy
elif with_ub_all_gather_x:
kwargs["ub_algo"] = ub_algo_x
kwargs["ub"] = ub_comm_x
elif with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
kwargs["extra_output_tensor"] = dx_local
if with_fp8_grad_input:
fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(dx)
kwargs.update(
{
"out": dx._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": dx._fp8_dtype,
}
)
fp8_gemm(
w.transpose_2d(),
w._scale_inv,
0,
w._fp8_dtype,
dy._data,
dy._scale_inv,
0,
dy._fp8_dtype,
dy.dtype,
get_workspace(),
**kwargs,
)
else:
kwargs = {"grad": True, "layout": "NN", "out": dx}
if with_ub_all_gather_dy:
kwargs["ub_algo"] = ub_algo_dy
kwargs["ub"] = ub_comm_dy
elif with_ub_all_gather_x:
kwargs["ub_algo"] = ub_algo_x
kwargs["ub"] = ub_comm_x
elif with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
kwargs["extra_output_tensor"] = dx_local
gemm(w, dy, dx.dtype, get_workspace(), **kwargs)
grad_input = reshape(dx_local, input_dims)
# Perform wgrad GEMM
if not weight_requires_grad:
pass
elif with_fp8_compute:
kwargs = {
"accumulate": accumulate_into_grad_weight,
"out": grad_weight,
"use_split_accumulator": True,
}
if with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
fp8_gemm(
x.transpose_2d(),
x._scale_inv,
0,
x._fp8_dtype,
dy.transpose_2d(),
dy._scale_inv,
0,
dy._fp8_dtype,
grad_weight.dtype,
get_workspace(),
**kwargs,
)
else:
kwargs = {
"accumulate": accumulate_into_grad_weight,
"layout": "NT",
"grad": True,
"use_bias": bias_requires_grad,
"out": grad_weight,
}
if with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
grad_weight, db, _ = gemm(
x,
dy,
grad_weight.dtype,
get_workspace(),
**kwargs,
)
# Compute grad bias if needed
if db_async is not None:
db_async.wait()
if bias_requires_grad:
if db is None:
db = dy.sum(dim=0)
extra_outputs["grad_bias"] = db
return grad_input, grad_weight, extra_outputs
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
list[tuple[Optional[torch.Tensor], ...]],
list[tuple[()]],
]:
# Get basic operations
idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx]
bias_op = None
if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
# Saved tensors from forward pass
(x_local,) = linear_op_ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if not hasattr(linear_op.weight, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = linear_op.weight.main_grad.detach()
else:
accumulate_into_main_grad = False
# Hackily workaround Userbuffers bug with non-FP8 dgrad
# reduce-scatter overlap
weight_requires_grad = linear_op_ctx.weight_requires_grad
if not linear_op_ctx.with_fp8_compute and not weight_requires_grad:
warnings.warn(
"There is a correctness bug when using Userbuffers "
"to overlap a dgrad reduce-scatter with a non-FP8 dgrad GEMM. "
"Hackily working around by overlapping dgrad reduce-scatter "
"with wgrad GEMM, even though wgrad isn't needed. "
"Please contact Transformer Engine team "
"if you encounter this use-case."
)
weight_requires_grad = True
# Linear backward pass
retval = UserbuffersBackwardLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=linear_op.weight,
input_dims=linear_op_ctx.input_dims,
weight_dims=linear_op.weight.size(),
weight_requires_grad=weight_requires_grad,
bias_requires_grad=(bias_op is not None),
device=linear_op.device,
dtype=linear_op_ctx.dtype,
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
sequence_parallel=self.sequence_parallel,
with_fp8_compute=linear_op_ctx.with_fp8_compute,
weight_fp8_meta=linear_op_ctx.weight_fp8_meta,
grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta,
grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta,
ub_comm_name=linear_op._userbuffers_options["comm_name"],
)
grad_input, grad_weight, extra_outputs = retval
grad_bias = None
if bias_op is not None:
grad_bias = extra_outputs["grad_bias"]
# Clear input tensor if possible
if linear_op_ctx.has_prev_op:
clear_tensor_data(x_local)
# Return gradients
grad_params = [() for _ in range(len(self.basic_ops))]
if accumulate_into_main_grad:
grad_weight = None
grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,)
grad_extra_inputs = [() for _ in range(len(self.basic_ops))]
return grad_input, grad_params, grad_extra_inputs
def fuse_userbuffers_backward_linear(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Substitute linear operations with Userbuffers implementation
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
"""
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Sliding window in list of ops
window = []
def peek_next_op() -> Optional[FusibleOperation]:
"""Get next op in list of ops"""
nonlocal ops
if not ops:
return None
return ops[-1][0]
def pop_next_op() -> FusibleOperation:
"""Remove next op from list of ops and add to sliding window"""
nonlocal ops, window
window.insert(0, ops[-1])
ops = ops[:-1]
return window[0][0]
# Scan through ops in reverse order, fusing if possible
out_reversed = []
while ops:
out_reversed.extend(reversed(window))
window.clear()
# Check if next op is linear
next_op = pop_next_op()
if not isinstance(next_op, BasicLinear):
continue
linear = next_op
if linear._userbuffers_options is None:
continue
# Check if next op is bias
bias = None
if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias):
bias = pop_next_op()
# Check if next op is reduce-scatter
reduce_scatter = None
if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter):
reduce_scatter = pop_next_op()
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
continue
if reduce_scatter.process_group_size == 1:
continue
# Replace window with fused op
op = UserbuffersBackwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out_reversed.extend(reversed(window))
out = out_reversed
out.reverse()
return out
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Linear layer forward with Userbuffers communication."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
from transformer_engine_torch import CommOverlapAlgo
from ...cpp_extensions import fp8_gemm, gemm
from ...distributed import get_distributed_world_size
from ...float8_tensor import Float8Tensor
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...module.base import get_ub, get_workspace
from ...utils import canonicalize_device, canonicalize_dtype
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import (
BasicOperation,
FusedOperation,
FusibleOperation,
OperationContext,
)
from .._common import (
convert_tensor,
get_fp8_meta_from_fp8_tensor,
is_float8_tensor,
reshape,
)
class UserbuffersForwardLinear(FusedOperation):
"""Linear forward implementation using Userbuffers
This operation is equivalent to a linear operation's forward pass,
but it uses Userbuffers to overlap tensor-parallel communication
with compute.
"""
def __init__(
self,
*,
linear: BasicLinear,
bias: Optional[Bias],
reduce_scatter: Optional[ReduceScatter],
) -> None:
# Basic operations that comprise this fused operation
op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None}
ops = [linear]
if bias is not None:
op_idxs["bias"] = len(ops)
ops.append(bias)
if reduce_scatter is not None:
op_idxs["reduce_scatter"] = len(ops)
ops.append(reduce_scatter)
# Initialize base class
super().__init__(ops)
# Index of each basic operations
self._op_idxs: dict[str, Optional[int]] = op_idxs
# Tensor parallelism configuration
self.tensor_parallel_mode: Optional[str]
self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup]
self.tensor_parallel_size: int
self.sequence_parallel: bool
if reduce_scatter is None:
self.tensor_parallel_mode = linear.tensor_parallel_mode
self.tensor_parallel_group = linear.tensor_parallel_group
self.tensor_parallel_size = linear.tensor_parallel_size
self.sequence_parallel = linear.sequence_parallel
else:
self.tensor_parallel_mode = "row"
self.tensor_parallel_group = reduce_scatter.process_group
self.tensor_parallel_size = reduce_scatter.process_group_size
self.sequence_parallel = True
@staticmethod
def _functional_forward(
input: torch.Tensor, # pylint: disable=redefined-builtin
weight: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None,
sequence_parallel: bool = False,
with_fp8_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None,
output_fp8_meta: Optional[dict[str, Any]] = None,
ub_comm_name: str,
) -> tuple[torch.Tensor, dict]:
"""Functional API for forward pass
Parameters
----------
input: torch.Tensor
Input tensor
weight: torch.Tensor
Weight tensor
bias: torch.Tensor, optional
Bias tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
output_fp8_meta: dict, optional
FP8 metadata for casting output tensor to FP8
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
(e.g. "qkv_fprop").
Returns
-------
torch.Tensor
Output tensor
dict
Extra output tensors. "input" is the input tensor,
possibly cast and reshaped from the provided input tensor.
"""
# Check device
if device is None:
device = weight.device
device = canonicalize_device(device)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype
if dtype is None:
dtype = weight.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Input tensor dims
input_dims = tuple(input.size())
weight_dims = tuple(weight.size())
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Output tensor dims
output_dims = list(input_dims)
output_dims[0] = -1
output_dims[-1] = weight_dims[0]
# Check tensor parallel group
if tensor_parallel_size is None:
tensor_parallel_size = get_distributed_world_size(tensor_parallel_group)
if tensor_parallel_size == 1:
tensor_parallel_mode = None
if tensor_parallel_mode not in ("column", "row"):
raise RuntimeError(
"Invalid configuration for Userbuffers "
f"({tensor_parallel_size=}, {tensor_parallel_mode=})"
)
if not sequence_parallel:
raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})")
# Check if FP8 is enabled
if with_fp8_compute:
if input_fp8_meta is None and not is_float8_tensor(input):
raise ValueError("No FP8 metadata was provided for casting input to FP8")
if weight_fp8_meta is None and not is_float8_tensor(weight):
raise ValueError("No FP8 metadata was provided for casting weight to FP8")
else:
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
with_fp8_output = (
with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None
)
# Get Userbuffers communicator
ub_comm = get_ub(ub_comm_name + "_fprop")
ub_local_buffer = ub_comm.get_ubuf_output(0)
ub_global_buffer = ub_comm.get_ubuf_output(1)
with_ub_all_gather = tensor_parallel_mode == "column"
with_ub_reduce_scatter = tensor_parallel_mode == "row"
# Choose Userbuffers communication algorithm
ub_algo = None
if with_ub_all_gather:
if with_fp8_compute and ub_comm.is_atomic_gemm():
ub_algo = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif with_ub_reduce_scatter:
is_atomic_gemm = with_fp8_compute and ub_comm.is_atomic_gemm()
ub_algo = {
(True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P,
(True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P,
(False, True): CommOverlapAlgo.ATOMIC_GEMM_RS,
(False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS,
}[(ub_comm.is_p2p_overlap(), is_atomic_gemm)]
else:
raise RuntimeError("Could not choose Userbuffers communication algorithm")
# Cast input tensor to correct dtype
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
with_transpose_cache = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_transpose_cache = False
x_local = Float8Tensor.to_float8(
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_local_buffer if with_ub_all_gather else None),
with_transpose_cache=with_transpose_cache,
)
elif not with_fp8_compute and is_float8_tensor(x_local):
if with_ub_all_gather:
x_local = ub_local_buffer.copy_(x_local)
else:
x_local = x_local.dequantize()
# Initialize buffers for UB all-gather if needed
x = x_local
if with_ub_all_gather:
if with_fp8_compute:
x = Float8Tensor.make_like(x_local, data=ub_global_buffer)
if x_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local._data)
else:
x_local._data = torch.empty_like(x_local._data)
else:
x = ub_global_buffer
if x_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local)
else:
x_local = torch.empty_like(x_local)
# Check weight tensor
w = convert_tensor(
weight,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
elif not with_fp8_compute and is_float8_tensor(w):
w = w.dequantize()
# Check bias tensor
b = None
if bias is not None:
b = convert_tensor(
bias,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
# Construct output tensor
y = None
y_local = None
if with_ub_reduce_scatter:
# Initialize buffers for UB reduce-scatter
if with_fp8_output:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
fp8_dtype = get_fp8_te_dtype(
output_fp8_meta["recipe"],
fprop_tensor=True,
)
y = Float8Tensor(
data=ub_global_buffer,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=output_fp8_meta[fp8_meta_key].scale_inv[0],
dtype=dtype,
)
ub_comm.set_ubuf_scale_inv(y._scale_inv)
else:
y = ub_global_buffer
y_local = torch.empty(
(x.size(0) // tensor_parallel_size, weight_dims[0]),
dtype=dtype,
device=device,
)
else:
# Allocate output tensor
if with_fp8_output:
fp8_dtype = get_fp8_te_dtype(
output_fp8_meta["recipe"],
fprop_tensor=True,
)
data = torch.empty(
(x.size(0), weight_dims[0]),
dtype=torch.uint8,
device=device,
)
y = Float8Tensor(
data=data,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
y = torch.empty(
(x.size(0), weight_dims[0]),
dtype=dtype,
device=device,
)
y_local = y
# Perform GEMM
if with_fp8_compute:
kwargs = {
"out": y,
"bias": b,
"use_bias": (b is not None),
"use_split_accumulator": False,
"ub_algo": ub_algo,
"ub": ub_comm,
}
if with_ub_all_gather:
kwargs["extra_output_tensor"] = x_local._data
if with_ub_reduce_scatter:
kwargs["extra_output_tensor"] = y_local
if with_fp8_output:
fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(y)
kwargs.update(
{
"out": y._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": y._fp8_dtype,
}
)
fp8_gemm(
w._data,
w._scale_inv,
0,
w._fp8_dtype,
x._data,
x._scale_inv,
0,
x._fp8_dtype,
y.dtype,
get_workspace(),
**kwargs,
)
else:
kwargs = {
"out": y,
"bias": b,
"use_bias": (b is not None),
"ub_algo": ub_algo,
"ub": ub_comm,
}
if with_ub_all_gather:
kwargs["extra_output_tensor"] = x_local
if with_ub_reduce_scatter:
kwargs["extra_output_tensor"] = y_local
gemm(w, x, y.dtype, get_workspace(), **kwargs)
# Reshape output tensor
out = reshape(y_local, output_dims)
# Return cast tensors
extra_outputs = {"input": x_local, "weight": w}
return out, extra_outputs
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
# Get basic operations
idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx]
bias_op = None
bias = None
if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
bias = bias_op.bias
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
# FP8 metadata
with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled()
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
if with_fp8_compute:
input_fp8_meta = linear_op.get_fp8_meta("input")
weight_fp8_meta = linear_op.get_fp8_meta("param")
next_op = basic_op_next_ops[-1]
if next_op is not None and next_op.num_fp8_scales("input") > 0:
output_fp8_meta = next_op.get_fp8_meta("input")
grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output")
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0:
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output")
# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
# Userbuffers options
if linear_op._userbuffers_options is None:
raise RuntimeError("Linear op is missing dict for Userbuffers options")
# Linear forward
output, extra_outputs = UserbuffersForwardLinear._functional_forward(
input=input_,
weight=linear_op.weight,
bias=bias,
device=linear_op.device,
dtype=dtype,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
tensor_parallel_size=self.tensor_parallel_size,
sequence_parallel=self.sequence_parallel,
with_fp8_compute=with_fp8_compute,
input_fp8_meta=input_fp8_meta,
weight_fp8_meta=weight_fp8_meta,
output_fp8_meta=output_fp8_meta,
ub_comm_name=linear_op._userbuffers_options["comm_name"],
)
x_local = extra_outputs["input"]
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.with_fp8_compute = with_fp8_compute
linear_op_ctx.weight_fp8_meta = weight_fp8_meta
linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta
linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta
linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))]
def fuse_userbuffers_forward_linear(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Substitute linear operations with Userbuffers implementation
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
"""
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Sliding window in list of ops
window = []
def peek_next_op() -> Optional[FusibleOperation]:
"""Get next op in list of ops"""
nonlocal ops
if not ops:
return None
return ops[0][0]
def pop_next_op() -> FusibleOperation:
"""Remove next op from list of ops and add to sliding window"""
nonlocal ops, window
window.append(ops[0])
ops = ops[1:]
return window[-1][0]
# Scan through ops, fusing if possible
out = []
while ops:
out.extend(window)
window.clear()
# Check if next op is linear
next_op = pop_next_op()
if not isinstance(next_op, BasicLinear):
continue
linear = next_op
if linear._userbuffers_options is None:
continue
# Check if next op is bias
bias = None
if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias):
bias = pop_next_op()
# Check if next op is reduce-scatter
reduce_scatter = None
if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter):
reduce_scatter = pop_next_op()
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
continue
if reduce_scatter.process_group_size == 1:
continue
# Replace window with fused op
op = UserbuffersForwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
return out
......@@ -20,6 +20,8 @@ from transformer_engine.pytorch.ops.fused import (
fuse_backward_linear_add,
fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
......@@ -345,6 +347,7 @@ class OperationFuser:
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass"""
ops = fuse_userbuffers_forward_linear(ops)
ops = fuse_forward_linear_bias_add(ops)
ops = fuse_forward_linear_bias_activation(ops)
return ops
......@@ -355,6 +358,7 @@ class OperationFuser:
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops)
return ops
......
......@@ -109,10 +109,12 @@ class _ToFloat8Func(torch.autograd.Function):
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: TE_DType = TE_DType.kFloat8E4M3,
data: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
with_transpose_cache: bool = False,
data_transpose: Optional[torch.Tensor] = None,
) -> Float8Tensor:
# pylint: disable=missing-function-docstring
......@@ -125,7 +127,8 @@ class _ToFloat8Func(torch.autograd.Function):
device = torch.device("cuda")
# FP8 data buffer
data = torch.empty(tensor.size(), dtype=torch.uint8, device=device)
if data is None:
data = torch.empty(tensor.size(), dtype=torch.uint8, device=device)
# Check scale
if scale is None and fp8_meta is None:
......@@ -140,8 +143,7 @@ class _ToFloat8Func(torch.autograd.Function):
scale_inv = scale_inv.to(device=device, dtype=torch.float32)
# Transpose cache
data_transpose = None
if with_transpose_cache:
if data_transpose is None and with_transpose_cache:
data_transpose = torch.empty(
(data.size(-1), data.numel() // data.size(-1)),
dtype=torch.uint8,
......@@ -172,7 +174,7 @@ class _ToFloat8Func(torch.autograd.Function):
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None
return grad, None, None, None, None, None, None, None, None, None
class _IdentityFunc(torch.autograd.Function):
......@@ -688,10 +690,12 @@ class Float8Tensor(QuantizedTensor):
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: TE_DType = TE_DType.kFloat8E4M3,
data: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
with_transpose_cache: bool = False,
data_transpose: Optional[torch.Tensor] = None,
):
"""Construct Float8Tensor from plain PyTorch tensor"""
return _ToFloat8Func.apply(
......@@ -700,10 +704,12 @@ class Float8Tensor(QuantizedTensor):
fp8_meta_forward,
fp8_meta_index,
fp8_dtype,
data,
scale,
amax,
scale_inv,
with_transpose_cache,
data_transpose,
)
def detach(self) -> Float8Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment