Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
......@@ -109,6 +109,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
}
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
opts = opts.dtype(torch::kFloat32);
// TODO: Replace with an empty tensor.
at::Tensor scale_inv = at::reciprocal(scale);
py::object ret;
if (internal) {
......@@ -250,6 +251,140 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)};
}
Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>();
this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast<int>();
this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>();
this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim.");
}
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {
// Change the rowwise and columnwise_data to the configured dtype.
// May be a switch between E5M2 and E4M3.
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
using namespace pybind11::literals;
std::vector<int64_t> torch_shape;
size_t numel = 1;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
numel *= s;
}
TensorWrapper tensor(this->get_scaling_mode());
at::TensorOptions opts;
at::TensorOptions scale_opts;
at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back();
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data_rowwise = std::move(*rowwise_data);
} else {
data_rowwise = at::empty(torch_shape, opts);
}
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(m_dim, 4);
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_inv_rowwise =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32,
std::vector<size_t>{sinv0, sinv1});
}
if (columnwise_usage) {
std::vector<int64_t> torch_columnwise_shape;
std::vector<size_t> columnwise_shape;
NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ",
columnwise_shape, " torch shape: ", torch_columnwise_shape);
if (torch_shape.size() > 0) {
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
}
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(k_dim, 4);
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
data_colwise = at::empty(torch_columnwise_shape, opts);
scale_inv_colwise =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape);
tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32,
std::vector<size_t>{sinv0, sinv1});
}
this->set_quantization_params(&tensor);
py::object ret;
if (internal) {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass));
ret = Float8BlockwiseQTensorClass(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
"is_2D_scaled"_a = (block_scaling_dim == 2));
} else {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
ret = Float8BlockwiseQTensorClass(
"shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
"columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
"columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2));
}
return {std::move(tensor), std::move(ret)};
}
......@@ -302,7 +437,8 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4);
rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
tensor.set_rowwise_scale_inv(
rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
}
......@@ -313,7 +449,8 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
tensor.set_columnwise_scale_inv(
columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
}
this->set_quantization_params(&tensor);
......
......@@ -6,14 +6,16 @@
#include "extensions.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) {
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input,
bool rowwise) {
using namespace transformer_engine::pytorch;
if (input.scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) {
return;
return std::nullopt;
}
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");
......@@ -48,9 +50,9 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww
output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
} else {
input_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
input_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
output_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0,
scale_inv_shape);
}
......@@ -63,6 +65,8 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
}
return swizzled_scale_inv;
}
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
......
......@@ -6,27 +6,38 @@
#include <optional>
#include "ATen/core/TensorBody.h"
#include "extensions.h"
#include "pybind.h"
std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
std::optional<std::vector<py::handle>> output_list,
namespace transformer_engine::pytorch {
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list,
transformer_engine::DType otype) {
using namespace transformer_engine::pytorch;
init_extension();
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
std::vector<py::object> py_output_objects_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto none = py::none();
if (output_list.has_value()) {
py_output_objects_list = output_list.value();
}
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool with_fused_kernel = true;
// create TE tensors from input
for (int i = 0; i < input_list.size(); i++) {
auto input_tensor = makeTransformerEngineTensor(input_list[i], none);
for (size_t i = 0; i < input_list.size(); i++) {
auto input_tensor = makeTransformerEngineTensor(input_list[i]);
const NVTEShape input_shape = input_tensor.shape();
transformer_engine::TensorWrapper output_tensor;
if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) {
with_fused_kernel = false;
}
if (output_list == std::nullopt) {
std::unique_ptr<Quantizer> quantizer = convert_quantizer(quantizer_list[i]);
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
......@@ -48,16 +59,8 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(),
"Number of input and output tensors must match");
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool with_fused_kernel = true;
for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) {
const auto& tensor = nvte_tensor_output_list[i];
if (nvte_tensor_scaling_mode(tensor) != NVTE_DELAYED_TENSOR_SCALING) {
with_fused_kernel = false;
break;
}
if (nvte_tensor_columnwise_data(tensor) == nullptr) {
if (nvte_tensor_columnwise_data(nvte_tensor_output_list[i]) == nullptr) {
with_fused_kernel = false;
break;
}
......@@ -68,9 +71,8 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
} else {
for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) {
nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i],
at::cuda::getCurrentCUDAStream());
for (size_t i = 0; i < py_output_objects_list.size(); i++) {
quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt);
}
}
return py_output_objects_list;
......@@ -78,7 +80,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
std::optional<at::Tensor> output) {
using namespace transformer_engine::pytorch;
init_extension();
const auto dim = input.dim();
NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose.");
......@@ -105,3 +107,5 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
return out;
}
} // namespace transformer_engine::pytorch
......@@ -84,6 +84,38 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer)
return ret;
}
TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantizer) {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
if (rowwise_usage) {
const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr();
const auto &rowwise_shape = getTensorShape(data_rowwise);
ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape);
const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape);
}
if (columnwise_usage) {
const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr();
const auto &shape = getTensorShape(data_colwise);
ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape);
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
}
quantizer->set_quantization_params(&ret);
return ret;
}
} // namespace detail
} // namespace transformer_engine::pytorch
......@@ -25,6 +25,9 @@ extern PyTypeObject *Float8CurrentScalingQuantizerClass;
extern PyTypeObject *MXFP8TensorPythonClass;
extern PyTypeObject *MXFP8TensorBasePythonClass;
extern PyTypeObject *MXFP8QuantizerClass;
extern PyTypeObject *Float8BlockwiseQTensorPythonClass;
extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass;
extern PyTypeObject *Float8BlockwiseQuantizerClass;
void init_extension();
......@@ -50,6 +53,15 @@ inline bool IsMXFP8Tensor(PyObject *obj) {
return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass;
}
inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) {
return Py_TYPE(obj) == Float8BlockwiseQuantizerClass;
}
inline bool IsFloat8BlockwiseQTensor(PyObject *obj) {
return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass ||
Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass;
}
TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer);
template <typename T>
......@@ -61,6 +73,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizati
std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params);
TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor,
Quantizer *quantization_params);
inline bool IsFloatingPointType(at::ScalarType type) {
return type == at::kFloat || type == at::kHalf || type == at::kBFloat16;
}
......@@ -71,7 +86,9 @@ constexpr std::array custom_types_converters = {
std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor,
CreateQuantizer<Float8CurrentScalingQuantizer>),
std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor,
CreateQuantizer<MXFP8Quantizer>)};
CreateQuantizer<MXFP8Quantizer>),
std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers,
NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer<Float8BlockQuantizer>)};
} // namespace detail
......
......@@ -7,6 +7,19 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
#include <torch/extension.h>
#include <optional>
#include "transformer_engine/transformer_engine.h"
bool non_tn_fp8_gemm_supported();
/* Swizzle the scaling factor of the input tensor.
*
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
*/
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input,
bool trans);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
......@@ -19,15 +19,24 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from .utils import safely_set_viewless_tensor_data
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
try:
import torch.distributed._symmetric_memory as symm_mem
HAS_TORCH_SYMMETRIC = True
except ImportError:
HAS_TORCH_SYMMETRIC = False
__all__ = ["checkpoint", "CudaRNGStatesTracker"]
......@@ -660,6 +669,9 @@ def checkpoint(
**kwargs,
)
from .module.base import TransformerEngineBaseModule
if isinstance(function, TransformerEngineBaseModule):
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
setattr(function, "fsdp_wrapped", False)
......@@ -860,23 +872,29 @@ def _all_gather_fp8(
process_group: dist_group_type,
*,
async_op: bool = False,
quantizer: Optional[Float8Quantizer] = None,
quantizer: Optional[Quantizer] = None,
out_shape: Optional[list[int]] = None,
) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]:
"""All-gather FP8 tensor along first dimension."""
world_size = get_distributed_world_size(process_group)
# Check that quantizer is valid
if quantizer is not None and not isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
raise ValueError(f"Got non-FP8 quantizer ({quantizer.__class__.__name__})")
# Output tensor dims
if out_shape is None:
out_shape = list(inp.size())
out_shape[0] *= world_size
# Quantize input tensor if needed
# Cast input tensor to FP8 if needed
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.
if not isinstance(inp, Float8TensorBase):
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
if quantizer is None:
raise ValueError("Input tensor is not FP8 and no quantizer was provided")
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -888,7 +906,7 @@ def _all_gather_fp8(
# Construct output tensor
out: Float8TensorBase
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if quantizer is not None:
dtype = torch.float32
device = "cuda"
if isinstance(inp, Float8Tensor):
......@@ -906,9 +924,8 @@ def _all_gather_fp8(
out._transpose_invalid = True
else:
raise RuntimeError("FP8TensorBase is not supported yet without Quantizer")
# For delayed scaling, scale_inv is from history, so we can pass it from inp to out
# For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv,
# so we can just pass it from inp to out
# Assume scaling factors are identical across ranks
out._scale_inv = inp._scale_inv
# Perform communication
......@@ -920,17 +937,86 @@ def _all_gather_fp8(
)
# Make sure FP8 transpose is populated if needed
if out._transpose is not None:
needs_transpose = (
quantizer is not None and quantizer.columnwise_usage and not non_tn_fp8_gemm_supported()
)
if needs_transpose:
if handle is not None:
handle.wait()
handle = None
if not isinstance(out, Float8Tensor):
raise RuntimeError("FP8TensorBase does not support FP8 transpose yet")
out._create_transpose()
return out, handle
def _all_gather_fp8_blockwise(
inp: torch.Tensor,
process_group: dist_group_type,
*,
async_op: bool = False, # pylint: disable=unused-argument
quantizer: Optional[Quantizer] = None,
out_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
"""
All-gather FP8 tensor along first dimension for blockwise quantization.
Returns: quantizer(gather(inp))
NOTE: The implementation is not sophisticated enough to honor async_op=True.
In some cases it falls back to synchronous gather and invokes the quantizer.
"""
# Input tensor attributes
device: torch.device
dtype: torch.dtype
if isinstance(inp, torch.Tensor):
device = inp.device
dtype = inp.dtype
elif isinstance(inp, Float8BlockwiseQTensorBase):
if inp._rowwise_data is not None:
device = inp._rowwise_data.device
elif inp._columnwise_data is not None:
device = inp._columnwise_data.device
else:
raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data")
dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant.
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, "
f"found {inp.__class__.__name__})"
)
world_size = get_distributed_world_size(process_group)
# Check that quantizer is valid
if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128):
raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")
# Output tensor dims
if out_shape is None:
out_shape = list(inp.size())
out_shape[0] *= world_size
# Doing BF16 gather for now as baseline because it's simpler
if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None:
out = torch.empty(
out_shape,
dtype=dtype,
device=device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
out = quantizer(out)
return out, None
# Implementation of fp8 gather needs to account for:
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# * Refer to scaffold code when implementing at:
# https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477
raise NotImplementedError("fp8 blockwise allgather not yet implemented")
def _all_gather_mxfp8(
inp: torch.Tensor,
process_group: dist_group_type,
......@@ -1069,7 +1155,9 @@ def gather_along_first_dim(
async_op: bool = False,
quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
"""All-gather tensors and concatenate along first dimension."""
"""
All-gather tensors and concatenate along first dimension.
"""
# Return immediately if no communication is required
world_size = get_distributed_world_size(process_group)
......@@ -1094,6 +1182,16 @@ def gather_along_first_dim(
out_shape=out_shape,
)
# FP8 block scaling case, block length = 128
if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer):
return _all_gather_fp8_blockwise(
inp,
process_group,
async_op=async_op,
quantizer=quantizer,
out_shape=out_shape,
)
# MXFP8 case
if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer):
assert isinstance(quantizer, MXFP8Quantizer)
......@@ -1105,6 +1203,28 @@ def gather_along_first_dim(
out_shape=out_shape,
)
# Debug case - call gather_along_first_dim on each tensor
if isinstance(inp, DebugQuantizedTensor):
out_obj = inp
rowwise = inp.get_tensor(False)
columnwise = inp.get_tensor(True)
final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
)
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise:
final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
)
columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise
)
out_obj.columnwise_gemm_tensor = columnwise_total
else:
out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor
return out_obj, None
# High-precision communication for quantized tensors
if quantizer is not None:
warnings.warn(
......@@ -1147,6 +1267,152 @@ def gather_along_first_dim(
return out, handle
# Global cache to store symmetric memory tensors
symmetric_mem_cache = {}
def get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group, tag=None):
"""
Gets or creates a symmetric memory tensor with specified properties.
Reuses cached tensors when available to avoid redundant creation and rendezvous operations.
Note: This function always returns a 1D tensor.
Parameters
----------
tensor_numel : int
Number of elements in the tensor.
tensor_dtype : torch.dtype
Data type of the tensor.
tensor_device : torch.device
Device on which to allocate the tensor.
tp_group : dist_group_type
Process group for rendezvous operation.
tag : Any, optional
Optional identifier to further distinguish tensors.
Returns
-------
torch.Tensor
A symmetric memory tensor with the specified properties.
"""
# Create a cache key based on tensor properties and group
cache_key = (tensor_numel, tensor_dtype, tensor_device, tp_group.group_name, tag)
# Check if we already have a symmetric memory tensor for this configuration
if cache_key not in symmetric_mem_cache:
# Create a new symmetric memory tensor if not in cache
msg = symm_mem.empty(
tensor_numel,
dtype=tensor_dtype,
device=tensor_device,
)
# Perform the rendezvous once for this tensor
symm_mem.rendezvous(msg, group=tp_group)
# Store in cache
symmetric_mem_cache[cache_key] = msg
else:
# Reuse the existing symmetric memory tensor
msg = symmetric_mem_cache[cache_key]
return msg
def symmetric_all_reduce(
inp: torch.Tensor,
tp_group: Optional[dist_group_type] = None,
async_op: bool = False,
all_reduce_type: str = "multimem_all_reduce",
):
"""
Performs an all-reduce operation across multiple processes using symmetric memory.
If the input tensor is already in the symmetric memory cache we can avoid copy
overheads by just directly using the input tensor for all reduce. Externally
created symmetric memory tensors not in the cache currently will not be able to
avoid the extra copies.
Parameters
----------
inp : torch.Tensor
The input tensor to be reduced. The operation is performed in-place.
tp_group : Optional[dist_group_type], default=None
The process group over which to perform the all-reduce operation.
If None, the default process group is used.
async_op : bool, default=False
Whether to perform the operation asynchronously.
Note: Currently only synchronous operations are supported for symmetric memory variants.
all_reduce_type : str, default="multimem_all_reduce"
The type of all-reduce implementation to use. Options include:
- "nccl": Standard PyTorch distributed all-reduce
- "multimem_all_reduce": multimem symmetric all-reduce
- "two_shot": Two-shot symmetric all-reduce
- "one_shot": One-shot symmetric all-reduce
Returns
-------
Tuple[torch.Tensor, Optional[torch.distributed.Work]]
- The first element is the input tensor with the all-reduce result.
- The second element is the async work handle if async_op=True,
otherwise None.
"""
assert async_op is False, "Async symmetric ops no supported yet"
assert HAS_TORCH_SYMMETRIC, "Could not import symetric memory from torch"
if get_distributed_world_size(tp_group) == 1:
return inp, None
if all_reduce_type == "nccl":
# Standard all-reduce implementation
handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op)
return inp, handle
all_reduce_impl = None
if all_reduce_type == "multimem_all_reduce":
all_reduce_impl = torch.ops.symm_mem.multimem_all_reduce_
elif all_reduce_type == "two_shot":
all_reduce_impl = torch.ops.symm_mem.two_shot_all_reduce_
elif all_reduce_type == "one_shot":
all_reduce_impl = torch.ops.symm_mem.one_shot_all_reduce
else:
raise TypeError(f"All reduce type {all_reduce_type} is not supported.")
group_name = tp_group.group_name
tensor_shape = inp.shape
tensor_numel = inp.numel()
tensor_dtype = inp.dtype
tensor_device = inp.device
input_id = id(inp)
is_cached = any(id(cached_tensor) == input_id for cached_tensor in symmetric_mem_cache.values())
# Check if the input tensor is already in the symmetric memory cache. If it is we can avoid copy overheads.
if is_cached:
all_reduce_impl(
inp,
"sum",
group_name,
)
else:
# Get symmetric memory tensor. Build or retrieve from cache.
msg = get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group)
msg.copy_(inp.reshape(-1))
all_reduce_impl(
msg,
"sum",
group_name,
)
# Copy the result back to the input tensor
inp.copy_(msg.reshape(tensor_shape))
return inp, None
def allreduce(
inp: torch.Tensor,
tp_group: Optional[dist_group_type] = None,
......
......@@ -128,9 +128,9 @@ class InferenceParams:
self,
max_batch_size: int,
max_sequence_length: int,
num_heads_kv: int = 16,
head_dim_k: int = 64,
dtype: torch.dtype = torch.bfloat16,
num_heads_kv: int = None,
head_dim_k: int = None,
dtype: torch.dtype = None,
head_dim_v: int = None,
is_paged: bool = False,
total_num_pages: int = None,
......@@ -141,6 +141,10 @@ class InferenceParams:
):
self.max_batch_size = max_batch_size
self.max_sequence_length = max_sequence_length
assert all(x is not None for x in [num_heads_kv, head_dim_k, dtype]), (
"num_heads_kv, head_dim_k, and dtype are required for InferenceParams since Transformer"
" Engine 2.2."
)
self.num_heads_kv = num_heads_kv
self.head_dim_k = head_dim_k
self.dtype = dtype
......
......@@ -7,7 +7,12 @@ Rotary Position Embedding implementation of different types along with helper fu
"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"]
class RotaryPositionEmbedding(torch.nn.Module):
......@@ -22,19 +27,24 @@ class RotaryPositionEmbedding(torch.nn.Module):
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
rotary_base: float = 10000.0,
interleaved: bool = False,
):
"""
Parameters
----------
dim: int
rotary embedding dimension
rotary_percent: float
Rotary embedding dimension.
rotary_percent: float, default = 1.0
Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int
if not None, discrete positions will be interpolated by this factor via the trick in
seq_len_interpolation_factor: int, default = None
If not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int
pre-trained max_position_embeddings before position interpolation
pretrained_max_position_embeddings: int, default = None
Pre-trained max_position_embeddings before position interpolation.
rotary_base: float, default = 10000.0
Base of the rotary position embedding.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
"""
super().__init__()
if rotary_percent < 1.0:
......@@ -50,17 +60,18 @@ class RotaryPositionEmbedding(torch.nn.Module):
)
self.register_buffer("inv_freq", inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
self.interleaved = interleaved
def forward(self, max_seq_len: int, offset: int = 0):
"""
Create rotary position embedding frequencies
Create rotary position embedding frequencies.
Parameters
----------
max_seq_len: int
sequence length of a sample
Sequence length of a sample.
offset: int, default = 0
fixed offset for freqencies
Fixed offset for frequencies.
"""
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
......@@ -84,7 +95,12 @@ class RotaryPositionEmbedding(torch.nn.Module):
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
if not self.interleaved:
emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
freqs.shape[0], -1
)
# emb [seq_length, .., dim]
return emb.reshape(emb.size(0), 1, 1, emb.size(1))
......@@ -104,61 +120,146 @@ class FusedRoPEFunc(torch.autograd.Function):
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
"""Fused RoPE forward."""
if freqs.dtype != torch.float32:
freqs = freqs.float()
if tensor_format == "sbhd":
output = tex.fused_rope_forward(t, freqs, False)
elif tensor_format == "bshd":
output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
elif tensor_format == "thd":
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
assert tensor_format in (
"sbhd",
"bshd",
"thd",
), f"Unsupported tensor_format: {tensor_format}."
output = tex.fused_rope_forward(
t, freqs, QKVFormat[tensor_format], interleaved, cu_seqlens, cp_size, cp_rank
)
ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
"""Fused RoPE backward."""
freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = tex.fused_rope_backward(
grad_output.transpose(0, 1), freqs, True
).transpose(0, 1)
elif ctx.tensor_format == "thd":
grad_input = tex.fused_rope_thd_backward(
grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank
grad_output,
freqs,
QKVFormat[ctx.tensor_format],
ctx.interleaved,
cu_seqlens,
ctx.cp_size,
ctx.cp_rank,
)
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None, None, None, None
return grad_input, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
change sign so the last dimension becomes [-odd, +even]
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
"""Change sign so the last dimension becomes [-odd, +even]
Args:
x: torch.Tensor. Input tensor.
interleaved: bool. Whether to use interleaved rotary position embedding.
Returns:
Tensor: Tensor rotated half.
"""
x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
x1, x2 = x.unbind(dim=-2)
if not interleaved:
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
# interleaved
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x_new = torch.stack((-x2, x1), dim=-1)
return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)
def _apply_rotary_pos_emb_base(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
interleaved: bool = False,
) -> torch.Tensor:
"""
Base implementation of applying rotary positional embedding tensor to the input tensor.
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
"""
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * cos_) + (_rotate_half(t, interleaved) * sin_)
return torch.cat((t, t_pass), dim=-1)
def _get_freqs_on_this_cp_rank(
freqs: torch.Tensor, seqlen: int, cp_size: int, cp_rank: int
) -> torch.Tensor:
"""Get the position embedding on the current context parallel rank.
Args:
freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`.
seqlen: int. Length of the current sequence.
cp_size: int. Context parallel world size.
cp_rank: int. Context parallel rank.
"""
if cp_size > 1:
cp_seg = seqlen // 2
full_seqlen = cp_size * seqlen
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
# cp_size == 1
return freqs[:seqlen]
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
interleaved: bool = False,
fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
......@@ -175,11 +276,13 @@ def apply_rotary_pos_emb(
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
......@@ -189,37 +292,40 @@ def apply_rotary_pos_emb(
cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
"""
if fused:
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
assert tensor_format in ("sbhd", "bshd"), (
"Only formats `sbhd` or `bshd` are supported for input tensor `t` "
f"when fused is False, got {tensor_format}."
if fused:
return FusedRoPEFunc.apply(
t, freqs, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
)
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# Unfused THD format
if tensor_format == "thd":
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
_apply_rotary_pos_emb_base(
x.unsqueeze(1),
_get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank),
interleaved=interleaved,
)
for x in torch.split(t, seqlens)
]
).squeeze(1)
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * cos_) + (_rotate_half(t) * sin_)
return torch.cat((t, t_pass), dim=-1)
# Unfused SBHD/BSHD format
if tensor_format == "sbhd":
seqlen = t.size(0)
elif tensor_format == "bshd":
seqlen = t.size(1)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
return _apply_rotary_pos_emb_base(
t,
_get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank),
tensor_format,
interleaved=interleaved,
)
......@@ -6,6 +6,7 @@
from __future__ import annotations
import abc
import itertools
import os
from contextlib import contextmanager
from collections import deque
......@@ -19,6 +20,7 @@ from transformer_engine.common.recipe import (
Format,
MXFP8BlockScaling,
Float8CurrentScaling,
Float8BlockScaling,
)
from .constants import dist_group_type
......@@ -56,6 +58,17 @@ def check_mxfp8_support() -> Tuple[bool, str]:
return False, "Device compute capability 10.0 or higher required for MXFP8 execution."
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available"""
if (
get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0)
and float(torch.version.cuda) >= 12.9
):
return True, ""
return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args."""
if get_device_compute_capability() >= (10, 0): # blackwell and above
......@@ -116,6 +129,8 @@ class FP8GlobalStateManager:
skip_fp8_weight_update_tensor = None
mxfp8_available = None
reason_for_no_mxfp8 = ""
fp8_block_scaling_available = None
reason_for_no_fp8_block_scaling = None
@classmethod
def reset(cls) -> None:
......@@ -141,6 +156,8 @@ class FP8GlobalStateManager:
cls.skip_fp8_weight_update_tensor = None
cls.mxfp8_available = None
cls.reason_for_no_mxfp8 = ""
cls.fp8_block_scaling_available = None
cls.reason_for_no_fp8_block_scaling = ""
@classmethod
def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
......@@ -168,6 +185,15 @@ class FP8GlobalStateManager:
cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support()
return cls.mxfp8_available, cls.reason_for_no_mxfp8
@classmethod
def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]:
"""Return if Float8 block scaling support is available."""
if cls.fp8_block_scaling_available is None:
cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = (
check_fp8_block_scaling_support()
)
return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling
@staticmethod
def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`."""
......@@ -441,6 +467,9 @@ class FP8GlobalStateManager:
if isinstance(fp8_recipe, MXFP8BlockScaling):
mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
assert mxfp8_available, reason_for_no_mxfp8
if isinstance(fp8_recipe, Float8BlockScaling):
fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available()
assert fp8_block_available, reason_for_no_fp8_block
@classmethod
def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
......@@ -793,8 +822,10 @@ class RecipeState(abc.ABC):
cls = MXFP8BlockScalingRecipeState
elif recipe.float8_current_scaling():
cls = Float8CurrentScalingRecipeState
elif recipe.float8_block_scaling():
cls = Float8BlockScalingRecipeState
else:
raise ValueError("{recipe.__class__.__name__} is not supported")
raise ValueError(f"{recipe.__class__.__name__} is not supported")
return cls(
recipe,
mode=mode,
......@@ -935,3 +966,108 @@ class MXFP8BlockScalingRecipeState(RecipeState):
from .tensor.mxfp8_tensor import MXFP8Quantizer
return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)]
class Float8BlockScalingRecipeState(RecipeState):
"""Configuration for Float8BlockScaling quantization.
Float8BlockScaling quantization does not require state,
but different quantizers use different modes.
"""
recipe: Float8BlockScaling
mode: str
qx_dtype: tex.DType
qw_dtype: tex.DType
qgrad_dtype: tex.DType
def __init__(
self,
recipe: Float8BlockScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.qx_dtype = get_fp8_te_dtype(recipe, True)
self.qw_dtype = get_fp8_te_dtype(recipe, True)
self.qgrad_dtype = get_fp8_te_dtype(recipe, False)
# Allocate buffers
if device is None:
device = torch.device("cuda")
self.device = device
def make_quantizers(self) -> list:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
if self.mode == "forward":
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward, and doesn't play nicely with QuantizeOp,
# which is not associated with a GEMM.
assert self.num_quantizers % 3 == 0 # x, w, output per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qw_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale,
block_scaling_dim=self.recipe.w_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 3)
]
)
)
assert self.mode == "backward", f"Unexpected mode {self.mode}"
assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 2)
]
)
)
......@@ -9,6 +9,7 @@ from dataclasses import dataclass
from functools import reduce
from operator import mul as multiply_op
import queue
import torch
from .. import cpp_extensions as tex
......@@ -226,3 +227,79 @@ class _ParameterInitMeta:
"""Safeguard reference to the parameter's parent module and initialization function."""
if self.init_fn is None:
self.init_fn = get_default_init_method()
class WeightGradStore:
"""
A class to manage weight gradient storage and computation in Transformer modules.
This class enables split backward propagation for better memory efficiency.
"""
def __init__(self, delay_wgrad_compute=False, ub_bulk_wgrad=False):
"""
Initialize the WeightGradStore.
Args:
delay_wgrad_compute (bool): Whether to delay weight gradient computation
ub_bulk_wgrad (bool): Whether to enable bulk weight gradient computation
"""
if delay_wgrad_compute:
self.context = queue.Queue()
assert (
ub_bulk_wgrad is False
), "ub_bulk_wgrad is not supported when enabling delay_wgrad_compute"
self.enabled = delay_wgrad_compute
else:
self.context = None
self.enabled = False
def delay_wgrad_compute(self):
"""
Get the current split backward propagation status.
Returns:
bool: True if split backward is enabled, False otherwise
"""
return self.enabled
def enable_delay_wgrad_compute(self):
"""Enable split backward propagation."""
self.enabled = True
def disable_delay_wgrad_compute(self):
"""Disable split backward propagation."""
self.enabled = False
def put(self, tensor_list, func):
"""
Store tensors and computation function for later execution.
Args:
tensor_list (list): List of tensors needed for computation
func (callable): Function to be executed with the tensors
"""
assert self.enabled is True, "delay_wgrad_compute is not enabled"
self.context.put([tensor_list, func])
def pop(self):
"""
Execute the stored computation with the stored tensors.
Raises an exception if the queue is empty.
"""
assert self.enabled is True, "delay_wgrad_compute is not enabled"
if self.context.qsize() > 0:
tensor_list, func = self.context.get()
return func(*tensor_list), tensor_list
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
raise RuntimeError(f"Pop empty queue. rank {rank}")
raise RuntimeError("Pop empty queue. No distributed environment detected.")
def assert_empty(self):
"""
Assert that the queue is empty.
Used for debugging and ensuring proper cleanup.
"""
assert self.enabled is True, "delay_wgrad_compute is not enabled"
rank = torch.distributed.get_rank()
assert self.context.empty(), f"Queue is not empty. rank {rank}"
......@@ -10,6 +10,7 @@ import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager
import logging
from types import MethodType
import torch
......@@ -18,11 +19,12 @@ import torch.nn.functional as F
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from ._common import _ParameterInitMeta
from ._common import _ParameterInitMeta, noop_cat
from ..fp8 import (
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
)
......@@ -34,8 +36,13 @@ from ..distributed import (
)
from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION
__all__ = ["initialize_ub", "destroy_ub"]
......@@ -44,7 +51,8 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_multi_stream_cublas_batchgemm_workspace = []
_dummy_wgrads = {}
multi_stream_cublas_batchgemm_workspace = []
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 2 if IS_HIP_EXTENSION else 3
......@@ -82,6 +90,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
)
return _multi_stream_cublas_workspace
def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_batchgemm_workspace
......@@ -92,11 +101,29 @@ def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
)
return _multi_stream_cublas_batchgemm_workspace
if bool(int(os.getenv("NVTE_DISABLE_FC2_DGRAD_OVERLAP", "0"))):
remove_ag_gemm_dgrad = ["fc2_dgrad"]
else:
remove_ag_gemm_dgrad = []
def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor:
"""Returns a dummy tensor of given shape."""
assert len(shape) == 2
global _dummy_wgrads
if (shape[0], shape[1], dtype) not in _dummy_wgrads:
_dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty(
shape,
dtype=dtype,
device="cuda",
requires_grad=False,
)
if zero:
_dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0)
return _dummy_wgrads[(shape[0], shape[1], dtype)].detach()
def initialize_ub(
shape: list,
tp_size: int,
......@@ -429,6 +456,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def __init__(self) -> None:
super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None
self.fp8_initialized = False
self.fp8 = False
self.fp8_calibration = False
......@@ -448,6 +476,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None
if not TEDebugState.debug_enabled:
TEDebugState.initialize()
# Names of attributes that can be set quickly (see __setattr__
# method)
_fast_setattr_names: Set[str] = {
......@@ -535,6 +566,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
recipe_state, Float8CurrentScalingRecipeState
):
return
if recipe.float8_block_scaling() and isinstance(
recipe_state, Float8BlockScalingRecipeState
):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
......@@ -860,7 +895,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8:
if not ctx.fp8 and not ctx.debug:
if gather_grad_output:
if not ctx.ub_overlap_ag or ctx.ub_obj_gradout is None:
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
......@@ -870,6 +905,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return grad_output, None
# FP8 with all-gather: unfused bgrad, fused cast + transpose
# Also supports debug quantization, which is handled inside gather_along_first_dim.
if gather_grad_output:
grad_bias = None
if ctx.use_bias:
......@@ -877,7 +913,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if ctx.ub_overlap_ag and ctx.ub_obj_gradout is not None:
# Quantize the gradient if needed
if not isinstance(
grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)
grad_output,
(
QuantizedTensor,
Float8TensorBase,
MXFP8TensorBase,
Float8BlockwiseQTensorBase,
),
):
grad_output = quantizer(grad_output)
......@@ -892,14 +934,41 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
return grad_output, grad_bias
# Debug without all-gather: unfused cast and bgrad
# bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
if ctx.debug:
grad_output_ = quantizer(grad_output)
if (
isinstance(
grad_output_.get_tensor(True),
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase),
)
and ctx.use_bias
):
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias = None
grad_output = grad_output_
return grad_output, grad_bias
# FP8 without all-gather: fused bgrad + cast + transpose
grad_bias = None
if ctx.use_bias:
if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
if isinstance(
grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else:
if isinstance(quantizer, Float8BlockQuantizer):
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
if not isinstance(
grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
):
grad_output = quantizer(grad_output)
return grad_output, grad_bias
......@@ -998,6 +1067,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor:
"""Get FP8 workspace buffer and maybe update its values
......@@ -1020,6 +1090,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
over `update_workspace` if provided.
fsdp_group: bool, default = None
FSDP process group that the weights are distributed over.
workspace_dtype: torch.dtype, default = None
If weight workspace contains high-precision tensor - for example
for debug quantization, this is dtype of the tensor.
"""
# FP8 primary weights
......@@ -1033,6 +1106,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Try getting workspace from cache
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)
if quantizer is not None and isinstance(out, MXFP8TensorBase):
......@@ -1043,6 +1117,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out = None
del self._fp8_workspaces[cache_name]
is_debug = isinstance(quantizer, DebugQuantizer)
is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor)
if is_debug != is_out_debug_tensor:
out = None
# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
......@@ -1060,7 +1139,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise ValueError(
"tensor and quantizer kwargs must be provided to construct FP8 workspace"
)
out = quantizer(tensor)
out = quantizer.quantize(tensor, dtype=workspace_dtype)
# Update cache
if cache_name is not None:
......@@ -1077,7 +1156,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out.quantize_(tensor, noop_flag=skip_update_flag)
else:
tex.quantize(tensor, quantizer, out, skip_update_flag)
return out
def _load_from_state_dict(
......@@ -1100,3 +1178,68 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names]
weight_tensor = noop_cat(unfused_weights)
if weight_tensor.grad is None:
weight_tensor.grad = wgrad.to(weight_tensor.dtype)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None:
bias_tensor.grad = grad_bias_.to(bias_tensor.dtype)
del grad_bias_
del wgrad
def _validate_name(self):
"""
Validate name passed to the module.
This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
If no name is assigned, it creates a default name with layer count as the variable.
"""
assert TEDebugState.debug_enabled
import nvdlfw_inspect.api as debug_api
if self.name is None:
debug_api.log_message(
"Names are not provided to debug modules. ",
"Creating and using generic names. Pass names to debug modules for better"
" insight. ",
level=logging.WARNING,
)
self.name = f"Layer_{TEDebugState.get_layer_count()}"
def _turn_off_unsupported_features_in_debug(self):
if (
getattr(self, "ub_bulk_wgrad", False)
or getattr(self, "ub_bulk_dgrad", False)
or getattr(self, "ub_overlap_ag", False)
or getattr(self, "ub_overlap_rs_dgrad", False)
or getattr(self, "ub_overlap_rs", False)
):
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
"UserBuffers are not supported in debug module. "
"Using UB optimization will not affect the debug module. ",
level=logging.WARNING,
)
if hasattr(self, "ub_bulk_wgrad"):
self.ub_bulk_wgrad = None
if hasattr(self, "ub_bulk_dgrad"):
self.ub_bulk_dgrad = None
if hasattr(self, "ub_overlap_ag"):
self.ub_overlap_ag = None
if hasattr(self, "ub_overlap_rs_dgrad"):
self.ub_overlap_rs_dgrad = None
if hasattr(self, "ub_overlap_rs"):
self.ub_overlap_rs = None
......@@ -4,12 +4,13 @@
"""FP8 Padding API"""
from typing import Union, List
from typing import List, Optional, Tuple
import torch
import transformer_engine_torch as tex
from ..fp8 import FP8GlobalStateManager
from ..jit import no_torch_dynamo
......@@ -74,22 +75,30 @@ class Fp8Padding(torch.nn.Module):
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
align_size: int, optional
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
"""
def __init__(
self,
num_gemms,
num_gemms: int,
align_size: Optional[int] = None,
) -> None:
super().__init__()
self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
m_splits: List[int],
) -> Union[torch.Tensor, List[int]]:
) -> Tuple[torch.Tensor, List[int]]:
"""
Apply the padding to the input.
......@@ -104,7 +113,12 @@ class Fp8Padding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
# FP8 padding calculate
padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits]
padded_m_splits = [
(m + self.align_size - 1) // self.align_size * self.align_size for m in m_splits
]
# no padding needed
if m_splits == padded_m_splits:
return inp, m_splits
if torch.is_grad_enabled():
fn = _Fp8Padding.apply
......
......@@ -4,12 +4,13 @@
"""FP8 Padding API"""
from typing import List
from typing import List, Optional
import torch
import transformer_engine_torch as tex
from ..fp8 import FP8GlobalStateManager
from ..jit import no_torch_dynamo
......@@ -70,15 +71,23 @@ class Fp8Unpadding(torch.nn.Module):
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
align_size: int, optional
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
"""
def __init__(
self,
num_gemms,
num_gemms: int,
align_size: Optional[int] = None,
) -> None:
super().__init__()
self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size
@no_torch_dynamo()
def forward(
......@@ -100,7 +109,12 @@ class Fp8Unpadding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
# FP8 padding calculate
padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits]
padded_m_splits = [
(m + self.align_size - 1) // self.align_size * self.align_size for m in m_splits
]
# no padding needed
if m_splits == padded_m_splits:
return inp
if torch.is_grad_enabled():
fn = _Fp8Unpadding.apply
......
......@@ -5,10 +5,12 @@
"""GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
import functools
import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from .base import (
get_multi_stream_cublas_workspace,
TransformerEngineBaseModule,
......@@ -16,6 +18,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import WeightGradStore
from ..fp8 import FP8GlobalStateManager
from ..utils import (
divide,
......@@ -37,7 +40,6 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..tensor.float8_tensor import Float8Tensor
from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.quantized_tensor import (
......@@ -47,7 +49,6 @@ from ..tensor.quantized_tensor import (
restore_from_saved,
)
__all__ = ["GroupedLinear"]
......@@ -65,6 +66,7 @@ class _GroupedLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizers: List[Quantizer],
weight_quantizers: List[Quantizer],
output_quantizers: List[Quantizer],
......@@ -85,13 +87,6 @@ class _GroupedLinear(torch.autograd.Function):
biases = weights_and_biases[num_gemms:]
device = inp.device
# TODO Support MXFP8 # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8():
raise NotImplementedError("GroupedLinear does not yet support MXFP8")
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling")
# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
......@@ -124,7 +119,11 @@ class _GroupedLinear(torch.autograd.Function):
for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False)
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
inputmats = tex.fused_multi_quantize(
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
)
......@@ -165,7 +164,7 @@ class _GroupedLinear(torch.autograd.Function):
m_splits=m_splits,
bias=biases,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
use_split_accumulator=fprop_gemm_use_split_accumulator,
)
if fp8_calibration:
......@@ -177,9 +176,19 @@ class _GroupedLinear(torch.autograd.Function):
weight_quantizers[i].calibrate(weights[i])
if is_grad_enabled:
ctx.weight_quantizers = weight_quantizers
ctx.weights_shape_1 = weights[0].shape[1]
# TODO: update after #1638 is merged. # pylint: disable=fixme
if weight_requires_grad:
for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensor):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensor):
weight.update_usage(columnwise_usage=True)
tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
*weights_fp8,
......@@ -200,6 +209,7 @@ class _GroupedLinear(torch.autograd.Function):
ctx.num_gemms = num_gemms
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
......@@ -213,6 +223,7 @@ class _GroupedLinear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module()
)
ctx.wgrad_store = wgrad_store
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
......@@ -245,6 +256,13 @@ class _GroupedLinear(torch.autograd.Function):
grad_biases = [None] * ctx.num_gemms
if ctx.fp8:
if ctx.use_bias:
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready
# for Float8BlockQuantizer.
if ctx.fp8_recipe.float8_block_scaling():
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output[i] = ctx.grad_output_quantizers[i](grad_output_mats[i])
else:
for i in range(ctx.num_gemms):
grad_biases[i], grad_output[i] = tex.bgrad_quantize(
grad_output_mats[i], ctx.grad_output_quantizers[i]
......@@ -267,12 +285,25 @@ class _GroupedLinear(torch.autograd.Function):
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.requires_dgrad:
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_dgrad.use_split_accumulator
)
dgrad = torch.empty(
(sum(ctx.m_splits), ctx.weights_shape_1),
dtype=ctx.activation_dtype,
device=ctx.device,
)
for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensor):
weight.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
general_grouped_gemm(
weights,
grad_output,
......@@ -283,10 +314,17 @@ class _GroupedLinear(torch.autograd.Function):
layout="NN",
m_splits=ctx.m_splits,
grad=True,
use_split_accumulator=_2X_ACC_DGRAD,
use_split_accumulator=dgrad_gemm_use_split_accumulator,
)
if ctx.weights_requires_grad:
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_wgrad.use_split_accumulator
)
if ctx.fuse_wgrad_accumulation:
wgrad_list = main_grads
else:
......@@ -294,21 +332,24 @@ class _GroupedLinear(torch.autograd.Function):
torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device)
for w in weights
]
# WGRAD
_, grad_biases_, _ = general_grouped_gemm(
inputmats,
grad_output,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
grouped_gemm_wgrad = functools.partial(
general_grouped_gemm,
out_dtype=ctx.activation_dtype,
workspaces=get_multi_stream_cublas_workspace(),
layout="NT",
grad=True,
m_splits=ctx.m_splits,
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=_2X_ACC_WGRAD,
use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
)
# WGRAD
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad)
else:
_, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
......@@ -351,7 +392,14 @@ class _GroupedLinear(torch.autograd.Function):
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias:
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or (
ctx.wgrad_store is not None
and ctx.wgrad_store.delay_wgrad_compute()
and not ctx.fp8
):
grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
......@@ -372,8 +420,9 @@ class _GroupedLinear(torch.autograd.Function):
None,
None,
None,
None, # is_grad_enabled
None, # is_grad_enabled
None,
None,
None,
*wgrad_list,
*grad_biases,
)
......@@ -422,7 +471,12 @@ class GroupedLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
`parallel_mode` are used to determine the shapes of weights and biases.
The TP communication should be handled in the dispatch and combine stages of MoE models.
"""
def __init__(
......@@ -445,6 +499,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
) -> None:
super().__init__()
......@@ -465,7 +520,13 @@ class GroupedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0}
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1}
self._num_fp8_tensors_per_gemm = {
"fwd": 3,
"bwd": 2,
}
if tp_group is None:
self.tp_size = tp_size
......@@ -476,6 +537,12 @@ class GroupedLinear(TransformerEngineBaseModule):
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
if self.tp_size > 1 and bias:
raise ValueError(
"GroupedLinear doesn't support bias when TP > 1. "
"Because the TP communication is handled outside of this module."
)
self.parallel_mode = parallel_mode
assert (
self.parallel_mode in GemmParallelModes
......@@ -502,7 +569,7 @@ class GroupedLinear(TransformerEngineBaseModule):
),
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"] + i,
fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"],
)
# Construct bias parameters if needed
......@@ -527,12 +594,18 @@ class GroupedLinear(TransformerEngineBaseModule):
self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True
else:
self.gemm_bias_unfused_add = False
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
assert not self.tp_size > 1, (
"GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
self._customize_quantizers_float8_current_scaling(fwd, recipe)
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
......@@ -590,7 +663,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced)
"""
assert not isinstance(
inp, Float8Tensor
inp, QuantizedTensor
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
......@@ -615,20 +688,27 @@ class GroupedLinear(TransformerEngineBaseModule):
grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms
if self.fp8:
input_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["input"] + i]
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms):
input_quantizers[i].internal = True
input_quantizers[i].internal = False
weight_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["weight"] + i]
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][self._offsets["input"] + i]
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
......@@ -643,10 +723,11 @@ class GroupedLinear(TransformerEngineBaseModule):
args += (
inp,
m_splits,
self.apply_bias and not self.gemm_bias_unfused_add,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
......@@ -663,17 +744,61 @@ class GroupedLinear(TransformerEngineBaseModule):
)
out = linear_fn(*args)
if self.gemm_bias_unfused_add:
out_shape = out.shape
out = torch.cat(
[
o + cast_if_needed(b, self.activation_dtype)
for o, b in zip(
torch.split(out.view(-1, self.out_features), m_splits), bias_tensors
)
]
).view(out_shape)
if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2]
if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms):
weight_param = getattr(self, f"weight{i}")
if weight_param.grad is None:
weight_param.grad = wgrad_list[i].to(weight_param.dtype)
if self.use_bias:
for i in range(self.num_gemms):
bias_param = getattr(self, f"bias{i}")
if bias_param.grad is None:
bias_param.grad = grad_biases_[i].to(bias_param.dtype)
del grad_biases_
del wgrad_list
del tensor_list
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert (
recipe.float8_current_scaling()
), "current scaling recipe quantizer customization here"
if fwd:
for i in range(self.num_gemms):
# set configs about amax epsilon and power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
# also set weight quantizer with same amax_epsilon & power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
else:
for i in range(self.num_gemms):
# set grad_output_quantizer with amax epsilon and power_2_scale
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
......@@ -9,16 +9,19 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import functools
import torch
from torch.nn import init
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import (
get_workspace,
get_ub,
TransformerEngineBaseModule,
get_dummy_wgrad,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
......@@ -34,11 +37,13 @@ from ..utils import (
nvtx_range_pop,
nvtx_range_push,
requires_grad,
needs_quantized_gemm,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
allreduce,
symmetric_all_reduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
in_fp8_activation_recompute_phase,
......@@ -48,16 +53,21 @@ from ..distributed import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose
from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose, WeightGradStore
from ..tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import (
general_gemm,
)
......@@ -89,12 +99,14 @@ class _LayerNormLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
......@@ -119,6 +131,8 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
......@@ -143,11 +157,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
# Avoid quantized norm kernel if norm output will be returned
with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered
)
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
......@@ -180,6 +189,18 @@ class _LayerNormLinear(torch.autograd.Function):
columnwise_usage = False
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
force_hp_blockwise_ln_out_gather = (
fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
with_quantized_norm = (
fp8
and not return_layernorm_output
and not return_layernorm_output_gathered
and not force_hp_blockwise_ln_out_gather
)
# Apply normalization
nvtx_range_push(f"{nvtx_label}.norm")
ln_out, mu, rsigma = apply_normalization(
......@@ -210,13 +231,13 @@ class _LayerNormLinear(torch.autograd.Function):
# norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8:
if fp8 or debug:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = input_quantizer(ln_out_total)
else:
if fp8:
if not with_quantized_norm:
if fp8 or debug:
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop:
......@@ -229,18 +250,19 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(input_quantizer if fp8 else None),
quantizer=(input_quantizer if fp8 or debug else None),
)
else:
if fp8 and not with_quantized_norm:
if (fp8 or debug) and not with_quantized_norm:
ln_out = input_quantizer(ln_out)
ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# Cast weight to expected dtype
if not fp8:
weightmat = weight
quantized_weight = False
weightmat = cast_if_needed(weight, activation_dtype)
if not fp8 and not debug:
weightmat = cast_if_needed(weightmat, activation_dtype)
else:
quantized_weight = not isinstance(weight, QuantizedTensor)
......@@ -250,6 +272,7 @@ class _LayerNormLinear(torch.autograd.Function):
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace(
tensor=weight,
quantizer=weight_quantizer,
......@@ -257,11 +280,12 @@ class _LayerNormLinear(torch.autograd.Function):
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
# Cast bias to expected dtype
bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32:
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
......@@ -319,9 +343,11 @@ class _LayerNormLinear(torch.autograd.Function):
clear_tensor_data(ln_out, ln_out_total)
if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer
ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)
ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather
# Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input:
......@@ -332,21 +358,16 @@ class _LayerNormLinear(torch.autograd.Function):
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False)
# For force_hp_blockwise_ln_out_gather, we should
# be saving the unquantized ln_out to ctx.
assert not force_hp_blockwise_ln_out_gather
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading:
if fp8 and weightmat is not None:
set_offloading_param(weightmat, "weight_offloading", True)
set_offloading_param(ln_weight, "weight_offloading", True)
set_offloading_param(weight, "weight_offloading", True)
set_offloading_param(inputmat, "activation_offloading", True)
set_offloading_param(mu, "activation_offloading", True)
set_offloading_param(rsigma, "activation_offloading", True)
set_offloading_param(ln_out, "activation_offloading", True)
mark_activation_offload(inputmat, mu, rsigma, ln_out)
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
......@@ -391,6 +412,7 @@ class _LayerNormLinear(torch.autograd.Function):
if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad
ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.input_quantizer = input_quantizer
ctx.owns_input = inputmat is not inp
......@@ -425,6 +447,8 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store
ctx.debug = debug
# Row Parallel Linear
if ub_overlap_rs_fprop:
......@@ -434,6 +458,9 @@ class _LayerNormLinear(torch.autograd.Function):
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
......@@ -564,12 +591,27 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
else:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
(
grad_output,
......@@ -582,21 +624,28 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Prepare GEMM input
# Note: Perform tensor-parallel communication if needed
# Launch tensor-parallel communication for LayerNorm out tensor
ln_out_total = None
ln_out_total_work = None
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None
if ctx.fp8:
if ctx.input_quantizer is not None:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False)
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
# async_op is not compatible with high precision gather since
# gather_along_first_dim does not offer callback chaining.
gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
quantizer=gather_quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
......@@ -621,6 +670,11 @@ class _LayerNormLinear(torch.autograd.Function):
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor):
weight.update_usage(
rowwise_usage=ctx.weight_quantizer.rowwise_usage,
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
)
dgrad, *_ = general_gemm(
weight,
grad_output,
......@@ -659,6 +713,8 @@ class _LayerNormLinear(torch.autograd.Function):
# Compute grad weight tensor
wgrad = None
if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
if ctx.fp8:
......@@ -672,18 +728,32 @@ class _LayerNormLinear(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
else:
if ln_out_total_work is not None:
# Synchronize tensor-parallel communication
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.input_quantizer is not None and not isinstance(
ln_out_total, QuantizedTensor
):
# Async gather may have been done in BF16
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)
# Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor):
ln_out_total.update_usage(columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor):
# This is a no-op if platform supports non-TN FP8 GEMM or the transpose
# already exists.
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
grad_output.update_usage(columnwise_usage=True)
# Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device
......@@ -692,39 +762,29 @@ class _LayerNormLinear(torch.autograd.Function):
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_wgrad.use_split_accumulator
)
wgrad, grad_bias_, *_, rs_out = general_gemm(
ln_out_total,
grad_output,
get_workspace(),
layout="NT",
grad=True,
general_gemm_wgrad = functools.partial(
general_gemm,
out_dtype=(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
grad=True,
bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=wgrad_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad,
ub_type=ub_type_wgrad,
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad)
else:
dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True)
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output)
if grad_bias is None:
grad_bias = grad_bias_
......@@ -734,6 +794,17 @@ class _LayerNormLinear(torch.autograd.Function):
if not ctx.return_layernorm_output:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out
else:
dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True)
# Don't return grad bias if not needed
if not ctx.use_bias:
grad_bias = None
# Synchronize tensor parallel communication
if ln_out_total_work is not None:
......@@ -787,18 +858,15 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"):
origin_weight.grad_added_to_main_grad = True
if getattr(origin_weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
origin_weight.main_grad.shape,
dtype=origin_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(origin_weight.main_grad.shape),
origin_weight.dtype,
zero=True,
)
else:
wgrad = torch.empty(
origin_weight.main_grad.shape,
dtype=origin_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(origin_weight.main_grad.shape),
origin_weight.dtype,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
......@@ -824,12 +892,14 @@ class _LayerNormLinear(torch.autograd.Function):
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # cpu_offloading
None, # tp_group
None, # tp_size
......@@ -852,8 +922,10 @@ class _LayerNormLinear(torch.autograd.Function):
None, # ub_bulk_wgrad
None, # ub_name
None, # fsdp_group
None, # debug
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
)
......@@ -906,6 +978,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
......@@ -941,6 +1015,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used.
"""
def __init__(
......@@ -970,6 +1053,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
name: str = None,
) -> None:
super().__init__()
......@@ -985,6 +1071,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if tp_group is None:
self.tp_size = tp_size
......@@ -1050,6 +1142,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name
if self.symmetric_ar_type is not None:
assert torch_version() >= (
2,
7,
0,
), "Torch version must be at least 2.7 to use symmetric memory"
self.eps = eps
layer_norm_weight = torch.nn.Parameter(
torch.empty(self.in_features, device=device, dtype=params_dtype)
......@@ -1252,6 +1351,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None,
fp8_output: Optional[bool] = False,
fp8_grad: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
......@@ -1274,6 +1374,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
......@@ -1282,6 +1385,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None:
is_first_microbatch = False
if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
......@@ -1303,13 +1413,28 @@ class LayerNormLinear(TransformerEngineBaseModule):
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if not any_feature_enabled(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
) = self._get_quantizers(fp8_output)
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
......@@ -1327,12 +1452,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
......@@ -1357,6 +1484,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(*args)
......@@ -1374,10 +1503,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
return out, ln_out
return out
def _get_quantizers(self, fp8_output):
def _get_quantizers(self, fp8_output, fp8_grad):
if not self.fp8:
return [None] * 5
return [None] * 6
grad_input_quantizer = None
grad_weight_quantizer = None
grad_output_quantizer = None
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
......@@ -1389,13 +1519,27 @@ class LayerNormLinear(TransformerEngineBaseModule):
if torch.is_grad_enabled():
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True
if fp8_grad:
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
return (
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output, fp8_grad):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
DebugQuantizer(self.name, name, q, self.tp_group)
for name, q in zip(names, original_quantizers)
)
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
......
......@@ -8,6 +8,7 @@ import warnings
from typing import Callable, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import functools
import torch
from torch.nn.parameter import Parameter
......@@ -17,6 +18,7 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import (
get_workspace,
_ub_communicators,
......@@ -42,25 +44,31 @@ from ..utils import (
clear_tensor_data,
requires_grad,
non_tn_fp8_gemm_supported,
needs_quantized_gemm,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
allreduce,
symmetric_all_reduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
use_reentrant_activation_recompute,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
)
from ..constants import dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..tensor.float8_tensor import Float8Tensor
from ..tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
)
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
......@@ -70,6 +78,8 @@ from ..tensor.quantized_tensor import (
from ..cpp_extensions import (
general_gemm,
)
from ...debug.pytorch.utils import any_feature_enabled
from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["LayerNormMLP"]
......@@ -101,7 +111,8 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu),
}
# no activation fusion written yet
# Per-tensor current scaling: []
# Per-tensor current scaling or fp8 blockwise scaling: []
if recipe.float8_current_scaling() or recipe.float8_block_scaling():
return {
"gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, None),
......@@ -112,6 +123,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"srelu": (tex.srelu, tex.dsrelu, None),
}
raise NotImplementedError(f"Unhandled recipe type {recipe}")
def _act_func(activation: str, recipe: Optional[Recipe] = None):
......@@ -119,7 +131,7 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None):
# bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# Per-tensor current scaling: []
# Per-tensor current scaling or fp8 blockwise scaling: []
funcs = _get_act_func_supported_list(recipe)
if activation not in funcs:
raise NotImplementedError("Activation type " + activation + " is not supported!")
......@@ -145,15 +157,20 @@ class _LayerNormMLP(torch.autograd.Function):
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
fc1_input_quantizer: Optional[Quantizer],
fc1_weight_quantizer: Optional[Quantizer],
fc1_output_quantizer: Optional[Quantizer],
fc1_grad_input_quantizer: Optional[Quantizer],
fc1_grad_weight_quantizer: Optional[Quantizer],
fc1_grad_output_quantizer: Optional[Quantizer],
fc2_input_quantizer: Optional[Quantizer],
fc2_weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_fc2_output_quantizer: Optional[Quantizer],
grad_fc1_output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
fc2_output_quantizer: Optional[Quantizer],
fc2_grad_input_quantizer: Optional[Quantizer],
fc2_grad_weight_quantizer: Optional[Quantizer],
fc2_grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
......@@ -179,6 +196,8 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
......@@ -207,16 +226,31 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
# Avoid quantized norm kernel if norm output will be returned
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered
fp8
and not return_layernorm_output
and not return_layernorm_output_gathered
and not debug
)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
# Kernels not available for norm fusion.
with_quantized_norm = False
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
# Configure quantizer for norm output
if fp8:
if fc1_input_quantizer is None:
......@@ -257,13 +291,14 @@ class _LayerNormMLP(torch.autograd.Function):
# norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8:
if fp8 or debug:
if not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = fc1_input_quantizer(ln_out_total)
else:
if fp8:
if not with_quantized_norm:
if fp8 or debug:
if not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
......@@ -276,18 +311,21 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(fc1_input_quantizer if fp8 else None),
quantizer=(fc1_input_quantizer if fp8 or debug else None),
)
else:
if fp8 and not with_quantized_norm:
# NOTE: force_hp_fc1_input_gather is redundant with else, but
# here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp.
if (fp8 or debug) and not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
ln_out_total = ln_out
# Cast weights to expected dtype
if not fp8:
fc1_weight_final = cast_if_needed(fc1_weight, activation_dtype)
fc2_weight_final = cast_if_needed(fc2_weight, activation_dtype)
else:
fc1_weight_final = fc1_weight
fc2_weight_final = fc2_weight
if fp8 or debug:
# If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc.
# FP8 cast to workspace buffer
......@@ -299,6 +337,7 @@ class _LayerNormMLP(torch.autograd.Function):
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_final = module.get_weight_workspace(
......@@ -308,11 +347,15 @@ class _LayerNormMLP(torch.autograd.Function):
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
else:
fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype)
fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype)
# Cast biases to expected dtype
bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32:
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16
if fc1_bias is not None:
fc1_bias = cast_if_needed(fc1_bias, bias_dtype)
......@@ -333,6 +376,7 @@ class _LayerNormMLP(torch.autograd.Function):
# - bias_gelu_fusion - only for full precision.
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
if activation != "gelu":
# blockwise scaled gemms don't support gemm_gelu_fusion in fwd.
gemm_gelu_fusion = bias_gelu_fusion = False
else:
if fp8:
......@@ -341,13 +385,16 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_gelu_fusion = True
if gemm_gelu_fusion and bias_gelu_fusion:
gemm_gelu_fusion = False
if debug:
gemm_gelu_fusion = False
fc1_outputs = general_gemm(
fc1_weight_final,
ln_out_total,
get_workspace(),
quantization_params=(
fc2_input_quantizer if gemm_gelu_fusion else None # fused gelu output is in fp8
fc2_input_quantizer
if gemm_gelu_fusion
else fc1_output_quantizer # fused gelu output is in fp8
),
out_dtype=activation_dtype,
bias=(
......@@ -358,6 +405,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_lnout,
ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None,
)
if not is_grad_enabled and (ln_out_total is not ln_out_return):
clear_tensor_data(ln_out_total)
......@@ -371,8 +419,17 @@ class _LayerNormMLP(torch.autograd.Function):
act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias)
elif gemm_gelu_fusion:
act_out, _, fc1_out, _ = fc1_outputs
elif debug:
fc1_out, *_ = fc1_outputs
act_out = activation_func(fc1_out, None)
act_out = fc2_input_quantizer(act_out)
else:
fc1_out, *_ = fc1_outputs
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling():
# tex.quantize does not support GELU fusion for blockwise.
act_out = activation_func(fc1_out, None)
act_out = tex.quantize(act_out, fc2_input_quantizer)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
if not is_grad_enabled:
......@@ -403,7 +460,7 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(),
out_dtype=activation_dtype,
bias=fc2_bias,
quantization_params=output_quantizer,
quantization_params=fc2_output_quantizer,
out=fc2_out,
use_split_accumulator=_2X_ACC_FPROP,
ub=ub_obj_fc2out,
......@@ -412,7 +469,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
# Weight with column-wise usage is needed for dgrad GEMM.
if is_grad_enabled and inp.requires_grad:
if is_grad_enabled:
if isinstance(fc1_weight_final, QuantizedTensor):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensor):
......@@ -422,23 +479,9 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else:
if cpu_offloading:
if fp8 and fc1_weight_final is not None:
set_offloading_param(fc1_weight_final, "weight_offloading", True)
if fp8 and fc2_weight_final is not None:
set_offloading_param(fc2_weight_final, "weight_offloading", True)
set_offloading_param(ln_weight, "weight_offloading", True)
set_offloading_param(fc1_weight, "weight_offloading", True)
set_offloading_param(fc2_weight, "weight_offloading", True)
set_offloading_param(fc1_bias, "weight_offloading", True)
set_offloading_param(inputmat, "activation_offloading", True)
set_offloading_param(mu, "activation_offloading", True)
set_offloading_param(rsigma, "activation_offloading", True)
set_offloading_param(mu, "activation_offloading", True)
set_offloading_param(ln_out, "activation_offloading", True)
set_offloading_param(fc1_out, "activation_offloading", True)
set_offloading_param(fc1_out_without_bias, "activation_offloading", True)
set_offloading_param(act_out, "activation_offloading", True)
mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
)
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
......@@ -455,10 +498,14 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
)
ctx.fc1_weight_quantizer = fc1_weight_quantizer
ctx.fc2_weight_quantizer = fc2_weight_quantizer
if not fc1_weight.requires_grad:
if not return_layernorm_output:
clear_tensor_data(ln_out)
ln_out = None
elif force_hp_fc1_input_gather:
assert not isinstance(ln_out, QuantizedTensor)
if not fc2_weight.requires_grad:
clear_tensor_data(act_out)
act_out = None
......@@ -487,11 +534,15 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.tensor_objects = tensor_objects
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer
ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
ctx.fc2_input_quantizer = fc2_input_quantizer
ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather
ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer
ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer
ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer
ctx.fc2_grad_input_quantizer = fc2_grad_input_quantizer
ctx.fc2_grad_weight_quantizer = fc2_grad_weight_quantizer
ctx.fc2_grad_output_quantizer = fc2_grad_output_quantizer
ctx.fc1_input_quantizer = fc1_input_quantizer
ctx.fc2_input_quantizer = fc2_input_quantizer
ctx.fc1_weight_requires_grad = fc1_weight.requires_grad
ctx.fc2_weight_requires_grad = fc2_weight.requires_grad
......@@ -502,6 +553,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype = activation_dtype
ctx.activation = activation
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
......@@ -523,6 +575,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
ctx.ub_overlap_ag = ub_overlap_ag
ctx.debug = debug
ctx.requires_dgrad = (
inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad
......@@ -537,12 +590,19 @@ class _LayerNormMLP(torch.autograd.Function):
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store
# Row Parallel Linear
if ub_overlap_rs:
fc2_out = rs_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
elif set_parallel_mode and tensor_parallel:
if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
fc2_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(fc2_out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
......@@ -643,15 +703,27 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad
ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
if ctx.grad_fc2_output_quantizer is not None:
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=False)
else:
ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=True)
# Configure quantizer for FC2 grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.fc2_grad_output_quantizer is not None:
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.fc2_grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.fc2_grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad = None
if ctx.ub_overlap_ag:
ub_obj_fc2_dgrad = get_ub("fc2_dgrad")
......@@ -660,11 +732,10 @@ class _LayerNormMLP(torch.autograd.Function):
grad_output,
fc2_bias_grad,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer
ctx, grad_outputs[0], True, ctx.fc2_grad_output_quantizer
)
# Prepare FC1 GEMM input
# Note: Perform tensor-parallel communication if needed
# Launch tensor-parallel communication for FC1 GEMM input
ln_out_total = None
ln_out_total_work = None
if (
......@@ -674,14 +745,20 @@ class _LayerNormMLP(torch.autograd.Function):
and not ctx.ub_bulk_dgrad
):
quantizer = None
if ctx.fp8:
if ctx.fp8 or ctx.debug:
quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False)
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
quantizer=gather_quantizer,
)
else:
ln_out_total = ln_out
......@@ -693,17 +770,26 @@ class _LayerNormMLP(torch.autograd.Function):
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# There are 5 possible fusion paths
# There are 6 possible fusion paths
# 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu,
# 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize
# 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize
# 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize
# 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm
# 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm
fc2_dgrad_gemm_gelu_fusion = (
not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion)
not ctx.fp8
and (ctx.activation == "gelu")
and (not ctx.bias_gelu_fusion)
and (not ctx.debug)
)
# FC2 DGRAD; Unconditional
if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor):
ctx.fc2_weight.update_usage(
rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage,
)
gemm_output, *_ = general_gemm(
fc2_weight,
grad_output,
......@@ -711,7 +797,9 @@ class _LayerNormMLP(torch.autograd.Function):
layout="NN",
grad=True,
quantization_params=(
ctx.grad_fc1_output_quantizer if fc2_dgrad_gemm_gelu_fusion else None
ctx.fc1_grad_input_quantizer
if fc2_dgrad_gemm_gelu_fusion or ctx.debug
else None
), # high precision to activation
out_dtype=ctx.activation_dtype,
gelu=fc2_dgrad_gemm_gelu_fusion,
......@@ -734,39 +822,65 @@ class _LayerNormMLP(torch.autograd.Function):
if isinstance(grad_output, QuantizedTensor):
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm(
act_out,
grad_output,
get_workspace(),
grad_arg = True
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling():
grad_arg = False
general_gemm_fc2_wgrad = functools.partial(
general_gemm,
out_dtype=(
origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
quantization_params=None, # wgrad in high precision
workspace=get_workspace(),
quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision
layout="NT",
grad=True,
bias=fc2_bias if fc2_bias_grad is None else None,
grad=grad_arg,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD,
out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad)
fc2_wgrad = None
else:
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad(
act_out,
grad_output,
)
if fc2_bias_grad is None:
if (
ctx.fp8
and ctx.fp8_recipe.float8_block_scaling()
and fc2_bias is not None
):
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
if ctx.wgrad_store is not None and not ctx.wgrad_store.delay_wgrad_compute():
clear_tensor_data(act_out)
# bias computation
fc1_bias_grad = None
fuse_gemm_and_bias_fc1_wgrad = False
if ctx.grad_fc1_output_quantizer is not None:
ctx.grad_fc1_output_quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.fc1_grad_output_quantizer is not None:
ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.bias_gelu_fusion:
# Fusion: gemm, bias + gelu
assert ctx.activation == "gelu"
assert not ctx.fp8
fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias)
if ctx.grad_fc1_output_quantizer is not None:
dact = ctx.grad_fc1_output_quantizer(dact)
if ctx.fc1_grad_output_quantizer is not None:
dact = ctx.fc1_grad_output_quantizer(dact)
elif ctx.debug:
dact_func = _act_func(ctx.activation)[1]
dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None)
fc1_bias_grad = dact.sum(dim=0)
dact = ctx.fc1_grad_output_quantizer(dact)
elif (
_act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None
and ctx.fp8
......@@ -776,7 +890,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation, ctx.fp8_recipe if ctx.fp8 else None
)[2]
fc1_bias_grad, dact = dbias_dact_quantize_func(
fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer
fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer
) # quantize bgrad gelu fused
else:
# Fusion: gemm + gelu,
......@@ -789,7 +903,14 @@ class _LayerNormMLP(torch.autograd.Function):
) # activation in high precision
if ctx.fp8:
fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer)
# TODO float8 blockwise current scaling has no bgrad fusion for now
if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
dact = ctx.fc1_grad_output_quantizer(dact)
else:
fc1_bias_grad, dact = tex.bgrad_quantize(
dact, ctx.fc1_grad_output_quantizer
)
else:
fuse_gemm_and_bias_fc1_wgrad = (
True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1
......@@ -836,12 +957,20 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None)
# FC1 DGRAD: Unconditional
if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensor
):
ctx.fc1_weight.update_usage(
rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage,
)
fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm(
fc1_weight,
dact,
get_workspace(),
out=fc1_dgrad_bulk,
out_dtype=ctx.activation_dtype,
quantization_params=ctx.fc1_grad_input_quantizer,
layout="NN",
grad=True,
ub=ub_obj_fc1_dgrad,
......@@ -869,6 +998,8 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 WGRAD
fc1_wgrad = None
if ctx.fc1_weight_requires_grad:
# Synchronize tensor-parallel communication for FC1 GEMM input tensor
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer)
if ctx.fp8:
......@@ -880,34 +1011,41 @@ class _LayerNormMLP(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
else:
if ln_out_total_work is not None:
# Synchronize tensor-parallel communication
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.fc1_input_quantizer is not None and not isinstance(
ln_out_total, QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.fc1_input_quantizer(ln_out_total)
# Make sure GEMM inputs have expected data
# Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor):
ln_out_total.update_usage(rowwise_usage=True, columnwise_usage=True)
ln_out_total.update_usage(columnwise_usage=True)
if isinstance(dact, QuantizedTensor):
dact.update_usage(rowwise_usage=True, columnwise_usage=True)
dact.update_usage(columnwise_usage=True)
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf():
fc1_dgrad_rs_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
)
fc1_wgrad_outputs = general_gemm(
ln_out_total,
dact,
get_workspace(),
# wgrad GEMM
general_gemm_fc1_wgrad = functools.partial(
general_gemm,
out_dtype=(
origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
quantization_params=ctx.fc1_grad_weight_quantizer,
grad=fuse_gemm_and_bias_fc1_wgrad,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
accumulate=accumulate_wgrad_into_param_main_grad,
......@@ -917,6 +1055,16 @@ class _LayerNormMLP(torch.autograd.Function):
extra_output=fc1_dgrad_rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad)
fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None
else:
fc1_wgrad_outputs = general_gemm_fc1_wgrad(
ln_out_total,
dact,
)
clear_tensor_data(ln_out_total, dact)
......@@ -931,7 +1079,7 @@ class _LayerNormMLP(torch.autograd.Function):
else:
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True)
# Synchronize tensor parallel communication
# Make sure all tensor-parallel communication is finished
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
......@@ -1040,15 +1188,20 @@ class _LayerNormMLP(torch.autograd.Function):
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # fc1_input_quantizer
None, # fc1_weight_quantizer
None, # fc2_input_quantizer
None, # fc2_weight_quantizer
None, # output_quantizer
None, # grad_fc2_output_quantizer
None, # grad_fc1_output_quantizer
None, # grad_input_quantizer
None, # fc1_input_quantizer,
None, # fc1_weight_quantizer,
None, # fc1_output_quantizer,
None, # fc1_grad_input_quantizer,
None, # fc1_grad_weight_quantizer,
None, # fc1_grad_output_quantizer,
None, # fc2_input_quantizer,
None, # fc2_weight_quantizer,
None, # fc2_output_quantizer,
None, # fc2_grad_input_quantizer,
None, # fc2_grad_weight_quantizer,
None, # fc2_grad_output_quantizer,
None, # cpu_offloading
None, # tp_group
None, # tp_size
......@@ -1074,6 +1227,8 @@ class _LayerNormMLP(torch.autograd.Function):
None, # fsdp_group
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # debug
)
......@@ -1126,6 +1281,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
......@@ -1168,6 +1325,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used.
"""
def __init__(
......@@ -1195,10 +1361,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_overlap_ag: bool = False,
name: str = None,
ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
) -> None:
super().__init__()
......@@ -1217,6 +1386,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
self.gemm_gelu_fusion = (
......@@ -1224,6 +1394,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
and self.activation == "gelu"
and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm()))
)
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if tp_group is None:
self.tp_size = tp_size
......@@ -1252,6 +1428,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_bulk_dgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad
)
if self.symmetric_ar_type is not None:
assert torch_version() >= (
2,
7,
0,
), "Torch version must be at least 2.7 to use symmetric memory"
# Initialize params in FP8
with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
......@@ -1384,7 +1567,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
@no_torch_dynamo()
def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a feedforward network (MLP Block).
......@@ -1407,6 +1592,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
......@@ -1415,18 +1603,41 @@ class LayerNormMLP(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None:
is_first_microbatch = False
fp8_output = False
if self.ub_overlap_rs:
if get_ub("fc2_fprop").is_fp8_ubuf():
fp8_output = True
with self.prepare_forward(inp, num_gemms=2) as inp:
quantizers = (
self._get_quantizers(fp8_output)
if not debug
else self._get_debug_quantizers(fp8_output)
)
if debug:
if not any_feature_enabled(quantizers):
quantizers = self._get_quantizers(fp8_output)
debug = False
if isinstance(self.fc1_weight, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
output_quantizer,
grad_fc1_output_quantizer,
grad_fc2_output_quantizer,
grad_input_quantizer,
) = self._get_quantizers()
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers
# Get weight tensors
fc1_weight = self.fc1_weight
......@@ -1462,15 +1673,20 @@ class LayerNormMLP(TransformerEngineBaseModule):
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_fc1_output_quantizer,
grad_fc2_output_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
......@@ -1479,7 +1695,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin,
......@@ -1492,10 +1708,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(*args)
......@@ -1513,17 +1731,21 @@ class LayerNormMLP(TransformerEngineBaseModule):
return out, ln_out
return out
def _get_quantizers(self):
def _get_quantizers(self, fp8_output):
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
output_quantizer,
grad_fc1_output_quantizer,
grad_fc2_output_quantizer,
grad_input_quantizer,
) = [None] * 8
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = [None] * 12
if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = False # temporary
......@@ -1531,32 +1753,59 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage(
rowwise=True, columnwise=isinstance(fc2_input_quantizer, MXFP8Quantizer)
rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
)
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
if fp8_output:
fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT
]
if torch.is_grad_enabled():
grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][
fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
]
grad_fc2_output_quantizer.internal = True
grad_fc1_output_quantizer = self.quantizers["scaling_bwd"][
fc2_grad_output_quantizer.internal = True
fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
]
grad_fc1_output_quantizer.internal = True
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT2]
grad_input_quantizer.internal = True
fc1_grad_output_quantizer.internal = True
return (
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
output_quantizer,
grad_fc1_output_quantizer,
grad_fc2_output_quantizer,
grad_input_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output):
from ...debug.pytorch.debug_quantization import DebugQuantizer
base_quantizers = list(self._get_quantizers(fp8_output))
assert TEDebugState.debug_enabled
def make_debug(prefix, offset):
labels = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return [
DebugQuantizer(
f"{self.name}.{prefix}",
label,
None if label in ("dgrad", "wgrad") else base_quantizers[i + offset],
self.tp_group,
)
for i, label in enumerate(labels)
]
return tuple(make_debug("fc1", 0) + make_debug("fc2", 6))
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
......@@ -1602,14 +1851,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
# grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# grad_fc1_output_quantizer: also set numerical configs for grad_fc1_output_quantizer
# fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
......@@ -1617,10 +1866,48 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_INPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
if self.sequence_parallel and self.set_parallel_mode:
# grad_fc2_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"):
(fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop()
if self.use_bias and self.fc1_bias.grad is None:
(fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop()
else:
(fc1_wgrad, *_), _ = self.wgrad_store.pop()
fc1_bias_grad = None
if self.use_bias:
if self.fc2_bias.grad is None:
if (
self.fp8
and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling()
and self.apply_bias
and not self.gemm_bias_unfused_add
):
act_out = tensor_list_fc2[0]
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
self.fc2_bias.grad = fc2_bias_grad_.to(self.fc2_bias.dtype)
if self.fc1_bias.grad is None:
self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype)
if not self.fuse_wgrad_accumulation:
if self.fc2_weight.grad is None:
self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype)
if self.fc1_weight.grad is None:
self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype)
del fc2_bias_grad_
del fc2_wgrad
del fc1_wgrad
del fc1_bias_grad
......@@ -7,36 +7,41 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import functools
import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import (
get_workspace,
get_ub,
TransformerEngineBaseModule,
get_dummy_wgrad,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import noop_cat, _fix_gathered_fp8_transpose
from ._common import noop_cat, _fix_gathered_fp8_transpose, WeightGradStore
from ..fp8 import FP8GlobalStateManager
from ..utils import (
cast_if_needed,
clear_tensor_data,
divide,
init_method_constant,
requires_grad,
needs_quantized_gemm,
non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec,
nvtx_range_pop,
nvtx_range_push,
requires_grad,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
allreduce,
symmetric_all_reduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
......@@ -56,10 +61,13 @@ from ..tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
__all__ = ["Linear"]
......@@ -78,11 +86,13 @@ class _Linear(torch.autograd.Function):
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
......@@ -103,6 +113,8 @@ class _Linear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
......@@ -128,6 +140,10 @@ class _Linear(torch.autograd.Function):
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
own_quantized_input = False
# TODO(kwyss): Support FP8 allgather for FP8 block quantization.
force_hp_input_gather = (
fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
......@@ -137,14 +153,22 @@ class _Linear(torch.autograd.Function):
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl:
if force_hp_input_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat, tp_group, quantizer=input_quantizer
)
else:
if not isinstance(inputmat, QuantizedTensor):
columnwise_usage = backward_needs_input and isinstance(
input_quantizer, MXFP8Quantizer
)
# force_hp_input_gather should enforce this
assert not isinstance(input_quantizer, Float8BlockQuantizer)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
......@@ -181,9 +205,9 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# Cast weight to expected dtype
if not fp8:
weightmat = cast_if_needed(weight, activation_dtype)
else:
weightmat = weight
if fp8 or debug:
# Configure quantizer
if weight_quantizer is not None:
columnwise_usage = is_grad_enabled and inp.requires_grad
......@@ -193,7 +217,6 @@ class _Linear(torch.autograd.Function):
and not in_fp8_activation_recompute_phase()
)
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace(
......@@ -203,11 +226,14 @@ class _Linear(torch.autograd.Function):
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
else:
weightmat = cast_if_needed(weightmat, activation_dtype)
# Cast bias to expected dtype
bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32:
if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
......@@ -262,6 +288,7 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.gemm")
if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer
saved_inputmat = None
ctx.backward_input_needs_gather = (
......@@ -275,6 +302,8 @@ class _Linear(torch.autograd.Function):
# can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if force_hp_input_gather:
assert not isinstance(inputmat, QuantizedTensor)
saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
......@@ -282,11 +311,8 @@ class _Linear(torch.autograd.Function):
if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading:
set_offloading_param(weight, "weight_offloading", True)
set_offloading_param(weightmat, "weight_offloading", True)
if saved_inputmat is not None:
set_offloading_param(saved_inputmat, "activation_offloading", True)
if cpu_offloading and saved_inputmat is not None:
mark_activation_offload(saved_inputmat)
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
......@@ -321,15 +347,18 @@ class _Linear(torch.autograd.Function):
ctx.tensor_objects = tensor_objects
ctx.activation_dtype = activation_dtype
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_input_gather = force_hp_input_gather
ctx.input_quantizer = input_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad
ctx.debug = debug
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = bias is not None
......@@ -353,6 +382,7 @@ class _Linear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store
# Row Parallel Linear
if ub_overlap_rs_fprop:
......@@ -362,6 +392,9 @@ class _Linear(torch.autograd.Function):
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
......@@ -471,14 +504,27 @@ class _Linear(torch.autograd.Function):
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
if ctx.grad_output_quantizer is not None:
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
else:
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
(
grad_output,
......@@ -491,21 +537,26 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Prepare input tensor
# Note: Perform tensor-parallel communication if needed
# Launch tensor-parallel communication for input tensor
inputmat_total = None
inputmat_total_work = None
if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None
if ctx.fp8:
if ctx.fp8 or ctx.debug:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False)
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
gather_quantizer = None if ctx.force_hp_input_gather else quantizer
inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
quantizer=gather_quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
......@@ -527,7 +578,6 @@ class _Linear(torch.autograd.Function):
# Update quantizer
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
# dgrad GEMM
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
......@@ -538,6 +588,12 @@ class _Linear(torch.autograd.Function):
recipe.fp8_gemm_dgrad.use_split_accumulator
)
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor):
weight_fp8.update_usage(
rowwise_usage=ctx.weight_quantizer.rowwise_usage,
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
)
dgrad, *_, rs_out = general_gemm(
weight_fp8,
grad_output,
......@@ -573,6 +629,8 @@ class _Linear(torch.autograd.Function):
# Compute grad weight tensor
wgrad = None
if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor
if ctx.ub_bulk_dgrad:
inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
if ctx.fp8:
......@@ -586,18 +644,32 @@ class _Linear(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
inputmat_total._create_transpose()
else:
if inputmat_total_work is not None:
# Synchronize tensor-parallel communication
inputmat_total_work.wait()
inputmat_total_work = None
if ctx.input_quantizer is not None and not isinstance(
inputmat_total, QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmat_total = ctx.input_quantizer(inputmat_total)
# Make sure GEMM inputs have required data
if isinstance(inputmat_total, QuantizedTensor):
inputmat_total.update_usage(columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor):
# This is a no-op if platform supports non-TN FP8 GEMM or the transpose
# already exists.
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
grad_output.update_usage(columnwise_usage=True)
# Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
......@@ -606,39 +678,29 @@ class _Linear(torch.autograd.Function):
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_wgrad.use_split_accumulator
)
wgrad, grad_bias_, _, rs_out = general_gemm(
inputmat_total,
grad_output,
get_workspace(),
layout="NT",
grad=True,
general_gemm_wgrad = functools.partial(
general_gemm,
out_dtype=(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
grad=True,
bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=wgrad_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad,
ub_type=ub_type_wgrad,
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad)
else:
dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True)
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output)
if grad_bias is None:
grad_bias = grad_bias_
......@@ -647,12 +709,19 @@ class _Linear(torch.autograd.Function):
# Deallocate input tensor
if ctx.owns_input:
clear_tensor_data(inputmat_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out
else:
dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True)
# Don't return grad bias if not needed
if not ctx.use_bias:
grad_bias = None
# Synchronize tensor parallel communication
# Make sure all tensor-parallel communication is finished
if inputmat_total_work is not None:
inputmat_total_work.wait()
inputmat_total_work = None
......@@ -669,18 +738,15 @@ class _Linear(torch.autograd.Function):
):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
zero=True,
)
else:
wgrad = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
......@@ -702,11 +768,13 @@ class _Linear(torch.autograd.Function):
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # fuse_wgrad_accumulation
None, # cpu_offloading
None, # tp_group
......@@ -727,6 +795,8 @@ class _Linear(torch.autograd.Function):
None, # fsdp_group
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # debug
)
......@@ -762,6 +832,8 @@ class Linear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
......@@ -797,7 +869,15 @@ class Linear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used.
"""
def __init__(
......@@ -823,6 +903,9 @@ class Linear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
name: Optional[str] = None,
) -> None:
super().__init__()
......@@ -835,6 +918,13 @@ class Linear(TransformerEngineBaseModule):
self.apply_bias = bias and not return_bias
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.symmetric_ar_type = symmetric_ar_type
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if device == "meta":
assert parameters_split is None, "Cannot split module parameters on 'meta' device."
......@@ -900,6 +990,13 @@ class Linear(TransformerEngineBaseModule):
assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
self.ub_name = ub_name
if self.symmetric_ar_type is not None:
assert torch_version() >= (
2,
7,
0,
), "Torch version must be at least 2.7 to use symmetric memory"
# Initialize params in FP8
with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
......@@ -1078,6 +1175,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
......@@ -1085,6 +1186,13 @@ class Linear(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None:
is_first_microbatch = False
if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor),
......@@ -1106,13 +1214,28 @@ class Linear(TransformerEngineBaseModule):
else:
bias_tensor = None
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if not any_feature_enabled(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
) = self._get_quantizers(fp8_output, fp8_grad)
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
......@@ -1133,11 +1256,13 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
......@@ -1158,6 +1283,8 @@ class Linear(TransformerEngineBaseModule):
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = linear_fn(*args)
if self.gemm_bias_unfused_add:
......@@ -1169,8 +1296,9 @@ class Linear(TransformerEngineBaseModule):
def _get_quantizers(self, fp8_output, fp8_grad):
if not self.fp8:
return [None] * 5
return [None] * 6
grad_input_quantizer = None
grad_weight_quantizer = None
grad_output_quantizer = None
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
......@@ -1188,8 +1316,20 @@ class Linear(TransformerEngineBaseModule):
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output, fp8_grad):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
DebugQuantizer(self.name, name, q, self.tp_group)
for name, q in zip(names, original_quantizers)
)
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
......
......@@ -13,6 +13,7 @@ import torch
import transformer_engine_torch as tex
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer
from ...utils import clear_tensor_data, devices_match
from ..op import BasicOperation, OperationContext
from .._common import reshape
......@@ -37,8 +38,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
the first half of the input tensor, while PyTorch applies it to
the second half.
Parameters
----------
cache_quantized_input: bool, default = False
Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is
highly experimental.
"""
def __init__(self, *, cache_quantized_input: bool = False):
super().__init__()
self.cache_quantized_input: bool = cache_quantized_input
@abc.abstractmethod
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
"""Forward implementation
......@@ -100,9 +113,16 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if y.dim() != x.dim():
y = y.reshape(list(x.shape[:-1]) + [-1])
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device)
quantizer.set_usage(rowwise=True, columnwise=False)
x = quantizer(x)
# Save state for backward pass
ctx.save_for_backward(x.detach())
ctx.fp8_enabled = fp8_enabled
ctx.dtype = dtype
ctx.prev_op = prev_op
return y
......@@ -116,10 +136,18 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Saved tensors from forward pass
(x,) = ctx.saved_tensors
# Check input tensor
if isinstance(x, QuantizedTensor):
x = x.dequantize(dtype=ctx.dtype)
elif x.dtype != ctx.dtype:
x = x.to(dtype=ctx.dtype)
if not x.is_contiguous():
x = x.contiguous()
# Check grad output tensor
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dy = dy.dequantize(dtype=ctx.dtype)
if not devices_match(dy.device, x.device) or dy.dtype != x.dtype:
dy = dy.to(device=x.device, dtype=x.dtype)
if not dy.is_contiguous():
......
......@@ -23,6 +23,7 @@ from ...fp8 import FP8GlobalStateManager
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
......@@ -412,7 +413,6 @@ class BasicLinear(BasicOperation):
x = None
x_async = None
with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel
own_quantized_x_local = False
if with_quantized_compute:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
......@@ -428,7 +428,6 @@ class BasicLinear(BasicOperation):
else:
if not isinstance(x_local, QuantizedTensor):
x_local = input_quantizer(x_local)
own_quantized_x_local = True
x = x_local
else:
if isinstance(x_local, QuantizedTensor):
......@@ -483,6 +482,12 @@ class BasicLinear(BasicOperation):
"Attempting to generate MXFP8 output tensor, "
"but GEMM with MXFP8 output is not supported"
)
if isinstance(output_quantizer, Float8BlockQuantizer):
raise RuntimeError(
"Attempting to generate Float8BlockQuantized output tensor, "
"but GEMM with Float8BlockQuantized output is not supported"
)
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -521,16 +526,16 @@ class BasicLinear(BasicOperation):
else:
torch.distributed.all_reduce(y, group=tensor_parallel_group)
# Configure input tensor for backward pass
if own_quantized_x_local:
x_local.update_usage(rowwise_usage=False)
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
if x_local is input:
x_local = x_local.detach()
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
return y, x_local, w
@staticmethod
......@@ -679,7 +684,9 @@ class BasicLinear(BasicOperation):
quantizer=input_quantizer,
)
else:
if not isinstance(x_local, QuantizedTensor):
if isinstance(x_local, QuantizedTensor):
x_local.update_usage(columnwise_usage=True)
else:
x_local = input_quantizer(x_local)
x = x_local
else:
......@@ -706,14 +713,18 @@ class BasicLinear(BasicOperation):
raise ValueError("Weight tensor is required to compute input grad")
w = weight
w_is_quantized = isinstance(w, QuantizedTensor)
if with_quantized_compute and not w_is_quantized:
if with_quantized_compute:
if w_is_quantized:
w.update_usage(columnwise_usage=True)
else:
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
if not with_quantized_compute and w.dtype != dtype:
else:
if w_is_quantized:
w = w.dequantize(dtype=dtype)
elif w.dtype != dtype:
w = w.to(dtype=dtype)
# Synchronize tensor-parallel communication
......@@ -867,8 +878,8 @@ class BasicLinear(BasicOperation):
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer.set_usage(columnwise=weight_requires_grad)
weight_quantizer.set_usage(columnwise=False)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
# Get autocast dtype if needed
dtype = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment