"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "b0ad8ef016862d044d7d13926dffbd45240bf581"
Unverified Commit 095b27d0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Userbuffers support in operation-based API (#1142)



* Add Userbuffers support for column TP linear layer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add Userbuffers support for row TP linear layer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Interpret linear+RS as row TP linear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add Userbuffers support for FP8 row TP linear layer

Assumes FP8 RS, which is not a good assumption.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug bug with incorrect bias pointers in UB GEMM

Bias pointers are not properly offset for different data chunks. Also removed logic for FP8 RS.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add Userbuffers support for linear dgrad

Test passes with row TP, fails with col TP.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add Userbuffers support for linear wgrad
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add support for grad bias
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fused cast-transpose-dbias
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support case where wgrad is optional
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand documentation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use recently added convenience functions in Float8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Respect autograd dtype
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix missing imports
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Respect PyT autocast dtype in bprop
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug merge conflicts
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 77c37d49
......@@ -10,4 +10,5 @@ pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import argparse
import dataclasses
import functools
import itertools
import os
import pathlib
import subprocess
import sys
import pytest
import torch
import transformer_engine
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
)
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, str_to_dtype
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
# Check if there are multiple GPUs
if torch.cuda.device_count() < 2:
pytest.skip("Userbuffers requires at least 2 GPUs.")
@dataclasses.dataclass
class ModelConfig:
"""Tensor dimensions in Transformer model"""
sequence_length: int
batch_size: int
num_heads: int
head_dim: int
dtype: torch.dtype
fp8: bool
@property
def hidden_size(self):
return self.num_heads * self.head_dim
@functools.cache
def launcher() -> str:
"""Launcher for current parallel job"""
if "OMPI_COMM_WORLD_SIZE" in os.environ:
return "ompi"
if "TORCHELASTIC_RUN_ID" in os.environ:
return "torchrun"
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`")
@functools.cache
def world_group() -> torch.distributed.ProcessGroup:
"""Get NCCL process group, initializing if needed"""
# Get launch config from environment
if launcher() == "ompi":
# OpenMPI
world_size = int(os.getenv("OMPI_COMM_WORLD_SIZE"))
rank = int(os.getenv("OMPI_COMM_WORLD_RANK"))
local_size = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE"))
local_rank = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK"))
elif launcher() == "torchrun":
# torchrun
world_size = int(os.getenv("WORLD_SIZE"))
rank = int(os.getenv("RANK"))
local_size = int(os.getenv("LOCAL_WORLD_SIZE"))
local_rank = int(os.getenv("LOCAL_RANK"))
else:
raise RuntimeError("Unexpected launcher ({launcher()})")
# Construct communicator
assert local_size == world_size
torch.cuda.set_device(local_rank)
group = torch.distributed.init_process_group(
"nccl",
init_method="file:///tmp/rdzv",
world_size=world_size,
rank=rank,
device_id=torch.device(f"cuda:{local_rank}"),
)
return group
def reset_rng(seed: int = 1234) -> None:
"""Reset random number generators"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
"""
# Random data
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor
if test_is_fp8:
test = Float8Tensor.to_float8(ref)
else:
test = ref.to(device=test_device, dtype=test_dtype)
if test.data_ptr() == ref.data_ptr():
test = test.clone()
# Make sure reference and test tensors represent exact same values
ref.copy_(test)
# Return reference and test tensors
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def _test_linear(
*,
model_config: ModelConfig,
bias: bool = False,
device: torch.device = "cuda",
tensor_parallel_mode: str = "column",
sequence_parallel: bool = True,
weight_requires_grad: bool = True,
) -> None:
dtype = model_config.dtype
fp8_compute = model_config.fp8
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
out_features = model_config.hidden_size
in_features = model_config.hidden_size
batch_size = model_config.sequence_length * model_config.batch_size
in_shape = [batch_size, in_features]
out_shape = [batch_size, out_features]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
)
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
bias_shape = [world_size, out_features]
else:
bias_shape = [out_features]
b_ref, b_test = make_reference_and_test_tensors(
bias_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
if bias:
if tensor_parallel_mode == "row":
y_ref += b_ref.sum(dim=0)
else:
y_ref += b_ref
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dw_ref = w_ref.grad
db_ref = b_ref.grad if bias else None
dx_ref = x_ref.grad
if tensor_parallel_mode == "column":
local_out_features = out_features // world_size
local_slice = slice(
rank * local_out_features,
(rank + 1) * local_out_features,
)
w_ref = w_ref[local_slice, :]
dw_ref = dw_ref[local_slice, :]
w_test = w_test[local_slice, :]
if bias:
b_ref = b_ref[local_slice]
db_ref = db_ref[local_slice]
b_test = b_test[local_slice]
y_ref = y_ref[..., local_slice]
dy_ref = dy_ref[..., local_slice]
dy_test = dy_test[..., local_slice].clone()
elif tensor_parallel_mode == "row":
local_in_features = in_features // world_size
local_slice = slice(
rank * local_in_features,
(rank + 1) * local_in_features,
)
w_ref = w_ref[:, local_slice]
dw_ref = dw_ref[:, local_slice]
w_test = w_test[:, local_slice]
if bias:
b_ref = b_ref[rank, :]
db_ref = db_ref[rank, :]
b_test = b_test[rank, :]
x_ref = x_ref[..., local_slice]
dx_ref = dx_ref[..., local_slice]
x_test = x_test[..., local_slice].clone()
if sequence_parallel:
local_batch_size = batch_size // world_size
local_slice = slice(
rank * local_batch_size,
(rank + 1) * local_batch_size,
)
if tensor_parallel_mode == "column":
x_ref = x_ref[local_slice, ...]
dx_ref = dx_ref[local_slice, ...]
x_test = x_test[local_slice, ...].clone()
elif tensor_parallel_mode == "row":
y_ref = y_ref[local_slice, ...]
dy_ref = dy_ref[local_slice, ...]
dy_test = dy_test[local_slice, ...].clone()
x_test.requires_grad_()
# Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_compute):
ops = []
linear_op = None
bias_op = None
if tensor_parallel_mode == "column":
userbuffers_options = {}
if not weight_requires_grad:
if fp8_compute:
userbuffers_options["comm_name"] = "fc1"
else:
# There is a correctness bug with overlapping
# dgrad reduce-scatter with dgrad GEMM. Fall back
# to overlapping dgrad reduce-scatter with wgrad
# GEMM, even though wgrad isn't needed.
userbuffers_options["comm_name"] = "qkv"
else:
userbuffers_options["comm_name"] = "qkv"
linear_op = te_ops.BasicLinear(
in_features,
out_features,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=process_group,
sequence_parallel=sequence_parallel,
userbuffers_options=userbuffers_options,
)
ops.append(linear_op)
if bias:
bias_op = te_ops.Bias(
out_features // world_size,
device=device,
dtype=dtype,
)
ops.append(bias_op)
elif tensor_parallel_mode == "row":
userbuffers_options = dict(comm_name="proj")
linear_op = te_ops.BasicLinear(
in_features // world_size,
out_features,
device=device,
dtype=dtype,
userbuffers_options=userbuffers_options,
)
ops.append(linear_op)
if bias:
bias_op = te_ops.Bias(out_features, device=device, dtype=dtype)
ops.append(bias_op)
ops.append(te_ops.ReduceScatter(process_group))
model = te_ops.Sequential(*ops)
with torch.no_grad():
linear_op.weight.copy_(w_test)
linear_op.weight.requires_grad_(requires_grad=weight_requires_grad)
if bias:
bias_op.bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=fp8_compute):
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
backward_ops = model._module_groups[0]._backward_ops
assert len(forward_ops) == 1
assert len(backward_ops) == 1
assert isinstance(forward_ops[0][0], UserbuffersForwardLinear)
assert isinstance(backward_ops[0][0], UserbuffersBackwardLinear)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
tols = dtype_tols(
model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight)
else tex.DType.kFloat8E4M3
)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
if weight_requires_grad:
dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dw_test, dw_ref, **tols)
if bias:
db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, db_ref, **tols)
def run_parallel_tests(model_config: ModelConfig) -> None:
"""Run parallel tests"""
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Linear op
for test_config in itertools.product(
(False, True), # bias
("column", "row"), # tensor_parallel_mode
(False, True), # weight_requires_grad
):
if rank == 0:
print(f"Running _test_linear with {test_config=}")
bias, tensor_parallel_mode, weight_requires_grad = test_config
_test_linear(
model_config=model_config,
bias=bias,
tensor_parallel_mode=tensor_parallel_mode,
weight_requires_grad=weight_requires_grad,
)
# Parallel job sizes
_world_sizes = []
if torch.cuda.device_count() > 1:
_world_sizes.append(torch.cuda.device_count())
@pytest.mark.parametrize("world_size", _world_sizes)
@pytest.mark.parametrize("fp8", (False, True))
def test_fuser_ops_with_userbuffers(
*,
world_size: int,
dtype: torch.dtype = torch.bfloat16,
fp8: bool,
) -> None:
"""Launch parallel job and run tests"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Parallel job launcher
command = []
if tex.ubuf_built_with_mpi():
python_exe = pathlib.Path(sys.executable).resolve()
command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe))
else:
command.extend(("torchrun", f"--nproc_per_node={world_size}"))
# Script invocation
command.extend(
(
_current_file,
"--parallel",
"--batch-size",
str(world_size),
"--num-heads",
str(world_size),
"--dtype",
str(dtype),
)
)
if fp8:
command.append("--fp8")
# Environment
env = dict(os.environ)
if not tex.device_supports_multicast():
env["UB_SKIPMC"] = "1"
env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
env["PYTORCH_JIT"] = "0"
env["NVTE_TORCH_COMPILE"] = "0"
env["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
# Launch parallel job
result = subprocess.run(command, check=True, env=env)
def main() -> None:
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
parser.add_argument("--sequence-length", type=int, default=32)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--head-dim", type=int, default=32)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--fp8", action="store_true")
args = parser.parse_args()
# Run parallel tests if needed
if args.parallel:
# Model config
model_config = ModelConfig(
sequence_length=args.sequence_length,
batch_size=args.batch_size,
num_heads=args.num_heads,
head_dim=args.head_dim,
dtype=str_to_dtype(args.dtype),
fp8=args.fp8,
)
# Initialize Userbuffers
group = world_group() # Initialize NCCL
bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl"
userbuffer_configs = {
"fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM
}
te.module.base.initialize_ub(
[
model_config.sequence_length * model_config.batch_size,
model_config.num_heads * model_config.head_dim,
],
torch.distributed.get_world_size(group),
use_fp8=model_config.fp8,
dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs,
)
# Run tests
run_parallel_tests(model_config)
# Clean up
te.module.base.destroy_ub()
if __name__ == "__main__":
main()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import torch
import transformer_engine
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
"""Convert type name to PyTorch dtype"""
if isinstance(dtype, torch.dtype):
return dtype
name = str(dtype).strip().lower()
if name.startswith("torch."):
name = name.replace("torch.", "", 1)
if name.startswith("fp"):
name = name.replace("fp", "float", 1)
dtype = dict(
float32=torch.float32,
float=torch.float32,
float64=torch.float64,
double=torch.float64,
float16=torch.float16,
half=torch.float16,
bfloat16=torch.bfloat16,
bf16=torch.bfloat16,
float8_e4m3fn=torch.float8_e4m3fn,
float8_e4m3=torch.float8_e4m3fn,
float8e4m3=torch.float8_e4m3fn,
float8=torch.float8_e4m3fn,
float8_e5m2=torch.float8_e5m2,
float8e5m2=torch.float8_e5m2,
uint8=torch.uint8,
byte=torch.uint8,
int8=torch.int8,
char=torch.int8,
int16=torch.int16,
short=torch.int16,
int32=torch.int32,
int=torch.int32,
int64=torch.int64,
long=torch.int64,
bool=torch.bool,
)[name]
return dtype
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
tex.DType.kFloat8E5M2: torch.float8_e5m2,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
if dtype == torch.float8_e4m3fn:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
raise ValueError(f"Unsupported dtype ({dtype})")
......@@ -314,11 +314,13 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
size_t m_chunk = m / _num_splits;
size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk;
size_t bias_chunk_size = m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.dptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.dptr());
char *bias_chunk_ptr = reinterpret_cast<char *>(bias.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
......@@ -337,16 +339,21 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk =
TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr);
auto bias_chunk =
TensorWrapper(bias.dptr(), {m_chunk}, bias.dtype(), nullptr, nullptr, nullptr);
auto workspace_chunk = TensorWrapper(
workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * D.element_size();
if (bias_chunk_ptr != nullptr) {
bias_chunk_ptr += bias_chunk_size * bias.element_size();
}
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
......@@ -354,10 +361,12 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
A.dtype(), nullptr, nullptr, A.scale_inv());
output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr), {n, m_chunk},
D.dtype(), D.amax(), D.scale(), nullptr);
bias_chunk = TensorWrapper(reinterpret_cast<void *>(bias_chunk_ptr), {m_chunk}, bias.dtype(),
nullptr, nullptr, nullptr);
workspace_chunk = TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -409,11 +418,13 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr),
{n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr);
auto bias_chunk = TensorWrapper(reinterpret_cast<void *>(bias_chunk_ptr), {m_chunk},
bias.dtype(), nullptr, nullptr, nullptr);
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -440,6 +451,9 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
rs_output_ptr += m_chunk * rs_output.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
if (bias_chunk_ptr != nullptr) {
bias_chunk_ptr += bias_chunk_size * bias.element_size();
}
}
}
......
......@@ -9,6 +9,8 @@ from typing import Any, Iterable, Optional
import torch
from transformer_engine_torch import FP8TensorMeta
from ..fp8 import FP8GlobalStateManager
from ..tensor import Float8Tensor
from ..utils import (
canonicalize_device, # pylint: disable=unused-import
......@@ -134,3 +136,25 @@ def maybe_autocast_dtype(
if torch.is_autocast_enabled(device_type):
return torch.get_autocast_dtype(device_type)
return canonicalize_dtype(default_dtype)
def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, int]:
"""Get FP8TensorMeta object and index corresponding to Float8Tensor
Constructs FP8TensorMeta if needed.
"""
# Check if tensor already has FP8 metadata
if tensor._fp8_meta is not None:
key = FP8GlobalStateManager.get_meta_tensor_key(
forward=tensor._fp8_meta_forward,
)
return tensor._fp8_meta[key], tensor._fp8_meta_index
# Create FP8TensorMeta class
fp8_meta = FP8TensorMeta()
fp8_meta.scale = tensor._scale_inv.reciprocal()
fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=tensor.device)
fp8_meta.scale_inv = tensor._scale_inv
return fp8_meta, 0
......@@ -83,6 +83,10 @@ class BasicLinear(BasicOperation):
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
experimental.
"""
......@@ -98,6 +102,7 @@ class BasicLinear(BasicOperation):
sequence_parallel: bool = False,
rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
accumulate_into_main_grad: bool = False,
userbuffers_options: Optional[dict[str, Any]] = None,
) -> None:
super().__init__()
......@@ -143,7 +148,7 @@ class BasicLinear(BasicOperation):
)
# Whether weight tensor is natively in FP8
self._with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self._with_fp8_parameters: bool = FP8GlobalStateManager.with_fp8_parameters()
if self._with_fp8_parameters:
self._fp8_metas = self._make_fp8_metas()
......@@ -163,7 +168,10 @@ class BasicLinear(BasicOperation):
self.reset_parameters()
# Whether to accumulate weight gradient into main_grad
self._accumulate_into_main_grad = accumulate_into_main_grad
self._accumulate_into_main_grad: bool = accumulate_into_main_grad
# Userbuffers options
self._userbuffers_options: Optional[dict[str, Any]] = userbuffers_options
@classmethod
def _canonicalize_tensor_parallelism(
......@@ -707,7 +715,7 @@ class BasicLinear(BasicOperation):
FP8 metadata for casting loss gradient w.r.t. output
tensor to FP8. Required if output grad is not already in
FP8.
grad_output_fp8_meta: dict, optional
grad_input_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8
......
......@@ -16,3 +16,11 @@ from .forward_linear_bias_add import (
ForwardLinearBiasAdd,
fuse_forward_linear_bias_add,
)
from .userbuffers_backward_linear import (
UserbuffersBackwardLinear,
fuse_userbuffers_backward_linear,
)
from .userbuffers_forward_linear import (
UserbuffersForwardLinear,
fuse_userbuffers_forward_linear,
)
......@@ -20,6 +20,8 @@ from transformer_engine.pytorch.ops.fused import (
fuse_backward_linear_add,
fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
......@@ -345,6 +347,7 @@ class OperationFuser:
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass"""
ops = fuse_userbuffers_forward_linear(ops)
ops = fuse_forward_linear_bias_add(ops)
ops = fuse_forward_linear_bias_activation(ops)
return ops
......@@ -355,6 +358,7 @@ class OperationFuser:
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops)
return ops
......
......@@ -109,10 +109,12 @@ class _ToFloat8Func(torch.autograd.Function):
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: TE_DType = TE_DType.kFloat8E4M3,
data: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
with_transpose_cache: bool = False,
data_transpose: Optional[torch.Tensor] = None,
) -> Float8Tensor:
# pylint: disable=missing-function-docstring
......@@ -125,7 +127,8 @@ class _ToFloat8Func(torch.autograd.Function):
device = torch.device("cuda")
# FP8 data buffer
data = torch.empty(tensor.size(), dtype=torch.uint8, device=device)
if data is None:
data = torch.empty(tensor.size(), dtype=torch.uint8, device=device)
# Check scale
if scale is None and fp8_meta is None:
......@@ -140,8 +143,7 @@ class _ToFloat8Func(torch.autograd.Function):
scale_inv = scale_inv.to(device=device, dtype=torch.float32)
# Transpose cache
data_transpose = None
if with_transpose_cache:
if data_transpose is None and with_transpose_cache:
data_transpose = torch.empty(
(data.size(-1), data.numel() // data.size(-1)),
dtype=torch.uint8,
......@@ -172,7 +174,7 @@ class _ToFloat8Func(torch.autograd.Function):
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None
return grad, None, None, None, None, None, None, None, None, None
class _IdentityFunc(torch.autograd.Function):
......@@ -688,10 +690,12 @@ class Float8Tensor(QuantizedTensor):
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: TE_DType = TE_DType.kFloat8E4M3,
data: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
with_transpose_cache: bool = False,
data_transpose: Optional[torch.Tensor] = None,
):
"""Construct Float8Tensor from plain PyTorch tensor"""
return _ToFloat8Func.apply(
......@@ -700,10 +704,12 @@ class Float8Tensor(QuantizedTensor):
fp8_meta_forward,
fp8_meta_index,
fp8_dtype,
data,
scale,
amax,
scale_inv,
with_transpose_cache,
data_transpose,
)
def detach(self) -> Float8Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment