Unverified Commit 8a20d666 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Support tensors with only column-wise data (#1505)



* Delete row-wise data in single-GPU linear forward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Python->C++ parsing of transpose-only Float8Tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug tensor shape calculation without row-wise data
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug correctness issues with only column-wise data
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Only cache column-wise input in LayerNormLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 all-gather with only column-wise data
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix moe cases, lint, rm unused ctx
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CPU activation offloading and use consistent logic for save/restore
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* RM stray file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix distributed and cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix norm cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Rm stray file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* RM stray file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix MXFP8 AG
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix FP8 with sequence parallelism
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix UB bulk dgrad
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0e137883
......@@ -92,7 +92,10 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training)
input.to_cpu();
auto scaling_mode = input.scaling_mode();
assert(input.rowwise_shape().ndim == 2);
assert(input.columnwise_shape().ndim == 2);
if (is_training) {
assert(input.columnwise_shape().ndim == 2);
}
dequantize_1x_kernel(input.rowwise_cpu_dptr<InputType>(),
input.rowwise_cpu_scale_inv_ptr<ScaleType>(),
......
......@@ -83,9 +83,11 @@ size_t product(const NVTEShape &shape, size_t begin, size_t end) {
}
return ret;
}
size_t product(const NVTEShape &shape) {
return product(shape, 0, shape.ndim);
}
size_t product(const std::vector<size_t> shape, size_t begin, size_t end) {
size_t ret = 1;
NVTE_CHECK(end <= shape.size());
......@@ -193,6 +195,7 @@ Tensor::Tensor(const std::string& name,
std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
shape.data[shape.ndim - 1]};
NVTEShape normalized_shape = convertShape(normalized_shape_v);
NVTEShape columnwise_shape{nullptr, 0};
std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
......@@ -207,7 +210,11 @@ Tensor::Tensor(const std::string& name,
columnwise_shape_vec.emplace_back(shape.data[i]);
}
}
const NVTEShape columnwise_shape{columnwise_shape_vec.data(), columnwise_shape_vec.size()};
if (columnwise) {
columnwise_shape.data = columnwise_shape_vec.data();
columnwise_shape.ndim = columnwise_shape_vec.size();
}
tensor_ = TensorWrapper(scaling_mode);
......
......@@ -29,6 +29,9 @@
namespace transformer_engine {
std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &mode);
inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING;
}
......@@ -108,17 +111,8 @@ struct Tensor {
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
int numel() const {
NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr,
"Tensor does not hold any data!");
size_t acc = 1;
if (data.dptr != nullptr) {
for (const auto &dim : data.shape) {
acc *= dim;
}
return acc;
}
// data is empty, use columnwise_data
for (const auto &dim : columnwise_data.shape) {
for (const auto dim : shape()) {
acc *= dim;
}
return acc;
......@@ -126,7 +120,10 @@ struct Tensor {
bool has_data() const noexcept { return data.dptr != nullptr; }
bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr; }
// Check for size (not just pointer) for 0-dim or no token cases.
bool has_columnwise_data() const noexcept {
return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0;
}
DType dtype() const {
if (has_data()) return data.dtype;
......@@ -135,24 +132,54 @@ struct Tensor {
return data.dtype;
}
std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC
* has some bugs with std::vector (see
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
*/
switch (scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING:
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> ret;
if (!columnwise_data.shape.empty()) {
for (size_t i = 1; i < columnwise_data.shape.size(); i++) {
ret.push_back(columnwise_data.shape[i]);
}
ret.push_back(columnwise_data.shape.front());
}
return ret;
} else {
return data.shape;
}
break;
case NVTE_MXFP8_1D_SCALING:
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape;
} else {
return data.shape;
}
break;
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
return {};
}
}
/*! Matrix height after tensor is flattened to 2D
*
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_first_dim() const {
if (!has_data() && has_columnwise_data()) {
const auto &data_shape = columnwise_data.shape;
if (data_shape.empty()) return 1;
if (is_tensor_scaling(scaling_mode)) {
return product(data_shape, 1, data_shape.size());
} else {
return product(data_shape, 0, data_shape.size() - 1);
const auto &full_shape = shape();
size_t ret = 1;
if (!full_shape.empty()) {
for (size_t i = 0; i < full_shape.size() - 1; i++) {
ret *= full_shape[i];
}
}
const auto &data_shape = data.shape;
if (data_shape.empty()) return 1;
return product(data_shape, 0, data_shape.size() - 1);
return ret;
}
/*! Matrix width after tensor is flattened to 2D
......@@ -161,18 +188,12 @@ struct Tensor {
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_last_dim() const {
if (!has_data() && has_columnwise_data()) {
const auto &data_shape = columnwise_data.shape;
if (data_shape.empty()) return 1;
if (is_tensor_scaling(scaling_mode)) {
return data_shape.front();
} else {
return data_shape.back();
}
const auto &full_shape = shape();
if (full_shape.empty()) {
return 1;
} else {
return full_shape.back();
}
const auto &data_shape = data.shape;
if (data_shape.empty()) return 1;
return data_shape.back();
}
};
......@@ -477,9 +498,6 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
bool is_fp8_dtype(const DType t);
std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &type);
/*! \brief Update a tensor's FP8 scale-inverse
*
* The FP8 scale-inverse (dequantization scaling factor) is updated
......
......@@ -555,7 +555,7 @@ class TensorWrapper {
* \return Number of elements in the tensor.
*/
size_t numel() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0;
if (tensor_ == nullptr) return 0;
return nvte_tensor_numel(tensor_);
}
......
......@@ -212,37 +212,58 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
}
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0};
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret;
// FP8 tensor keeps shape in rowwise data
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
return ret;
if (tensor == nullptr) {
NVTE_ERROR("Invalid tensor");
}
NVTEShape ret;
// Get shape based on what data is available
if (t.has_data()) {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
return ret;
}
if (t.has_columnwise_data()) {
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
return ret;
// Determine tensor shape depending on tensor format
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
switch (t.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
// We can infer tensor shape if FP8 tensor only has FP8 data
// transpose. However, NVTEShape only contains a pointer and
// cannot store temporary data. We hack around this by caching
// the tensor shape within the empty FP8 data.
auto &shape_cache = const_cast<std::vector<size_t> &>(t.data.shape);
shape_cache.clear();
if (!t.columnwise_data.shape.empty()) {
for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) {
shape_cache.push_back(t.columnwise_data.shape[i]);
}
shape_cache.push_back(t.columnwise_data.shape.front());
}
ret.data = shape_cache.data();
ret.ndim = shape_cache.size();
} else {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
} else {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
}
break;
}
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"",
transformer_engine::to_string(t.scaling_mode), "\"");
}
// Tensor has no data
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
return ret;
}
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0};
if (tensor == nullptr) {
NVTE_ERROR("Invalid tensor");
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret;
ret.data = t.columnwise_data.shape.data();
......@@ -250,25 +271,20 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
return ret;
}
size_t nvte_tensor_ndim(const NVTETensor tensor) {
if (tensor == nullptr) return 0;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.shape.size();
}
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
if (tensor == nullptr) return 0;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim);
return t.data.shape[dim];
const auto &shape = nvte_tensor_shape(tensor);
NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim,
" in a shape array with ", shape.ndim, " entries");
return shape.data[dim];
}
size_t nvte_tensor_numel(const NVTETensor tensor) {
if (tensor == nullptr) return 0;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
const auto &shape = nvte_tensor_shape(tensor);
size_t numel = 1;
for (auto size : t.data.shape) {
numel *= size;
for (size_t i = 0; i < shape.ndim; i++) {
numel *= shape.data[i];
}
return numel;
}
......
......@@ -419,11 +419,25 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
tensor_list = [state]
for tensor_on_device in tensor_list:
# `tensor_offloaded` is a hacky way of dealing with columnwise-only
# quantized tensors for CPU offloading. The complication is due to
# the `rowwise_data` being `None`. The offloading checker incorrectly
# returns `False` and the entire `state` ([None, columnwise_tensor])
# is added to the tensor tag state dict. A better design would change
# how quantized tensors are kept track of in the offload handler.
# Currently at every stage it is ensured that a quantized tensor is a
# list whereas a non-quantized tensor is standalone object, which is
# not good! TODO(@sanandaraj5597)
tensor_offloaded = False
# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
tensor_offloaded = True
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
if is_quantized_tensor:
self.tensor_tag_to_state[tensor_tag].append(state)
if tensor_offloaded:
self.tensor_tag_to_state[tensor_tag].append(state)
else:
self.tensor_tag_to_state[tensor_tag].append(tensor_on_device)
else:
self.tensor_tag_to_state[tensor_tag] = state
......
......@@ -4,6 +4,10 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <transformer_engine/transformer_engine.h>
#include "common.h"
#include "pybind.h"
......@@ -11,67 +15,72 @@ namespace transformer_engine::pytorch {
namespace detail {
TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) {
const at::Tensor &data = tensor.attr("_data").cast<at::Tensor>();
const at::Tensor &scale_inv = tensor.attr("_scale_inv").cast<at::Tensor>();
float *scale_inv_dptr = reinterpret_cast<float *>(scale_inv.data_ptr());
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
auto ret = TensorWrapper(quantizer->get_scaling_mode());
bool data_exists = !tensor.attr("_data").is_none();
bool transpose_exists =
!tensor.attr("_transpose_invalid").cast<bool>() && !tensor.attr("_transpose").is_none();
const auto &shape = getTensorShape(data);
NVTE_CHECK(data_exists || transpose_exists, "No data found for FP8 Tensor.");
bool transpose_valid = !tensor.attr("_transpose_invalid").cast<bool>();
std::optional<at::Tensor> transpose = std::nullopt;
if (transpose_valid) {
transpose = tensor.attr("_transpose").cast<std::optional<at::Tensor>>();
// FP8 data
const DType fp8_dtype = tensor.attr("_fp8_dtype").cast<DType>();
if (data_exists) {
const auto &data = tensor.attr("_data").cast<at::Tensor>();
ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data));
}
// In the case of being called under tex.dequantize, the quantizer will be NoneQuantizer
// whose scaling mode is defaulted to NVTE_DELAYED_TENSOR_SCALING
auto ret = TensorWrapper(quantizer->get_scaling_mode());
ret.set_rowwise_data(data.data_ptr(), dtype, shape);
if (transpose_valid && transpose != std::nullopt) {
const auto &transpose_shape = getTensorShape(*transpose);
ret.set_columnwise_data(transpose->data_ptr(), dtype, transpose_shape);
// FP8 data transpose
if (transpose_exists) {
const auto &data_transpose = tensor.attr("_transpose").cast<at::Tensor>();
ret.set_columnwise_data(data_transpose.data_ptr(), fp8_dtype, getTensorShape(data_transpose));
}
const auto scale_inv_dtype = GetTransformerEngineDType(scale_inv.scalar_type());
const auto scale_inv_shape = getTensorShape(scale_inv);
ret.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
ret.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
// Scale-inverse
{
const auto &scale_inv = tensor.attr("_scale_inv").cast<at::Tensor>();
float *dptr = reinterpret_cast<float *>(scale_inv.data_ptr());
const auto &dtype = GetTransformerEngineDType(scale_inv.scalar_type());
const auto &shape = getTensorShape(scale_inv);
ret.set_rowwise_scale_inv(dptr, dtype, shape);
ret.set_columnwise_scale_inv(dptr, dtype, shape);
}
// Quantizer state
quantizer->set_quantization_params(&ret);
return ret;
}
TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING);
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor.");
// Row-scaled data
const DType fp8_dtype = tensor.attr("_fp8_dtype").cast<DType>();
if (rowwise_usage) {
const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr();
const auto &shape = getTensorShape(data_rowwise);
ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape);
const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat8E8M0, scale_inv_rowwise_shape);
const auto &data = tensor.attr("_rowwise_data").cast<at::Tensor>();
const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data));
ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv));
}
// Column-scaled data
if (columnwise_usage) {
const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr();
const auto &shape = getTensorShape(data_colwise);
ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape);
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat8E8M0,
scale_inv_colwise_shape);
const auto &data = tensor.attr("_columnwise_data").cast<at::Tensor>();
const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data));
ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0,
getTensorShape(scale_inv));
}
// Quantizer state
quantizer->set_quantization_params(&ret);
return ret;
}
......
......@@ -22,7 +22,7 @@ from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
......@@ -819,30 +819,30 @@ class CudaRNGStatesTracker:
def reduce_scatter_along_first_dim(
input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_distributed_world_size(tp_group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_, None
return inp, None
dim_size = list(input_.size())
dim_size = list(inp.size())
assert (
dim_size[0] % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device())
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op
output, inp.contiguous(), group=tp_group, async_op=async_op
)
return output, handle
def _all_gather_fp8(
input_: torch.Tensor,
inp: torch.Tensor,
process_group: dist_group_type,
*,
async_op: bool = False,
......@@ -854,18 +854,18 @@ def _all_gather_fp8(
# Output tensor dims
if out_shape is None:
out_shape = list(input_.size())
out_shape = list(inp.size())
out_shape[0] *= world_size
# Quantize input tensor if needed
if not isinstance(input_, Float8TensorBase):
if not isinstance(inp, Float8TensorBase):
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(columnwise=False)
input_ = quantizer(input_)
inp = quantizer(inp)
quantizer.set_usage(columnwise=init_columnwise_usage)
# Construct output tensor
......@@ -873,30 +873,30 @@ def _all_gather_fp8(
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
dtype = torch.float32
device = "cuda"
if isinstance(input_, Float8Tensor):
dtype = input_.dtype
device = input_.device
if isinstance(inp, Float8Tensor):
dtype = inp.dtype
device = inp.device
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
elif isinstance(input_, Float8Tensor):
out = input_.make_like(input_, shape=out_shape)
elif isinstance(inp, Float8Tensor):
out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty_like(
out_shape,
dtype=torch.uint8,
device=input_.device,
device=inp.device,
)
out._transpose = None
out._transpose_invalid = True
else:
raise RuntimeError("FP8TensorBase is not supported yet without Quantizer")
# For delayed scaling, scale_inv is from history, so we can pass it from input_ to out
# For delayed scaling, scale_inv is from history, so we can pass it from inp to out
# For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv,
# so we can just pass it from input_ to out
out._scale_inv = input_._scale_inv
# so we can just pass it from inp to out
out._scale_inv = inp._scale_inv
# Perform communication
handle = torch.distributed.all_gather_into_tensor(
out._data,
input_._data.contiguous(),
inp._data.contiguous(),
group=process_group,
async_op=async_op,
)
......@@ -914,7 +914,7 @@ def _all_gather_fp8(
def _all_gather_mxfp8(
input_: torch.Tensor,
inp: torch.Tensor,
process_group: dist_group_type,
*,
async_op: bool = False,
......@@ -925,27 +925,56 @@ def _all_gather_mxfp8(
# Tensor dims
world_size = get_distributed_world_size(process_group)
in_shape = list(input_.size())
in_shape = list(inp.size())
if out_shape is None:
out_shape = [in_shape[0] * world_size] + in_shape[1:]
# Gather MXFP8 data for row-wise usage
if quantizer.rowwise_usage and not quantizer.columnwise_usage:
# For cases where inp has dimensions that cannot be quantized,
# we gather in high precision followed by a cast to FP8.
if (
not isinstance(inp, MXFP8TensorBase)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
out = torch.empty(
out_shape,
dtype=inp.dtype,
device=inp.device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out)
return out, None
inp_dtype = inp.dtype
inp_device = inp.device
# Cast input tensor to MXFP8 with required data
if not isinstance(inp, MXFP8TensorBase):
inp = quantizer(inp)
elif (
inp.rowwise_data is None
and quantizer.rowwise_usage
or inp.columnwise_data is None
and quantizer.columnwise_usage
):
warnings.warn(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to MXFP8."
)
inp = quantizer(inp.dequantize())
# Cast input tensor to MXFP8 if needed
if not isinstance(input_, MXFP8TensorBase):
input_ = quantizer(input_)
# Construct MXFP8 output tensor
out = quantizer.make_empty(out_shape, dtype=inp_dtype, device=inp_device)
# Construct MXFP8 output tensor
dtype = torch.float32
device = "cuda"
if isinstance(input_, MXFP8Tensor):
dtype = input_.dtype
device = input_.device
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Async op handle
handle = None
# Gather MXFP8 data for row-wise usage
if quantizer.rowwise_usage:
# Remove padding from MXFP8 scale-inverses
in_scale_inv = input_._rowwise_scale_inv
in_scale_inv = inp._rowwise_scale_inv
out_scale_inv = out._rowwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1])
if in_scale_inv.size(0) != flattened_in_shape0:
......@@ -954,40 +983,52 @@ def _all_gather_mxfp8(
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers
with torch.distributed._coalescing_manager(
if handle is not None:
handle.wait()
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
device=device,
async_ops=async_op,
) as coalescing_manager:
torch.distributed.all_gather_into_tensor(
out._rowwise_data,
input_._rowwise_data,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
handle = coalescing_manager if async_op else None
return out, handle
)
handle = torch.distributed.all_gather_into_tensor(
out._rowwise_data,
inp._rowwise_data,
group=process_group,
async_op=async_op,
)
# Gather in high precision and quantize for column-wise usage
if isinstance(input_, QuantizedTensor):
input_ = input_.dequantize(dtype=torch.bfloat16)
out = torch.empty(
out_shape,
dtype=input_.dtype,
device=input_.device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, input_, group=process_group)
out = quantizer(out)
return out, None
# Gather MXFP8 data for column-wise usage
if quantizer.columnwise_usage:
# Remove padding from MXFP8 scale-inverses
in_scale_inv = inp._columnwise_scale_inv
out_scale_inv = out._columnwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1]) // 32
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers
if handle is not None:
handle.wait()
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
handle = torch.distributed.all_gather_into_tensor(
out._columnwise_data,
inp._columnwise_data,
group=process_group,
async_op=async_op,
)
return out, handle
def gather_along_first_dim(
input_: torch.Tensor,
inp: torch.Tensor,
process_group: dist_group_type,
async_op: bool = False,
quantizer: Optional[Quantizer] = None,
......@@ -997,20 +1038,20 @@ def gather_along_first_dim(
# Return immediately if no communication is required
world_size = get_distributed_world_size(process_group)
if world_size == 1:
if quantizer is not None and not isinstance(input_, QuantizedTensor):
input_ = quantizer(input_)
return input_, None
if quantizer is not None and not isinstance(inp, QuantizedTensor):
inp = quantizer(inp)
return inp, None
# Output tensor dims
out_shape = list(input_.size())
out_shape = list(inp.size())
out_shape[0] *= world_size
# FP8 case: delayed scaling or current scaling
if isinstance(input_, Float8TensorBase) or isinstance(
if isinstance(inp, Float8TensorBase) or isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
return _all_gather_fp8(
input_,
inp,
process_group,
async_op=async_op,
quantizer=quantizer,
......@@ -1018,10 +1059,10 @@ def gather_along_first_dim(
)
# MXFP8 case
if isinstance(input_, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer):
if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer):
assert isinstance(quantizer, MXFP8Quantizer)
return _all_gather_mxfp8(
input_,
inp,
process_group,
async_op=async_op,
quantizer=quantizer,
......@@ -1034,36 +1075,36 @@ def gather_along_first_dim(
"Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather."
)
if isinstance(input_, QuantizedTensor):
input_ = input_.dequantize()
if isinstance(inp, QuantizedTensor):
inp = inp.dequantize()
out = torch.empty(
out_shape,
dtype=input_.dtype,
device=input_.device,
dtype=inp.dtype,
device=inp.device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, input_, group=process_group)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out)
return out, None
# Dequantize quantized tensor if not supported
if isinstance(input_, QuantizedTensor):
if isinstance(inp, QuantizedTensor):
warnings.warn(
"Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather."
)
input_ = input_.dequantize()
inp = inp.dequantize()
# Communication for plain PyTorch tensors
out = torch.empty(
out_shape,
dtype=input_.dtype,
device=input_.device,
dtype=inp.dtype,
device=inp.device,
memory_format=torch.contiguous_format,
)
handle = torch.distributed.all_gather_into_tensor(
out,
input_.contiguous(),
inp.contiguous(),
group=process_group,
async_op=async_op,
)
......@@ -1071,7 +1112,7 @@ def gather_along_first_dim(
def allreduce(
input_: torch.Tensor,
inp: torch.Tensor,
tp_group: Optional[dist_group_type] = None,
async_op: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
......@@ -1079,12 +1120,12 @@ def allreduce(
# Bypass the function if we are using only 1 GPU.
if get_distributed_world_size(tp_group) == 1:
return input_, None
return inp, None
# All-reduce.
handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op)
handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op)
return input_, handle
return inp, handle
def _fsdp_scatter_tensors(
......
......@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import (
restore_from_saved,
)
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpp_extensions import (
......@@ -326,6 +327,19 @@ class _LayerNormLinear(torch.autograd.Function):
clear_tensor_data(ln_out, ln_out_total)
if is_grad_enabled:
ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)
# Input with column-wise usage is needed for dgrad GEMM.
if backward_needs_input:
if isinstance(ln_out, QuantizedTensor):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False)
if cpu_offloading:
if fp8 and weightmat is not None:
set_offloading_param(weightmat, "weight_offloading", True)
......@@ -556,12 +570,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Note: Perform tensor-parallel communication if needed
ln_out_total = None
ln_out_total_work = None
if (
ctx.requires_wgrad
and ctx.parallel_mode == "column"
and ctx.sequence_parallel
and not ctx.ub_bulk_dgrad
):
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None
if ctx.fp8:
quantizer = ctx.input_quantizer
......
......@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
......@@ -155,6 +156,7 @@ class _Linear(torch.autograd.Function):
)
if not isinstance(inputmat, QuantizedTensor):
inputmat = input_quantizer(inputmat)
own_quantized_input = True
elif backward_needs_input:
inputmat.update_usage(rowwise_usage=True, columnwise_usage=True)
inputmat_total = inputmat
......@@ -251,9 +253,18 @@ class _Linear(torch.autograd.Function):
if is_grad_enabled:
saved_inputmat = None
ctx.backward_input_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)
if backward_needs_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensor):
inputmat.update_usage(rowwise_usage=False)
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
inputmat.update_usage(rowwise_usage=False)
saved_inputmat = inputmat
if cpu_offloading:
......@@ -311,7 +322,6 @@ class _Linear(torch.autograd.Function):
ctx.requires_wgrad = weight.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.owns_input = saved_inputmat is not inp
ctx.is_input_fp8 = not own_quantized_input
if ctx.fp8 and requires_grad(inp, weight, bias):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
......@@ -452,12 +462,7 @@ class _Linear(torch.autograd.Function):
# Note: Perform tensor-parallel communication if needed
inputmat_total = None
inputmat_total_work = None
if (
ctx.requires_wgrad
and ctx.parallel_mode == "column"
and ctx.sequence_parallel
and not ctx.ub_bulk_dgrad
):
if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None
if ctx.fp8:
quantizer = ctx.input_quantizer
......
......@@ -523,9 +523,7 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass
if own_quantized_x_local:
### TODO Restore once column-wise usage is supported by itself # pylint: disable=fixme
# x_local.update_usage(rowwise_usage=False)
pass
x_local.update_usage(rowwise_usage=False)
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
......
......@@ -5,6 +5,7 @@
"""Mixin class holding data specific for Float8Tensor"""
from __future__ import annotations
import math
from typing import Any, Dict, Optional, Tuple
import torch
......@@ -120,7 +121,10 @@ class Float8TensorBase:
def size(self, *args, **kwargs):
# pylint: disable=missing-function-docstring
return self._data.size(*args, **kwargs)
if self._data is not None:
return self._data.size(*args, **kwargs)
size = self._transpose.size(*args, **kwargs)
return torch.Size([size[-1], math.prod(size[:-1])])
def __repr__(self):
return (
......
......@@ -115,7 +115,9 @@ class MXFP8TensorBase:
def size(self, *args, **kwargs):
# pylint: disable=missing-function-docstring
return self._rowwise_data.size(*args, **kwargs)
if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs)
return self._columnwise_data.size(*args, **kwargs)
def __repr__(self):
data_rowwise = self.dequantize()
......
......@@ -428,21 +428,40 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose)
self._transpose_invalid = False
def update_usage(self, rowwise_usage=True, columnwise_usage=True):
assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor"
if rowwise_usage:
assert self._data is not None, "Rowwise usage of the tensor was already disabled"
else:
if not non_tn_fp8_gemm_supported():
if self._transpose is None or self._transpose_invalid:
self._create_transpose()
self._data = None
if columnwise_usage:
if self._transpose is None or self._transpose_invalid:
assert self._data is not None, "The tensor does not hold any data anymore"
if not non_tn_fp8_gemm_supported():
self._create_transpose()
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
# Figure out what data is available and what is required
has_data = self._data is not None
has_data_transpose = self._transpose is not None and not self._transpose_invalid
needs_data = has_data
needs_data_transpose = has_data_transpose
if non_tn_fp8_gemm_supported():
if rowwise_usage is not None and rowwise_usage:
needs_data = True
if columnwise_usage is not None and columnwise_usage:
needs_data = True
needs_data_transpose = False
else:
if rowwise_usage is not None:
needs_data = rowwise_usage
if columnwise_usage is not None:
needs_data_transpose = columnwise_usage
# Generate data that is required
if needs_data and not has_data:
raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose")
if needs_data_transpose and not has_data_transpose:
if not has_data:
raise RuntimeError("FP8 data is required to generate FP8 data transpose")
self._create_transpose()
# Delete data that is not required
if not needs_data:
self._data = None
if not needs_data_transpose:
self._transpose = None
self._transpose_invalid = True
......
......@@ -66,6 +66,16 @@ class MXFP8Quantizer(Quantizer):
return dst
def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
if inp.ndim < 2:
return False
if inp.shape[-1] % MXFP8_BLOCK_SCALING_SIZE != 0:
return False
if math.prod(inp.shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE != 0:
return False
return True
def make_empty(
self,
shape: Iterable[int],
......@@ -207,36 +217,50 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# TODO(ksivamani): Fix the detach bug
return MXFP8Tensor.make_like(self)
def update_usage(self, rowwise_usage=True, columnwise_usage=True):
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""
For MXFP8, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages.
"""
assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor."
if columnwise_usage and rowwise_usage:
assert (
self._rowwise_data is not None
and self._rowwise_scale_inv is not None
and self._columnwise_data is not None
and self._columnwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage."
return
# Default usage is based on available data
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
# Update row-scaled data
if rowwise_usage:
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "Cannot update to rowwise usage."
if self._rowwise_data is None:
raise RuntimeError(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data"
)
if self._rowwise_scale_inv is None:
raise RuntimeError(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses"
)
else:
self._rowwise_data = None
self._rowwise_scale_inv = None
# Update column-scaled data
if columnwise_usage:
if self._columnwise_data is None:
raise RuntimeError(
"Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data"
)
if self._columnwise_scale_inv is None:
raise RuntimeError(
"Requested column-wise usage, "
"but MXFP8Tensor is missing column-scaled scale-inverses"
)
else:
self._columnwise_data = None
self._columnwise_scale_inv = None
return
assert (
self._columnwise_data is not None and self._columnwise_scale_inv is not None
), "Cannot update to columnwise usage."
self._rowwise_data = None
self._rowwise_scale_inv = None
return
def clone(self) -> MXFP8Tensor:
# pylint: disable=missing-function-docstring
......
......@@ -21,13 +21,10 @@ def prepare_for_saving(
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal TensorBase types too."""
# pylint: disable=unidiomatic-typecheck # Using type instead of isinstance to check exact type
tensor_list, tensor_objects_list = [], []
for tensor in tensors:
if tensor is None:
tensor_list.append(None)
tensor_objects_list.append(None)
elif isinstance(tensor, torch.Tensor):
if tensor is None or isinstance(tensor, torch.Tensor):
tensor_list.append(tensor)
tensor_objects_list.append(None)
else:
......@@ -44,7 +41,7 @@ def restore_from_saved(
"""Recombine the tensor data and metadata during backward pass."""
tensor_objects = []
for tensor in tensors:
if tensor is None:
if tensor is None or isinstance(tensor, torch.Tensor):
tensor_objects.append(saved_tensors[0])
saved_tensors = saved_tensors[1:]
else:
......@@ -289,7 +286,11 @@ class QuantizedTensor(torch.Tensor):
f"{self.__class__.__name__} class does not implement detach function"
)
def update_usage(self, rowwise_usage=True, columnwise_usage=True):
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""Indicate to the tensor how it is going to be used
This enables optimizations to memory usage in some cases
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment