Unverified Commit a3df1d73 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Prototype for operation-based API (#707)



* Add basic infrastructure for Sequential module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 support in linear op

Runs, but need to validate. Runtime errors with non-FP8 params and FP8 compute, or FP8 params and non-FP8 compute.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add reshape op and unit test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add unfused linear op

Test does not pass with FP8.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug unfused linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add test for linear+bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add separate abstract classes for unfused and fused ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Consolidate unfused ops in submodule
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add linear-bias fused op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use fused cast-transpose in linear ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable GEMM+bias fusion with FP32 activations

Not supported by cuBLAS.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add parallel unit test for unfused linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor parallel tests to reduce job launches
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add all-reduce, all-gather, and reduce-scatter ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unused file
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug multi-GPU FP8 test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add support for FP8 scale updates

Still need to implement amax reductions.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add license boilerplate
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fuse GEMM+bias in row TP

Add documentation for unfused ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename pipeline to fuser

Expand documentation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak documentation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Preserve cached FP8 transpose between ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add option for fused wgrad accumulation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Directly output FP8 from linear if needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix cuDNN front-end commit
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use updated FP8 tensor API for transpose caching
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use updated API for FP8 scale updates
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tests for non-default FP8 recipes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename UnfusedOperation to BasicOperation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add unit test to check amax reduction with fusable op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Operator autograd state no longer needs to be initialized
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial functional implementation of linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug fused linear+bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove autograd context from functional linear impl
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use functional linear impl in fused linear+bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename subdirectory from "fuser" to "ops"

Avoid confusion with kernel fusers and graph compilers.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update with Float8Tensor changes in #820
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unnecessary CPU overheads
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Correctly pass FP8 metadata from next op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter errors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add convenience functions to manipulate Sequential class
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Clear saved tensor data in linear op after bprop
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix Pylint error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix test name in QA script
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Run distributed tests even when only 1 GPU is available
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Only run distributed tests with 2 GPUs if there are >=2 GPUs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @sudhakarsingh27 and @ksivaman

Fix spelling of "fusible". Avoid "input" name in internal APIs.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update transformer_engine/pytorch/ops/__init__.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 05977f44
......@@ -22,3 +22,5 @@ pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
This diff is collapsed.
This diff is collapsed.
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
from typing import Optional
from typing import Iterable, Optional
import pytest
import torch
......@@ -15,6 +15,8 @@ from transformer_engine.pytorch.fp8 import (
_amax_and_scale_update,
get_default_fp8_recipe,
)
import transformer_engine.pytorch.ops as te_ops
import transformer_engine_torch as tex
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -33,7 +35,7 @@ class TestFP8Recipe:
@pytest.mark.parametrize("amax_history_len", [31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
@pytest.mark.parametrize("is_first_microbatch", [None, True, False])
def test_amax_and_scale_update(
def test_fp8_scale_update_with_linear_module(
self,
amax_history_len: int,
amax_compute_algo: str,
......@@ -49,7 +51,7 @@ class TestFP8Recipe:
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
with te.fp8_autocast(fp8_recipe=recipe):
module = te.Linear(16, 16)
y = module(
torch.randn([16, 16], device="cuda"),
......@@ -162,6 +164,130 @@ class TestFP8Recipe:
ref_scale_inv_backward[0],
)
@pytest.mark.parametrize("amax_history_len", [31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
def test_fp8_scale_update_with_linear_fuser_op(
self,
amax_history_len: int,
amax_compute_algo: str,
margin: float = 2,
num_steps: int = 4,
in_shape: tuple[int] = (16, 16),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
# Construct linear op
op = te_ops.BasicLinear(in_shape[-1], in_shape[-1])
# Get FP8 meta tensors
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
x_fp8_meta = op.get_fp8_meta("input")[forward_key]
w_fp8_meta = op.get_fp8_meta("param")[forward_key]
dy_fp8_meta = op.get_fp8_meta("grad_output")[backward_key]
# Perform training steps
x_history = []
w_history = []
dy_history = []
for step in range(num_steps):
# Fill tensors with known values
x_history.append(step + 0.25)
w_history.append(step + 0.5)
dy_history.append(step + 0.75)
x = torch.full(
in_shape,
x_history[-1],
dtype=dtype,
device=device,
requires_grad=True,
)
dy = torch.full(
in_shape,
dy_history[-1],
dtype=dtype,
device=device,
)
with torch.no_grad():
op.weight.fill_(w_history[-1])
# Forward and backward pass
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
with te.fp8_autocast(fp8_recipe=recipe):
y = op(x)
y.backward(dy)
def check_amax_history(
fp8_meta: dict,
ref_amax_history: Iterable[float],
) -> None:
"""Check that amax history matches expected values"""
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-amax_history_len:]
ref_amax_history = torch.tensor(
ref_amax_history,
dtype=torch.float32,
device=device,
)
test_amax_history = fp8_meta.amax_history[:, 0]
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(
test_amax_history[-(step + 1) :],
ref_amax_history[: (step + 1)],
**tols,
)
def check_scale(
fp8_meta: dict,
ref_amax_history: Iterable[float],
stage: str,
):
"""Check that scale and scale reciprocal match expected values"""
# Compute amax
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-(amax_history_len + 1) :]
if amax_compute_algo == "max":
ref_amax = max(ref_amax_history)
elif amax_compute_algo == "most_recent":
ref_amax = ref_amax_history[-1]
else:
raise RuntimeError(f"{amax_compute_algo=} is not supported")
# Compute scale
max_val = {
"forward": 448.0,
"backward": 57344.0,
}[stage]
ref_scale = (max_val / ref_amax) / (2**margin)
# Check values in FP8 meta tensors
torch.testing.assert_close(
fp8_meta.scale.item(),
ref_scale,
)
torch.testing.assert_close(
fp8_meta.scale_inv.item(),
1 / ref_scale,
)
# Check that results match expected values
check_amax_history(x_fp8_meta, x_history)
check_amax_history(w_fp8_meta, w_history)
check_amax_history(dy_fp8_meta, dy_history)
check_scale(x_fp8_meta, x_history, "forward")
check_scale(w_fp8_meta, w_history, "forward")
check_scale(dy_fp8_meta, dy_history, "backward")
@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
@pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"])
@pytest.mark.parametrize(
......@@ -191,7 +317,7 @@ class TestFP8Recipe:
# Setup fp8_meta dictionary
def setup_fp8_meta():
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
with te.fp8_autocast(fp8_recipe=recipe):
module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda"))
y.backward(torch.zeros_like(y))
......
......@@ -3,9 +3,11 @@
# See LICENSE for license information.
"""Methods needed for distributed training (DP/TP)."""
import warnings
from __future__ import annotations
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from typing import Any, Dict, Union, Optional, Callable, Tuple, List
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
import torch
from torch.cuda import _lazy_call, _lazy_init
......@@ -829,23 +831,48 @@ def reduce_scatter_along_first_dim(
def gather_along_first_dim(
input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Gather tensors and concatinate along the first dimension."""
input_: torch.Tensor,
process_group: dist_group_type,
async_op: bool = False,
) -> tuple[torch.Tensor, Any]:
"""All-gather tensors and concatenate along first dimension."""
world_size = get_distributed_world_size(tp_group)
# Bypass the function if we are using only 1 GPU.
# Return immediately if no communication is required
world_size = get_distributed_world_size(process_group)
if world_size == 1:
return input_, None
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
# Allocate output tensor
output_shape = list(input_.size())
output_shape[0] *= world_size
if isinstance(input_, Float8Tensor):
output = Float8Tensor.make_like(
input_,
data=torch.empty(
output_shape,
dtype=torch.uint8,
device=input_.device,
),
)
src = input_._data.contiguous()
dst = output._data
else:
output = torch.empty(
output_shape,
dtype=input_.dtype,
device=input_.device,
memory_format=torch.contiguous_format,
)
src = input_.contiguous()
dst = output
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
# Launch all-gather
handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op
dst,
src,
group=process_group,
async_op=async_op,
)
return output, handle
......
......@@ -563,6 +563,23 @@ class Float8Tensor(torch.Tensor):
return _IdentityFunc.apply(self)
return super().expand_as(other)
def contiguous(
self,
*,
memory_format: torch.memory_format = torch.contiguous_format,
) -> Float8Tensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if self._data.is_contiguous(memory_format=memory_format):
return self
return _IdentityFunc.apply(
self,
{"data": self._data.detach().contiguous(memory_format=memory_format)},
)
def transpose_2d(
self,
*,
......@@ -885,6 +902,22 @@ class Float8Tensor(torch.Tensor):
fp8_attrs=args[0]._fp8_attrs,
)
# View op
if func == aten.view.default:
tensor = args[0]
data = tensor._data
data_view = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(
tensor,
data=data_view,
fp8_attrs=tensor._fp8_attrs,
)
def maybe_unwrap(t):
if isinstance(t, Float8Tensor):
return t.from_float8()
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operations.
This operation-based API is experimental and subject to change.
"""
from transformer_engine.pytorch.ops.basic import (
AllGather,
AllReduce,
BasicLinear,
Bias,
Identity,
ReduceScatter,
Reshape,
)
from transformer_engine.pytorch.ops.linear import Linear
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.sequential import Sequential
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Helper functions used in fusible operations."""
from __future__ import annotations
from typing import Any, Iterable, Optional
import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor
def canonicalize_device(device: Optional[torch.device | str]) -> torch.device:
"""Canonicalize PyTorch device
If `None`, then returns the default CUDA device.
"""
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != "cuda":
device = torch.device("cuda", torch.cuda.current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())
return device
def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
"""Canonicalize PyTorch datatype
If `None`, then returns the default PyTorch datatype.
"""
if dtype is None:
# Use default dtype
dtype = torch.get_default_dtype()
return dtype
def devices_match(device1: torch.device, device2: torch.device) -> bool:
"""Whether two devices are the same"""
device1 = torch.device(device1)
device2 = torch.device(device2)
if device1.type != device2.type:
return False
if device1.type == "cuda":
index1 = device1.index
index2 = device2.index
if index1 is None:
index1 = torch.cuda.current_device()
if index2 is None:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2
def is_float8_tensor(tensor: Any) -> bool:
"""Check if object is a `Float8Tensor`"""
return isinstance(tensor, Float8Tensor)
def convert_tensor(
tensor: torch.Tensor | Float8Tensor,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
memory_format: torch.memory_format = torch.preserve_format,
) -> torch.Tensor | Float8Tensor:
"""Convert tensor attributes, keeping same data if possible"""
# Default kwargs
if device is None:
device = tensor.device
device = canonicalize_device(device)
if dtype is None:
dtype = tensor.dtype
dtype = canonicalize_dtype(dtype)
# Make sure output is detached from autograd graph
tensor = tensor.detach()
# Return immediately if tensor already has desired attributes
if devices_match(device, tensor.device) and dtype == tensor.dtype:
if memory_format == torch.preserve_format or tensor.is_contiguous(
memory_format=memory_format
):
return tensor
# Convert FP8 tensor
if is_float8_tensor(tensor):
data = tensor._data.to(device=device, memory_format=memory_format)
return Float8Tensor.make_like(
tensor,
data=data,
fp8_attrs=tensor._fp8_attrs,
dtype=dtype,
)
# Convert standard PyTorch tensor
return tensor.to(device=device, dtype=dtype, memory_format=memory_format)
def reshape(
tensor: torch.Tensor | Float8Tensor,
shape: Iterable[int],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor | Float8Tensor:
"""Reshape tensor, keeping same data if possible
If the input is a Float8Tensor, this function attempts to preserve
the cached transpose if available and valid. If a cached transpose
is present, it is interpreted as the transpose of a 2D matrix
where the width matches the innermost tensor dimension.
"""
# Make sure tensor is in expected format
tensor = convert_tensor(
tensor,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
# Return immediately if tensor already has desired shape
shape = list(shape)
if len(shape) == tensor.dim():
if sum(1 for d in shape if d == -1) > 1:
raise ValueError(
"Attempted to reshape tensor with "
f"shape={tuple(tensor.size())} into shape={tuple(shape)}"
)
if all(d1 == d2 for d1, d2 in zip(shape, tensor.size()) if d1 != -1):
return tensor
# Reshape FP8 tensor
# Note: Preserve cached transpose if possible
if is_float8_tensor(tensor):
out = Float8Tensor.make_like(
tensor,
data=tensor._data.view(shape),
fp8_attrs=tensor._fp8_attrs,
)
return out
# Reshape standard PyTorch tensor
return tensor.view(shape)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Single tensor operations supported by the operation fuser."""
from .all_gather import AllGather
from .all_reduce import AllReduce
from .basic_linear import BasicLinear
from .bias import Bias
from .identity import Identity
from .reduce_scatter import ReduceScatter
from .reshape import Reshape
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for all-gather."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import convert_tensor, is_float8_tensor
class AllGather(BasicOperation):
"""All-gather tensor along outer dimension
Equivalent to gathering tensors from all processes and
concatenating along the first dimension.
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
def __init__(
self,
process_group: Optional[torch.distributed.ProcessGroup] = None,
) -> None:
super().__init__()
self.process_group: Optional[torch.distributed.ProcessGroup] = process_group
self.process_group_size: int = torch.distributed.get_world_size(process_group)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Trivial case
if self.process_group_size == 1:
return input_
# Tensor dimensions
input_dims = input_.size()
if not input_dims:
raise RuntimeError(
"Attempted to all-gather a tensor "
f"with shape={list(input_dims)} "
f"over {self.process_group_size} processes"
)
output_dims = list(input_dims)
output_dims[0] *= self.process_group_size
# Perform all-gather
x = convert_tensor(input_, memory_format=torch.contiguous_format)
y = None
if is_float8_tensor(x):
y = Float8Tensor.make_like(
x,
data=torch.empty(
output_dims,
dtype=torch.uint8,
device=x.device,
),
)
torch.distributed.all_gather_into_tensor(
y._data,
x._data,
group=self.process_group,
)
else:
y = torch.empty(output_dims, dtype=x.dtype, device=x.device)
torch.distributed.all_gather_into_tensor(
y,
x,
group=self.process_group,
)
return y
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Trivial case
if self.process_group_size == 1:
return grad_output, ()
# Tensor dimensions
output_dims = grad_output.size()
if not output_dims or output_dims[0] % self.process_group_size != 0:
raise RuntimeError(
"Attempted to reduce-scatter a tensor "
f"with shape={list(output_dims)} "
f"over {self.process_group_size} processes"
)
input_dims = list(output_dims)
input_dims[0] //= self.process_group_size
# Check output gradient tensor
dy = grad_output
if is_float8_tensor(dy):
dy = dy.from_float8()
dy = dy.contiguous()
# Perform reduce-scatter
dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device)
torch.distributed.reduce_scatter_tensor(
dx,
dy,
group=self.process_group,
)
return dx, ()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for all-reduce."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import is_float8_tensor
class AllReduce(BasicOperation):
"""All-reduce tensor
Equivalent to summing tensors from all processes. It is assumed
that the output is used in operations that are redundantly
computed on all processes, and hence that gradients are identical
between processes.
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
def __init__(
self,
process_group: Optional[torch.distributed.ProcessGroup] = None,
reduce_in_backward: bool = True,
) -> None:
super().__init__()
self.process_group: Optional[torch.distributed.ProcessGroup] = process_group
self._reduce_in_backward: bool = reduce_in_backward
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Trivial case
if torch.distributed.get_world_size(self.process_group) == 1:
return input_
# Perform all-reduce
x = input_
if is_float8_tensor(x):
x = x.from_float8()
x = x.contiguous()
torch.distributed.all_reduce(x, group=self.process_group)
return x
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
return grad_output, ()
This diff is collapsed.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for bias."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import (
canonicalize_device,
canonicalize_dtype,
)
class Bias(BasicOperation):
"""Apply additive bias
This is equivalent to the additive bias in `torch.nn.Linear`.
Parameters
----------
size: int
Inner dimension of input tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel: bool, default = `False`
Whether to distribute input tensor and bias tensors along
inner dimension
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
"""
def __init__(
self,
size: int,
*,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
tensor_parallel: bool = False,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> None:
super().__init__()
# Bias size
self._size = size
# Bias tensor device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
self.device: torch.device = device
# Bias tensor datatype
self.dtype: torch.dtype = canonicalize_dtype(dtype)
# Tensor parallel configuration
tensor_parallel_size = 1
local_size = size
if tensor_parallel:
tensor_parallel_size = torch.distributed.get_world_size(tensor_parallel_group)
tensor_parallel = tensor_parallel_size > 1
if size % tensor_parallel_size != 0:
raise ValueError(
"Invalid configuration for tensor parallelism "
f"({size=}, {tensor_parallel_size=})"
)
local_size //= tensor_parallel_size
else:
tensor_parallel_group = None
self.tensor_parallel: bool = tensor_parallel
self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = tensor_parallel_group
self.tensor_parallel_size: int = tensor_parallel_size
self.local_size: int = local_size
# Initialize parameters if needed
bias = torch.empty(
local_size,
device="meta",
dtype=dtype,
)
bias = torch.nn.Parameter(bias)
self.bias: torch.nn.Parameter
self.register_parameter("bias", bias)
if not defer_param_init:
self.reset_parameters()
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""
# Make sure parameter is initialized
bias = self.bias
if bias.device.type != "cuda":
bias = torch.empty_like(bias, device=self.device)
bias = bias.to(device=self.device, dtype=self.dtype)
# Initialize values
bias.zero_()
# Save updated parameter
if not isinstance(bias, torch.nn.Parameter):
bias = torch.nn.Parameter(bias)
self.bias = bias
def pre_forward(self) -> None:
super().pre_forward()
if self.bias.device.type == "meta":
self.reset_parameters()
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
x = input_
b = self.bias.reshape([1] * (x.dim() - 1) + [self.local_size])
return x + b
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
dy = grad_output
if dy.dim() > 1:
db = dy.sum(tuple(range(dy.dim() - 1)))
else:
db = dy
return dy, (db,)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for identity."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
class Identity(BasicOperation):
"""Return input tensor"""
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
return input_
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
return grad_output, ()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for reduce-scatter."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import convert_tensor, is_float8_tensor
class ReduceScatter(BasicOperation):
"""Reduce-scatter tensor along outer dimension
Equivalent to summing tensors from all processes and splitting
along the first dimension.
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
def __init__(
self,
process_group: Optional[torch.distributed.ProcessGroup] = None,
) -> None:
super().__init__()
self.process_group: Optional[torch.distributed.ProcessGroup] = process_group
self.process_group_size: int = torch.distributed.get_world_size(process_group)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Trivial case
if self.process_group_size == 1:
return input_
# Tensor dimensions
input_dims = input_.size()
if not input_dims or input_dims[0] % self.process_group_size != 0:
raise RuntimeError(
"Attempted to reduce-scatter a tensor "
f"with shape={list(input_dims)} "
f"over {self.process_group_size} processes"
)
output_dims = list(input_dims)
output_dims[0] //= self.process_group_size
# Check input tensor
x = input_
if is_float8_tensor(x):
x = x.from_float8()
x = x.contiguous()
# Perform reduce-scatter
y = torch.empty(output_dims, dtype=x.dtype, device=x.device)
torch.distributed.reduce_scatter_tensor(y, x, group=self.process_group)
return y
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Trivial case
if self.process_group_size == 1:
return grad_output, ()
# Tensor dimensions
output_dims = grad_output.size()
if not output_dims:
raise RuntimeError(
"Attempted to all-gather a tensor "
f"with shape={list(output_dims)} "
f"over {self.process_group_size} processes"
)
input_dims = list(output_dims)
input_dims[0] *= self.process_group_size
# Perform all-gather
dy = convert_tensor(grad_output, memory_format=torch.contiguous_format)
dx = None
if is_float8_tensor(dy):
dx = Float8Tensor.make_like(
dy,
data=torch.empty(
input_dims,
dtype=torch.uint8,
device=dy.device,
),
)
torch.distributed.all_gather_into_tensor(
dx._data,
dy._data,
group=self.process_group,
)
else:
dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device)
torch.distributed.all_gather_into_tensor(
dx,
dy,
group=self.process_group,
)
return dx, ()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for reshape."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import reshape
class Reshape(BasicOperation):
"""Reshape tensor
See `torch.reshape`.
Parameters
----------
shape: iterable of int
Output tensor dimensions. If one dimension is -1, it is
inferred based on input tensor dimensions.
"""
def __init__(self, shape: Iterable[int]) -> None:
super().__init__()
self._shape = tuple(shape)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
ctx.input_shape = input_.size()
return reshape(input_, self._shape)
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
return reshape(grad_output, ctx.input_shape), ()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Compound tensor operation supported by the operation fuser."""
from .linear_bias_activation import (
ForwardLinearBiasActivation,
fuse_forward_linear_bias_activation,
)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused operation for GEMM, bias, activation in the forward pass."""
from __future__ import annotations
from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusedOperation,
FusibleOperation,
OperationContext,
)
class ForwardLinearBiasActivation(FusedOperation):
"""Fused GEMM, bias, activation in the forward pass
Bias and activation are both optional. Row tensor parallelism is
not supported since that requires communication immediately after
the GEMM.
"""
def __init__(
self,
*,
linear: BasicLinear,
bias: Optional[Bias],
activation: None,
) -> None:
# Basic operations that comprise this fused operation
op_idxs = dict(
linear=0,
bias=None,
activation=None,
)
ops = [linear]
if bias is not None:
op_idxs["bias"] = len(ops)
ops.append(bias)
if activation is not None:
op_idxs["activation"] = len(ops)
ops.append(activation)
# Initialize base class
super().__init__(ops)
# Index of each basic operations
self._op_idxs: dict[str, Optional[int]] = op_idxs
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]],
) -> torch.Tensor:
# Get basic operations
idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx]
if self._op_idxs["bias"] is None:
bias_op = None
bias = None
else:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
bias = bias_op.bias
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
if self._op_idxs["activation"] is None:
activation_op = None # pylint: disable=unused-variable
else:
raise NotImplementedError("Activations are not yet supported")
# FP8 metadata
with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled()
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
if with_fp8_compute:
input_fp8_meta = linear_op.get_fp8_meta("input")
weight_fp8_meta = linear_op.get_fp8_meta("param")
next_op = basic_op_next_ops[-1]
if next_op is not None and next_op.num_fp8_scales("input") > 0:
output_fp8_meta = next_op.get_fp8_meta("input")
grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output")
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0:
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output")
# Linear forward
output, x_local, _ = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
bias=bias,
device=linear_op.device,
dtype=linear_op.dtype,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
tensor_parallel_group=linear_op.tensor_parallel_group,
sequence_parallel=linear_op.sequence_parallel,
with_fp8_compute=with_fp8_compute,
input_fp8_meta=input_fp8_meta,
weight_fp8_meta=weight_fp8_meta,
output_fp8_meta=output_fp8_meta,
)
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.with_fp8_compute = with_fp8_compute
linear_op_ctx.weight_fp8_meta = weight_fp8_meta
linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta
linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output
def fuse_forward_linear_bias_activation(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse GEMM, bias, activation in the forward pass
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op1, _ = window[0]
if not isinstance(op1, BasicLinear):
continue
if op1.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
if op1.dtype not in (torch.float16, torch.bfloat16):
# cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output
continue
# Check if second op is bias
op2, _ = ops[0]
if not isinstance(op2, Bias):
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearBiasActivation(
linear=window[0][0],
bias=window[1][0],
activation=None,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Manager class for a pipeline of fusible operations."""
from __future__ import annotations
from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.graph import is_graph_capturing
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusibleOperation,
OperationContext,
)
from transformer_engine.pytorch.ops.fused_forward import (
fuse_forward_linear_bias_activation,
)
class _OperationFuserAutogradFunction(torch.autograd.Function):
"""Autograd function for a pipeline of operations
Autograd must be done at the pipeline level since we may apply
different fusions in the forward and backward passes.
"""
# pylint: disable=unused-argument
@staticmethod
def forward(
func_ctx: torch.autograd.function.FunctionCtx,
input_: torch.Tensor,
forward_ops: list[tuple[FusibleOperation, list[int]]],
backward_ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: list[BasicOperation],
basic_op_kwargs: list[dict[str, Any]],
*params: torch.nn.Parameter,
) -> torch.Tensor:
"""Forward pass
Parameters
----------
func_ctx: torch.autograd.function.FunctionCtx
Context for PyTorch autograd function
input_: torch.Tensor
Input to first operation in pipeline
forward_ops: list of tuple
Forward pass operations and the indices of the
corresponding basic operations. The order should match
basic_ops.
backward_ops: list of tuple
Backward pass operations and the indices of the
corresponding basic operations. The order should be the
reverse of basic_ops.
basic_ops: list of BasicOperation
Basic operations
basic_op_kwargs: list of dict
Keyword arguments to BasicOperation
*params: torch.nn.Parameter
Parameters in operation pipeline
"""
# Operation autograd contexts
basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))]
# Apply forward ops
x = input_
requires_grad = x.requires_grad
for op, basic_op_idxs in forward_ops:
# Forward op
prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
next_ops = [
basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs
]
x = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
x,
prev_ops,
next_ops,
[basic_op_kwargs[idx] for idx in basic_op_idxs],
)
# Check if backward op is required
if not requires_grad:
requires_grad = any(param.requires_grad for param in op.parameters())
for idx in basic_op_idxs:
basic_op_ctxs[idx]._requires_grad = requires_grad
x.requires_grad_(requires_grad=requires_grad)
# Flatten list of saved tensors
to_save = []
for ctx in basic_op_ctxs:
range_start = len(to_save)
if ctx.to_save is not None:
to_save.extend(ctx.to_save)
range_end = len(to_save)
ctx.to_save = None
ctx._saved_tensors_range = (range_start, range_end)
func_ctx.save_for_backward(*to_save)
# Other context for backward pass
func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
return x
@staticmethod
@torch.autograd.function.once_differentiable
def backward(
func_ctx: Any,
grad_output: torch.Tensor,
) -> tuple[Optional[torch.Tensor], ...]:
"""Backward pass"""
# Operations and autograd state
backward_ops = func_ctx.backward_ops
basic_ops = func_ctx.basic_ops
basic_op_ctxs = func_ctx.basic_op_ctxs
# Unflatten list of saved tensors
saved_tensors = func_ctx.saved_tensors
for ctx in basic_op_ctxs:
ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
ctx._saved_tensors_range = None
del saved_tensors
# Apply backward ops
dx = grad_output
grad_params = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in backward_ops:
# Stop if no more gradients are required
if all(not basic_op_ctxs[idx]._requires_grad for idx in basic_op_idxs):
dx = None
break
# Backward op
dx, fused_op_dparams = op.fuser_backward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
dx,
)
for idx, basic_op_dparams in zip(basic_op_idxs, fused_op_dparams):
grad_params[idx] = basic_op_dparams
basic_op_ctxs[idx].saved_tensors = None
# Flatten list of parameter gradients
grad_params_flat = []
for idx, dparams in enumerate(grad_params):
params = list(basic_ops[idx].parameters())
if dparams is None:
dparams = [None for _ in range(len(params))]
else:
dparams = list(dparams)
if len(dparams) != len(params):
raise RuntimeError(
f"Expected op {idx} to generate {len(params)} param grads, "
f"but got {len(dparams)}"
)
grad_params_flat.extend(dparams)
# Update FP8 scaling factors
if func_ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
dx, # input_
None, # forward_ops
None, # backward_ops
None, # basic_ops
None, # basic_op_kwargs
*grad_params_flat, # params
)
class OperationFuser:
"""Manages forward and backward passes for a pipeline of operations
Parameters
----------
ops: list of FusibleOperation
Pipeline of operations
fuse_ops: bool, default = `True`
Whether to attempt fusing operations
"""
def __init__(
self,
ops: list[FusibleOperation],
fuse_ops: bool = True,
) -> None:
# Get list of basic operations
basic_ops = []
for op in ops:
if op.is_fused_op:
basic_ops.extend(op.basic_ops)
else:
basic_ops.append(op)
self._num_basic_ops: int = len(basic_ops)
self._basic_ops: list[BasicOperation] = basic_ops
# Ops for forward and backward pass
self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]]
self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
self._backward_ops = list(reversed(self._forward_ops))
# Fuse ops if needed
if fuse_ops:
self.fuse_ops()
@classmethod
def _fuse_forward_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass"""
ops = fuse_forward_linear_bias_activation(ops)
return ops
@classmethod
def _fuse_backward_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass"""
return ops
def fuse_ops(self) -> None:
"""Attempt to fuse operations"""
self._forward_ops = self._fuse_forward_ops(self._forward_ops)
self._backward_ops = self._fuse_backward_ops(self._backward_ops)
def __call__(
self,
input: torch.Tensor, # pylint: disable=redefined-builtin
basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor:
# Initialization before forward pass
for op in self._basic_ops:
op.pre_forward()
# Canonicalize op kwargs
if basic_op_kwargs is None:
basic_op_kwargs = [{} for _ in range(len(self._basic_ops))]
# Flatten list of parameters
params = []
for op in self._basic_ops:
params.extend(op.parameters())
# Fuser forward pass
return _OperationFuserAutogradFunction.apply(
input,
self._forward_ops,
self._backward_ops,
self._basic_ops,
basic_op_kwargs,
*params,
)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for linear layer."""
from __future__ import annotations
from collections.abc import Callable
from typing import Optional
import torch
from transformer_engine.pytorch.ops.basic import (
AllReduce,
BasicLinear,
Bias,
ReduceScatter,
)
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.ops.op import FusedOperation
class Linear(FusedOperation):
"""Apply linear transformation: :math:`y = x A^T + b`
This is a drop-in replacement for `torch.nn.Linear`.
Parameters
----------
in_features: int
Inner dimension of input tensor
out_features: int
Inner dimension of output tensor
bias: bool, default = `True`
Apply additive bias
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim)
rng_state_tracker_function: callable
Function that returns CudaRNGStatesTracker, which is used for
model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False`
Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
"""
def __init__(
self,
in_features: int,
out_features: int,
*,
bias: bool = True,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False,
rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
accumulate_into_main_grad: bool = False,
) -> None:
# Tensor parallel configuration
(
tensor_parallel_mode,
tensor_parallel_group,
tensor_parallel_size,
sequence_parallel,
local_in_features,
local_out_features,
) = BasicLinear._canonicalize_tensor_parallelism(
mode=tensor_parallel_mode,
process_group=tensor_parallel_group,
sequence_parallel=sequence_parallel,
in_features=in_features,
out_features=out_features,
)
# Construct basic ops
ops = []
linear_kwargs = dict(
in_features=in_features,
out_features=out_features,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=tensor_parallel_group,
sequence_parallel=sequence_parallel,
rng_state_tracker_function=rng_state_tracker_function,
accumulate_into_main_grad=accumulate_into_main_grad,
)
bias_kwargs = dict(
size=out_features,
device=device,
dtype=dtype,
tensor_parallel=(tensor_parallel_mode is not None),
tensor_parallel_group=tensor_parallel_group,
)
if tensor_parallel_mode == "row":
# Row TP: GEMM + bias + reduction
linear_kwargs["in_features"] = local_in_features
linear_kwargs["out_features"] = local_out_features
linear_kwargs["tensor_parallel_mode"] = None
linear_kwargs["tensor_parallel_group"] = None
linear_kwargs["sequence_parallel"] = False
bias_kwargs["size"] *= tensor_parallel_size
ops.append(BasicLinear(**linear_kwargs))
if bias:
ops.append(Bias(**bias_kwargs))
if sequence_parallel:
ops.append(ReduceScatter(tensor_parallel_group))
else:
ops.append(AllReduce(tensor_parallel_group))
else:
# Column TP or no TP: (gather + GEMM) + bias
ops.append(BasicLinear(**linear_kwargs))
if bias:
ops.append(Bias(**bias_kwargs))
# Initialize base class
super().__init__(ops)
# Register parameters
self.register_parameter("weight", self.basic_ops[0].weight)
self.register_parameter("bias", self.basic_ops[1].bias if bias else None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment