Unverified Commit 8a20d666 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Support tensors with only column-wise data (#1505)



* Delete row-wise data in single-GPU linear forward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Python->C++ parsing of transpose-only Float8Tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug tensor shape calculation without row-wise data
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug correctness issues with only column-wise data
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Only cache column-wise input in LayerNormLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 all-gather with only column-wise data
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix moe cases, lint, rm unused ctx
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CPU activation offloading and use consistent logic for save/restore
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* RM stray file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix distributed and cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix norm cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Rm stray file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* RM stray file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix MXFP8 AG
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix FP8 with sequence parallelism
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix UB bulk dgrad
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0e137883
...@@ -92,7 +92,10 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training) ...@@ -92,7 +92,10 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training)
input.to_cpu(); input.to_cpu();
auto scaling_mode = input.scaling_mode(); auto scaling_mode = input.scaling_mode();
assert(input.rowwise_shape().ndim == 2); assert(input.rowwise_shape().ndim == 2);
if (is_training) {
assert(input.columnwise_shape().ndim == 2); assert(input.columnwise_shape().ndim == 2);
}
dequantize_1x_kernel(input.rowwise_cpu_dptr<InputType>(), dequantize_1x_kernel(input.rowwise_cpu_dptr<InputType>(),
input.rowwise_cpu_scale_inv_ptr<ScaleType>(), input.rowwise_cpu_scale_inv_ptr<ScaleType>(),
......
...@@ -83,9 +83,11 @@ size_t product(const NVTEShape &shape, size_t begin, size_t end) { ...@@ -83,9 +83,11 @@ size_t product(const NVTEShape &shape, size_t begin, size_t end) {
} }
return ret; return ret;
} }
size_t product(const NVTEShape &shape) { size_t product(const NVTEShape &shape) {
return product(shape, 0, shape.ndim); return product(shape, 0, shape.ndim);
} }
size_t product(const std::vector<size_t> shape, size_t begin, size_t end) { size_t product(const std::vector<size_t> shape, size_t begin, size_t end) {
size_t ret = 1; size_t ret = 1;
NVTE_CHECK(end <= shape.size()); NVTE_CHECK(end <= shape.size());
...@@ -193,6 +195,7 @@ Tensor::Tensor(const std::string& name, ...@@ -193,6 +195,7 @@ Tensor::Tensor(const std::string& name,
std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1), std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
shape.data[shape.ndim - 1]}; shape.data[shape.ndim - 1]};
NVTEShape normalized_shape = convertShape(normalized_shape_v); NVTEShape normalized_shape = convertShape(normalized_shape_v);
NVTEShape columnwise_shape{nullptr, 0};
std::vector<size_t> columnwise_shape_vec; std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
...@@ -207,7 +210,11 @@ Tensor::Tensor(const std::string& name, ...@@ -207,7 +210,11 @@ Tensor::Tensor(const std::string& name,
columnwise_shape_vec.emplace_back(shape.data[i]); columnwise_shape_vec.emplace_back(shape.data[i]);
} }
} }
const NVTEShape columnwise_shape{columnwise_shape_vec.data(), columnwise_shape_vec.size()};
if (columnwise) {
columnwise_shape.data = columnwise_shape_vec.data();
columnwise_shape.ndim = columnwise_shape_vec.size();
}
tensor_ = TensorWrapper(scaling_mode); tensor_ = TensorWrapper(scaling_mode);
......
...@@ -29,6 +29,9 @@ ...@@ -29,6 +29,9 @@
namespace transformer_engine { namespace transformer_engine {
std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &mode);
inline bool is_tensor_scaling(const NVTEScalingMode &mode) { inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING; return mode == NVTE_DELAYED_TENSOR_SCALING;
} }
...@@ -108,17 +111,8 @@ struct Tensor { ...@@ -108,17 +111,8 @@ struct Tensor {
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
int numel() const { int numel() const {
NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr,
"Tensor does not hold any data!");
size_t acc = 1; size_t acc = 1;
if (data.dptr != nullptr) { for (const auto dim : shape()) {
for (const auto &dim : data.shape) {
acc *= dim;
}
return acc;
}
// data is empty, use columnwise_data
for (const auto &dim : columnwise_data.shape) {
acc *= dim; acc *= dim;
} }
return acc; return acc;
...@@ -126,7 +120,10 @@ struct Tensor { ...@@ -126,7 +120,10 @@ struct Tensor {
bool has_data() const noexcept { return data.dptr != nullptr; } bool has_data() const noexcept { return data.dptr != nullptr; }
bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr; } // Check for size (not just pointer) for 0-dim or no token cases.
bool has_columnwise_data() const noexcept {
return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0;
}
DType dtype() const { DType dtype() const {
if (has_data()) return data.dtype; if (has_data()) return data.dtype;
...@@ -135,24 +132,54 @@ struct Tensor { ...@@ -135,24 +132,54 @@ struct Tensor {
return data.dtype; return data.dtype;
} }
std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC
* has some bugs with std::vector (see
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
*/
switch (scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING:
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> ret;
if (!columnwise_data.shape.empty()) {
for (size_t i = 1; i < columnwise_data.shape.size(); i++) {
ret.push_back(columnwise_data.shape[i]);
}
ret.push_back(columnwise_data.shape.front());
}
return ret;
} else {
return data.shape;
}
break;
case NVTE_MXFP8_1D_SCALING:
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape;
} else {
return data.shape;
}
break;
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
return {};
}
}
/*! Matrix height after tensor is flattened to 2D /*! Matrix height after tensor is flattened to 2D
* *
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
* as a (D1*D2*...*D(n-1), Dn) matrix. * as a (D1*D2*...*D(n-1), Dn) matrix.
*/ */
size_t flat_first_dim() const { size_t flat_first_dim() const {
if (!has_data() && has_columnwise_data()) { const auto &full_shape = shape();
const auto &data_shape = columnwise_data.shape; size_t ret = 1;
if (data_shape.empty()) return 1; if (!full_shape.empty()) {
if (is_tensor_scaling(scaling_mode)) { for (size_t i = 0; i < full_shape.size() - 1; i++) {
return product(data_shape, 1, data_shape.size()); ret *= full_shape[i];
} else {
return product(data_shape, 0, data_shape.size() - 1);
} }
} }
const auto &data_shape = data.shape; return ret;
if (data_shape.empty()) return 1;
return product(data_shape, 0, data_shape.size() - 1);
} }
/*! Matrix width after tensor is flattened to 2D /*! Matrix width after tensor is flattened to 2D
...@@ -161,19 +188,13 @@ struct Tensor { ...@@ -161,19 +188,13 @@ struct Tensor {
* as a (D1*D2*...*D(n-1), Dn) matrix. * as a (D1*D2*...*D(n-1), Dn) matrix.
*/ */
size_t flat_last_dim() const { size_t flat_last_dim() const {
if (!has_data() && has_columnwise_data()) { const auto &full_shape = shape();
const auto &data_shape = columnwise_data.shape; if (full_shape.empty()) {
if (data_shape.empty()) return 1; return 1;
if (is_tensor_scaling(scaling_mode)) {
return data_shape.front();
} else { } else {
return data_shape.back(); return full_shape.back();
} }
} }
const auto &data_shape = data.shape;
if (data_shape.empty()) return 1;
return data_shape.back();
}
}; };
struct QuantizationConfig { struct QuantizationConfig {
...@@ -477,9 +498,6 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -477,9 +498,6 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
bool is_fp8_dtype(const DType t); bool is_fp8_dtype(const DType t);
std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &type);
/*! \brief Update a tensor's FP8 scale-inverse /*! \brief Update a tensor's FP8 scale-inverse
* *
* The FP8 scale-inverse (dequantization scaling factor) is updated * The FP8 scale-inverse (dequantization scaling factor) is updated
......
...@@ -555,7 +555,7 @@ class TensorWrapper { ...@@ -555,7 +555,7 @@ class TensorWrapper {
* \return Number of elements in the tensor. * \return Number of elements in the tensor.
*/ */
size_t numel() const noexcept { size_t numel() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0; if (tensor_ == nullptr) return 0;
return nvte_tensor_numel(tensor_); return nvte_tensor_numel(tensor_);
} }
......
...@@ -212,37 +212,58 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) { ...@@ -212,37 +212,58 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
} }
NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0}; if (tensor == nullptr) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); NVTE_ERROR("Invalid tensor");
}
NVTEShape ret; NVTEShape ret;
// FP8 tensor keeps shape in rowwise data // Determine tensor shape depending on tensor format
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
switch (t.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
// We can infer tensor shape if FP8 tensor only has FP8 data
// transpose. However, NVTEShape only contains a pointer and
// cannot store temporary data. We hack around this by caching
// the tensor shape within the empty FP8 data.
auto &shape_cache = const_cast<std::vector<size_t> &>(t.data.shape);
shape_cache.clear();
if (!t.columnwise_data.shape.empty()) {
for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) {
shape_cache.push_back(t.columnwise_data.shape[i]);
}
shape_cache.push_back(t.columnwise_data.shape.front());
}
ret.data = shape_cache.data();
ret.ndim = shape_cache.size();
} else {
ret.data = t.data.shape.data(); ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size(); ret.ndim = t.data.shape.size();
return ret;
} }
break;
// Get shape based on what data is available
if (t.has_data()) {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
return ret;
} }
if (t.has_columnwise_data()) { case NVTE_MXFP8_1D_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
ret.data = t.columnwise_data.shape.data(); ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size(); ret.ndim = t.columnwise_data.shape.size();
return ret; } else {
}
// Tensor has no data
ret.data = t.data.shape.data(); ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size(); ret.ndim = t.data.shape.size();
}
break;
}
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"",
transformer_engine::to_string(t.scaling_mode), "\"");
}
return ret; return ret;
} }
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0}; if (tensor == nullptr) {
NVTE_ERROR("Invalid tensor");
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret; NVTEShape ret;
ret.data = t.columnwise_data.shape.data(); ret.data = t.columnwise_data.shape.data();
...@@ -250,25 +271,20 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { ...@@ -250,25 +271,20 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
return ret; return ret;
} }
size_t nvte_tensor_ndim(const NVTETensor tensor) { size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
if (tensor == nullptr) return 0;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.shape.size();
}
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
if (tensor == nullptr) return 0; const auto &shape = nvte_tensor_shape(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim,
NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim); " in a shape array with ", shape.ndim, " entries");
return t.data.shape[dim]; return shape.data[dim];
} }
size_t nvte_tensor_numel(const NVTETensor tensor) { size_t nvte_tensor_numel(const NVTETensor tensor) {
if (tensor == nullptr) return 0; const auto &shape = nvte_tensor_shape(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
size_t numel = 1; size_t numel = 1;
for (auto size : t.data.shape) { for (size_t i = 0; i < shape.ndim; i++) {
numel *= size; numel *= shape.data[i];
} }
return numel; return numel;
} }
......
...@@ -419,11 +419,25 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -419,11 +419,25 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
tensor_list = [state] tensor_list = [state]
for tensor_on_device in tensor_list: for tensor_on_device in tensor_list:
# `tensor_offloaded` is a hacky way of dealing with columnwise-only
# quantized tensors for CPU offloading. The complication is due to
# the `rowwise_data` being `None`. The offloading checker incorrectly
# returns `False` and the entire `state` ([None, columnwise_tensor])
# is added to the tensor tag state dict. A better design would change
# how quantized tensors are kept track of in the offload handler.
# Currently at every stage it is ensured that a quantized tensor is a
# list whereas a non-quantized tensor is standalone object, which is
# not good! TODO(@sanandaraj5597)
tensor_offloaded = False
# if offload, return the reference to cpu copy # if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device): if self.tensor_need_offloading_checker(tensor_on_device):
tensor_offloaded = True
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
if is_quantized_tensor: if is_quantized_tensor:
if tensor_offloaded:
self.tensor_tag_to_state[tensor_tag].append(state) self.tensor_tag_to_state[tensor_tag].append(state)
else:
self.tensor_tag_to_state[tensor_tag].append(tensor_on_device)
else: else:
self.tensor_tag_to_state[tensor_tag] = state self.tensor_tag_to_state[tensor_tag] = state
......
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <transformer_engine/transformer_engine.h>
#include "common.h" #include "common.h"
#include "pybind.h" #include "pybind.h"
...@@ -11,67 +15,72 @@ namespace transformer_engine::pytorch { ...@@ -11,67 +15,72 @@ namespace transformer_engine::pytorch {
namespace detail { namespace detail {
TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) { TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) {
const at::Tensor &data = tensor.attr("_data").cast<at::Tensor>(); auto ret = TensorWrapper(quantizer->get_scaling_mode());
const at::Tensor &scale_inv = tensor.attr("_scale_inv").cast<at::Tensor>();
float *scale_inv_dptr = reinterpret_cast<float *>(scale_inv.data_ptr()); bool data_exists = !tensor.attr("_data").is_none();
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>(); bool transpose_exists =
!tensor.attr("_transpose_invalid").cast<bool>() && !tensor.attr("_transpose").is_none();
const auto &shape = getTensorShape(data); NVTE_CHECK(data_exists || transpose_exists, "No data found for FP8 Tensor.");
bool transpose_valid = !tensor.attr("_transpose_invalid").cast<bool>(); // FP8 data
std::optional<at::Tensor> transpose = std::nullopt; const DType fp8_dtype = tensor.attr("_fp8_dtype").cast<DType>();
if (transpose_valid) { if (data_exists) {
transpose = tensor.attr("_transpose").cast<std::optional<at::Tensor>>(); const auto &data = tensor.attr("_data").cast<at::Tensor>();
ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data));
} }
// In the case of being called under tex.dequantize, the quantizer will be NoneQuantizer
// whose scaling mode is defaulted to NVTE_DELAYED_TENSOR_SCALING
auto ret = TensorWrapper(quantizer->get_scaling_mode());
ret.set_rowwise_data(data.data_ptr(), dtype, shape); // FP8 data transpose
if (transpose_valid && transpose != std::nullopt) { if (transpose_exists) {
const auto &transpose_shape = getTensorShape(*transpose); const auto &data_transpose = tensor.attr("_transpose").cast<at::Tensor>();
ret.set_columnwise_data(transpose->data_ptr(), dtype, transpose_shape); ret.set_columnwise_data(data_transpose.data_ptr(), fp8_dtype, getTensorShape(data_transpose));
} }
const auto scale_inv_dtype = GetTransformerEngineDType(scale_inv.scalar_type()); // Scale-inverse
const auto scale_inv_shape = getTensorShape(scale_inv); {
ret.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); const auto &scale_inv = tensor.attr("_scale_inv").cast<at::Tensor>();
ret.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); float *dptr = reinterpret_cast<float *>(scale_inv.data_ptr());
const auto &dtype = GetTransformerEngineDType(scale_inv.scalar_type());
const auto &shape = getTensorShape(scale_inv);
ret.set_rowwise_scale_inv(dptr, dtype, shape);
ret.set_columnwise_scale_inv(dptr, dtype, shape);
}
// Quantizer state
quantizer->set_quantization_params(&ret); quantizer->set_quantization_params(&ret);
return ret; return ret;
} }
TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING);
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor.");
// Row-scaled data
const DType fp8_dtype = tensor.attr("_fp8_dtype").cast<DType>();
if (rowwise_usage) { if (rowwise_usage) {
const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>(); const auto &data = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>(); const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data));
const auto &shape = getTensorShape(data_rowwise); ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv));
ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape);
const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat8E8M0, scale_inv_rowwise_shape);
} }
// Column-scaled data
if (columnwise_usage) { if (columnwise_usage) {
const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>(); const auto &data = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>(); const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data));
const auto &shape = getTensorShape(data_colwise); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0,
ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); getTensorShape(scale_inv));
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat8E8M0,
scale_inv_colwise_shape);
} }
// Quantizer state
quantizer->set_quantization_params(&ret); quantizer->set_quantization_params(&ret);
return ret; return ret;
} }
......
...@@ -22,7 +22,7 @@ from .utils import safely_set_viewless_tensor_data ...@@ -22,7 +22,7 @@ from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager from .fp8 import FP8GlobalStateManager
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
...@@ -819,30 +819,30 @@ class CudaRNGStatesTracker: ...@@ -819,30 +819,30 @@ class CudaRNGStatesTracker:
def reduce_scatter_along_first_dim( def reduce_scatter_along_first_dim(
input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: ) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
"""Reduce-scatter the input tensor across model parallel group.""" """Reduce-scatter the input tensor across model parallel group."""
world_size = get_distributed_world_size(tp_group) world_size = get_distributed_world_size(tp_group)
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size == 1: if world_size == 1:
return input_, None return inp, None
dim_size = list(input_.size()) dim_size = list(inp.size())
assert ( assert (
dim_size[0] % world_size == 0 dim_size[0] % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size" ), "First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device())
handle = torch.distributed.reduce_scatter_tensor( handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op output, inp.contiguous(), group=tp_group, async_op=async_op
) )
return output, handle return output, handle
def _all_gather_fp8( def _all_gather_fp8(
input_: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
*, *,
async_op: bool = False, async_op: bool = False,
...@@ -854,18 +854,18 @@ def _all_gather_fp8( ...@@ -854,18 +854,18 @@ def _all_gather_fp8(
# Output tensor dims # Output tensor dims
if out_shape is None: if out_shape is None:
out_shape = list(input_.size()) out_shape = list(inp.size())
out_shape[0] *= world_size out_shape[0] *= world_size
# Quantize input tensor if needed # Quantize input tensor if needed
if not isinstance(input_, Float8TensorBase): if not isinstance(inp, Float8TensorBase):
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
# we cannot directly gather the transposed fp8 tensor # we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer # so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing # and then set it back to the original value after quantizing
init_columnwise_usage = quantizer.columnwise_usage init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(columnwise=False) quantizer.set_usage(columnwise=False)
input_ = quantizer(input_) inp = quantizer(inp)
quantizer.set_usage(columnwise=init_columnwise_usage) quantizer.set_usage(columnwise=init_columnwise_usage)
# Construct output tensor # Construct output tensor
...@@ -873,30 +873,30 @@ def _all_gather_fp8( ...@@ -873,30 +873,30 @@ def _all_gather_fp8(
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
dtype = torch.float32 dtype = torch.float32
device = "cuda" device = "cuda"
if isinstance(input_, Float8Tensor): if isinstance(inp, Float8Tensor):
dtype = input_.dtype dtype = inp.dtype
device = input_.device device = inp.device
out = quantizer.make_empty(out_shape, dtype=dtype, device=device) out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
elif isinstance(input_, Float8Tensor): elif isinstance(inp, Float8Tensor):
out = input_.make_like(input_, shape=out_shape) out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty_like( out._data = torch.empty_like(
out_shape, out_shape,
dtype=torch.uint8, dtype=torch.uint8,
device=input_.device, device=inp.device,
) )
out._transpose = None out._transpose = None
out._transpose_invalid = True out._transpose_invalid = True
else: else:
raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") raise RuntimeError("FP8TensorBase is not supported yet without Quantizer")
# For delayed scaling, scale_inv is from history, so we can pass it from input_ to out # For delayed scaling, scale_inv is from history, so we can pass it from inp to out
# For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv, # For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv,
# so we can just pass it from input_ to out # so we can just pass it from inp to out
out._scale_inv = input_._scale_inv out._scale_inv = inp._scale_inv
# Perform communication # Perform communication
handle = torch.distributed.all_gather_into_tensor( handle = torch.distributed.all_gather_into_tensor(
out._data, out._data,
input_._data.contiguous(), inp._data.contiguous(),
group=process_group, group=process_group,
async_op=async_op, async_op=async_op,
) )
...@@ -914,7 +914,7 @@ def _all_gather_fp8( ...@@ -914,7 +914,7 @@ def _all_gather_fp8(
def _all_gather_mxfp8( def _all_gather_mxfp8(
input_: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
*, *,
async_op: bool = False, async_op: bool = False,
...@@ -925,27 +925,56 @@ def _all_gather_mxfp8( ...@@ -925,27 +925,56 @@ def _all_gather_mxfp8(
# Tensor dims # Tensor dims
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
in_shape = list(input_.size()) in_shape = list(inp.size())
if out_shape is None: if out_shape is None:
out_shape = [in_shape[0] * world_size] + in_shape[1:] out_shape = [in_shape[0] * world_size] + in_shape[1:]
# Gather MXFP8 data for row-wise usage # For cases where inp has dimensions that cannot be quantized,
if quantizer.rowwise_usage and not quantizer.columnwise_usage: # we gather in high precision followed by a cast to FP8.
if (
not isinstance(inp, MXFP8TensorBase)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
out = torch.empty(
out_shape,
dtype=inp.dtype,
device=inp.device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out)
return out, None
# Cast input tensor to MXFP8 if needed inp_dtype = inp.dtype
if not isinstance(input_, MXFP8TensorBase): inp_device = inp.device
input_ = quantizer(input_)
# Cast input tensor to MXFP8 with required data
if not isinstance(inp, MXFP8TensorBase):
inp = quantizer(inp)
elif (
inp.rowwise_data is None
and quantizer.rowwise_usage
or inp.columnwise_data is None
and quantizer.columnwise_usage
):
warnings.warn(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to MXFP8."
)
inp = quantizer(inp.dequantize())
# Construct MXFP8 output tensor # Construct MXFP8 output tensor
dtype = torch.float32 out = quantizer.make_empty(out_shape, dtype=inp_dtype, device=inp_device)
device = "cuda"
if isinstance(input_, MXFP8Tensor): # Async op handle
dtype = input_.dtype handle = None
device = input_.device
out = quantizer.make_empty(out_shape, dtype=dtype, device=device) # Gather MXFP8 data for row-wise usage
if quantizer.rowwise_usage:
# Remove padding from MXFP8 scale-inverses # Remove padding from MXFP8 scale-inverses
in_scale_inv = input_._rowwise_scale_inv in_scale_inv = inp._rowwise_scale_inv
out_scale_inv = out._rowwise_scale_inv out_scale_inv = out._rowwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1]) flattened_in_shape0 = math.prod(in_shape[:-1])
if in_scale_inv.size(0) != flattened_in_shape0: if in_scale_inv.size(0) != flattened_in_shape0:
...@@ -954,40 +983,52 @@ def _all_gather_mxfp8( ...@@ -954,40 +983,52 @@ def _all_gather_mxfp8(
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers # Launch all-gathers
with torch.distributed._coalescing_manager( if handle is not None:
group=process_group, handle.wait()
device=device,
async_ops=async_op,
) as coalescing_manager:
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
handle = torch.distributed.all_gather_into_tensor(
out._rowwise_data, out._rowwise_data,
input_._rowwise_data, inp._rowwise_data,
group=process_group, group=process_group,
async_op=async_op,
) )
# Gather MXFP8 data for column-wise usage
if quantizer.columnwise_usage:
# Remove padding from MXFP8 scale-inverses
in_scale_inv = inp._columnwise_scale_inv
out_scale_inv = out._columnwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1]) // 32
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers
if handle is not None:
handle.wait()
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out_scale_inv, out_scale_inv,
in_scale_inv, in_scale_inv,
group=process_group, group=process_group,
) )
handle = coalescing_manager if async_op else None handle = torch.distributed.all_gather_into_tensor(
return out, handle out._columnwise_data,
inp._columnwise_data,
# Gather in high precision and quantize for column-wise usage group=process_group,
if isinstance(input_, QuantizedTensor): async_op=async_op,
input_ = input_.dequantize(dtype=torch.bfloat16)
out = torch.empty(
out_shape,
dtype=input_.dtype,
device=input_.device,
memory_format=torch.contiguous_format,
) )
torch.distributed.all_gather_into_tensor(out, input_, group=process_group)
out = quantizer(out) return out, handle
return out, None
def gather_along_first_dim( def gather_along_first_dim(
input_: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
async_op: bool = False, async_op: bool = False,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
...@@ -997,20 +1038,20 @@ def gather_along_first_dim( ...@@ -997,20 +1038,20 @@ def gather_along_first_dim(
# Return immediately if no communication is required # Return immediately if no communication is required
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
if world_size == 1: if world_size == 1:
if quantizer is not None and not isinstance(input_, QuantizedTensor): if quantizer is not None and not isinstance(inp, QuantizedTensor):
input_ = quantizer(input_) inp = quantizer(inp)
return input_, None return inp, None
# Output tensor dims # Output tensor dims
out_shape = list(input_.size()) out_shape = list(inp.size())
out_shape[0] *= world_size out_shape[0] *= world_size
# FP8 case: delayed scaling or current scaling # FP8 case: delayed scaling or current scaling
if isinstance(input_, Float8TensorBase) or isinstance( if isinstance(inp, Float8TensorBase) or isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
): ):
return _all_gather_fp8( return _all_gather_fp8(
input_, inp,
process_group, process_group,
async_op=async_op, async_op=async_op,
quantizer=quantizer, quantizer=quantizer,
...@@ -1018,10 +1059,10 @@ def gather_along_first_dim( ...@@ -1018,10 +1059,10 @@ def gather_along_first_dim(
) )
# MXFP8 case # MXFP8 case
if isinstance(input_, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer):
assert isinstance(quantizer, MXFP8Quantizer) assert isinstance(quantizer, MXFP8Quantizer)
return _all_gather_mxfp8( return _all_gather_mxfp8(
input_, inp,
process_group, process_group,
async_op=async_op, async_op=async_op,
quantizer=quantizer, quantizer=quantizer,
...@@ -1034,36 +1075,36 @@ def gather_along_first_dim( ...@@ -1034,36 +1075,36 @@ def gather_along_first_dim(
"Attempting to all-gather an unsupported quantized tensor. " "Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather." "Falling back to high-precision all-gather."
) )
if isinstance(input_, QuantizedTensor): if isinstance(inp, QuantizedTensor):
input_ = input_.dequantize() inp = inp.dequantize()
out = torch.empty( out = torch.empty(
out_shape, out_shape,
dtype=input_.dtype, dtype=inp.dtype,
device=input_.device, device=inp.device,
memory_format=torch.contiguous_format, memory_format=torch.contiguous_format,
) )
torch.distributed.all_gather_into_tensor(out, input_, group=process_group) torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out) out = quantizer(out)
return out, None return out, None
# Dequantize quantized tensor if not supported # Dequantize quantized tensor if not supported
if isinstance(input_, QuantizedTensor): if isinstance(inp, QuantizedTensor):
warnings.warn( warnings.warn(
"Attempting to all-gather an unsupported quantized tensor. " "Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather." "Falling back to high-precision all-gather."
) )
input_ = input_.dequantize() inp = inp.dequantize()
# Communication for plain PyTorch tensors # Communication for plain PyTorch tensors
out = torch.empty( out = torch.empty(
out_shape, out_shape,
dtype=input_.dtype, dtype=inp.dtype,
device=input_.device, device=inp.device,
memory_format=torch.contiguous_format, memory_format=torch.contiguous_format,
) )
handle = torch.distributed.all_gather_into_tensor( handle = torch.distributed.all_gather_into_tensor(
out, out,
input_.contiguous(), inp.contiguous(),
group=process_group, group=process_group,
async_op=async_op, async_op=async_op,
) )
...@@ -1071,7 +1112,7 @@ def gather_along_first_dim( ...@@ -1071,7 +1112,7 @@ def gather_along_first_dim(
def allreduce( def allreduce(
input_: torch.Tensor, inp: torch.Tensor,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
async_op: bool = False, async_op: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: ) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
...@@ -1079,12 +1120,12 @@ def allreduce( ...@@ -1079,12 +1120,12 @@ def allreduce(
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_distributed_world_size(tp_group) == 1: if get_distributed_world_size(tp_group) == 1:
return input_, None return inp, None
# All-reduce. # All-reduce.
handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op) handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op)
return input_, handle return inp, handle
def _fsdp_scatter_tensors( def _fsdp_scatter_tensors(
......
...@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import ( ...@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import (
restore_from_saved, restore_from_saved,
) )
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpp_extensions import ( from ..cpp_extensions import (
...@@ -326,6 +327,19 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -326,6 +327,19 @@ class _LayerNormLinear(torch.autograd.Function):
clear_tensor_data(ln_out, ln_out_total) clear_tensor_data(ln_out, ln_out_total)
if is_grad_enabled: if is_grad_enabled:
ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)
# Input with column-wise usage is needed for dgrad GEMM.
if backward_needs_input:
if isinstance(ln_out, QuantizedTensor):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False)
if cpu_offloading: if cpu_offloading:
if fp8 and weightmat is not None: if fp8 and weightmat is not None:
set_offloading_param(weightmat, "weight_offloading", True) set_offloading_param(weightmat, "weight_offloading", True)
...@@ -556,12 +570,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -556,12 +570,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Note: Perform tensor-parallel communication if needed # Note: Perform tensor-parallel communication if needed
ln_out_total = None ln_out_total = None
ln_out_total_work = None ln_out_total_work = None
if ( if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
ctx.requires_wgrad
and ctx.parallel_mode == "column"
and ctx.sequence_parallel
and not ctx.ub_bulk_dgrad
):
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
......
...@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import ( ...@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
...@@ -155,6 +156,7 @@ class _Linear(torch.autograd.Function): ...@@ -155,6 +156,7 @@ class _Linear(torch.autograd.Function):
) )
if not isinstance(inputmat, QuantizedTensor): if not isinstance(inputmat, QuantizedTensor):
inputmat = input_quantizer(inputmat) inputmat = input_quantizer(inputmat)
own_quantized_input = True
elif backward_needs_input: elif backward_needs_input:
inputmat.update_usage(rowwise_usage=True, columnwise_usage=True) inputmat.update_usage(rowwise_usage=True, columnwise_usage=True)
inputmat_total = inputmat inputmat_total = inputmat
...@@ -251,8 +253,17 @@ class _Linear(torch.autograd.Function): ...@@ -251,8 +253,17 @@ class _Linear(torch.autograd.Function):
if is_grad_enabled: if is_grad_enabled:
saved_inputmat = None saved_inputmat = None
ctx.backward_input_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)
if backward_needs_input: if backward_needs_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensor): if own_quantized_input and isinstance(inputmat, QuantizedTensor):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
inputmat.update_usage(rowwise_usage=False) inputmat.update_usage(rowwise_usage=False)
saved_inputmat = inputmat saved_inputmat = inputmat
...@@ -311,7 +322,6 @@ class _Linear(torch.autograd.Function): ...@@ -311,7 +322,6 @@ class _Linear(torch.autograd.Function):
ctx.requires_wgrad = weight.requires_grad ctx.requires_wgrad = weight.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.owns_input = saved_inputmat is not inp ctx.owns_input = saved_inputmat is not inp
ctx.is_input_fp8 = not own_quantized_input
if ctx.fp8 and requires_grad(inp, weight, bias): if ctx.fp8 and requires_grad(inp, weight, bias):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
...@@ -452,12 +462,7 @@ class _Linear(torch.autograd.Function): ...@@ -452,12 +462,7 @@ class _Linear(torch.autograd.Function):
# Note: Perform tensor-parallel communication if needed # Note: Perform tensor-parallel communication if needed
inputmat_total = None inputmat_total = None
inputmat_total_work = None inputmat_total_work = None
if ( if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
ctx.requires_wgrad
and ctx.parallel_mode == "column"
and ctx.sequence_parallel
and not ctx.ub_bulk_dgrad
):
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
......
...@@ -523,9 +523,7 @@ class BasicLinear(BasicOperation): ...@@ -523,9 +523,7 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass # Configure input tensor for backward pass
if own_quantized_x_local: if own_quantized_x_local:
### TODO Restore once column-wise usage is supported by itself # pylint: disable=fixme x_local.update_usage(rowwise_usage=False)
# x_local.update_usage(rowwise_usage=False)
pass
# Detach input tensor if needed # Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save # Note: PyTorch autograd produces esoteric errors if we save
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""Mixin class holding data specific for Float8Tensor""" """Mixin class holding data specific for Float8Tensor"""
from __future__ import annotations from __future__ import annotations
import math
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import torch import torch
...@@ -120,7 +121,10 @@ class Float8TensorBase: ...@@ -120,7 +121,10 @@ class Float8TensorBase:
def size(self, *args, **kwargs): def size(self, *args, **kwargs):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if self._data is not None:
return self._data.size(*args, **kwargs) return self._data.size(*args, **kwargs)
size = self._transpose.size(*args, **kwargs)
return torch.Size([size[-1], math.prod(size[:-1])])
def __repr__(self): def __repr__(self):
return ( return (
......
...@@ -115,7 +115,9 @@ class MXFP8TensorBase: ...@@ -115,7 +115,9 @@ class MXFP8TensorBase:
def size(self, *args, **kwargs): def size(self, *args, **kwargs):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs) return self._rowwise_data.size(*args, **kwargs)
return self._columnwise_data.size(*args, **kwargs)
def __repr__(self): def __repr__(self):
data_rowwise = self.dequantize() data_rowwise = self.dequantize()
......
...@@ -428,21 +428,40 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -428,21 +428,40 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose) self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose)
self._transpose_invalid = False self._transpose_invalid = False
def update_usage(self, rowwise_usage=True, columnwise_usage=True): def update_usage(
assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor" self,
if rowwise_usage: rowwise_usage: Optional[bool] = None,
assert self._data is not None, "Rowwise usage of the tensor was already disabled" columnwise_usage: Optional[bool] = None,
):
# Figure out what data is available and what is required
has_data = self._data is not None
has_data_transpose = self._transpose is not None and not self._transpose_invalid
needs_data = has_data
needs_data_transpose = has_data_transpose
if non_tn_fp8_gemm_supported():
if rowwise_usage is not None and rowwise_usage:
needs_data = True
if columnwise_usage is not None and columnwise_usage:
needs_data = True
needs_data_transpose = False
else: else:
if not non_tn_fp8_gemm_supported(): if rowwise_usage is not None:
if self._transpose is None or self._transpose_invalid: needs_data = rowwise_usage
if columnwise_usage is not None:
needs_data_transpose = columnwise_usage
# Generate data that is required
if needs_data and not has_data:
raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose")
if needs_data_transpose and not has_data_transpose:
if not has_data:
raise RuntimeError("FP8 data is required to generate FP8 data transpose")
self._create_transpose() self._create_transpose()
# Delete data that is not required
if not needs_data:
self._data = None self._data = None
if columnwise_usage: if not needs_data_transpose:
if self._transpose is None or self._transpose_invalid:
assert self._data is not None, "The tensor does not hold any data anymore"
if not non_tn_fp8_gemm_supported():
self._create_transpose()
else:
self._transpose = None self._transpose = None
self._transpose_invalid = True self._transpose_invalid = True
......
...@@ -66,6 +66,16 @@ class MXFP8Quantizer(Quantizer): ...@@ -66,6 +66,16 @@ class MXFP8Quantizer(Quantizer):
return dst return dst
def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
if inp.ndim < 2:
return False
if inp.shape[-1] % MXFP8_BLOCK_SCALING_SIZE != 0:
return False
if math.prod(inp.shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE != 0:
return False
return True
def make_empty( def make_empty(
self, self,
shape: Iterable[int], shape: Iterable[int],
...@@ -207,36 +217,50 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -207,36 +217,50 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# TODO(ksivamani): Fix the detach bug # TODO(ksivamani): Fix the detach bug
return MXFP8Tensor.make_like(self) return MXFP8Tensor.make_like(self)
def update_usage(self, rowwise_usage=True, columnwise_usage=True): def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
""" """
For MXFP8, columnwise scaled output is only produced by x2 For MXFP8, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages. scaling kernels, so this function only disables usages.
""" """
assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor."
if columnwise_usage and rowwise_usage: # Default usage is based on available data
assert ( if rowwise_usage is None:
self._rowwise_data is not None rowwise_usage = self._rowwise_data is not None
and self._rowwise_scale_inv is not None if columnwise_usage is None:
and self._columnwise_data is not None columnwise_usage = self._columnwise_data is not None
and self._columnwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage."
return
# Update row-scaled data
if rowwise_usage: if rowwise_usage:
assert ( if self._rowwise_data is None:
self._rowwise_data is not None and self._rowwise_scale_inv is not None raise RuntimeError(
), "Cannot update to rowwise usage." "Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data"
self._columnwise_data = None )
self._columnwise_scale_inv = None if self._rowwise_scale_inv is None:
return raise RuntimeError(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses"
assert ( )
self._columnwise_data is not None and self._columnwise_scale_inv is not None else:
), "Cannot update to columnwise usage."
self._rowwise_data = None self._rowwise_data = None
self._rowwise_scale_inv = None self._rowwise_scale_inv = None
return
# Update column-scaled data
if columnwise_usage:
if self._columnwise_data is None:
raise RuntimeError(
"Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data"
)
if self._columnwise_scale_inv is None:
raise RuntimeError(
"Requested column-wise usage, "
"but MXFP8Tensor is missing column-scaled scale-inverses"
)
else:
self._columnwise_data = None
self._columnwise_scale_inv = None
def clone(self) -> MXFP8Tensor: def clone(self) -> MXFP8Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
......
...@@ -21,13 +21,10 @@ def prepare_for_saving( ...@@ -21,13 +21,10 @@ def prepare_for_saving(
"""Prepare tensors for saving. Needed because save_for_backward accepts only """Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal TensorBase types too.""" the internal TensorBase types too."""
# pylint: disable=unidiomatic-typecheck # Using type instead of isinstance to check exact type
tensor_list, tensor_objects_list = [], [] tensor_list, tensor_objects_list = [], []
for tensor in tensors: for tensor in tensors:
if tensor is None: if tensor is None or isinstance(tensor, torch.Tensor):
tensor_list.append(None)
tensor_objects_list.append(None)
elif isinstance(tensor, torch.Tensor):
tensor_list.append(tensor) tensor_list.append(tensor)
tensor_objects_list.append(None) tensor_objects_list.append(None)
else: else:
...@@ -44,7 +41,7 @@ def restore_from_saved( ...@@ -44,7 +41,7 @@ def restore_from_saved(
"""Recombine the tensor data and metadata during backward pass.""" """Recombine the tensor data and metadata during backward pass."""
tensor_objects = [] tensor_objects = []
for tensor in tensors: for tensor in tensors:
if tensor is None: if tensor is None or isinstance(tensor, torch.Tensor):
tensor_objects.append(saved_tensors[0]) tensor_objects.append(saved_tensors[0])
saved_tensors = saved_tensors[1:] saved_tensors = saved_tensors[1:]
else: else:
...@@ -289,7 +286,11 @@ class QuantizedTensor(torch.Tensor): ...@@ -289,7 +286,11 @@ class QuantizedTensor(torch.Tensor):
f"{self.__class__.__name__} class does not implement detach function" f"{self.__class__.__name__} class does not implement detach function"
) )
def update_usage(self, rowwise_usage=True, columnwise_usage=True): def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""Indicate to the tensor how it is going to be used """Indicate to the tensor how it is going to be used
This enables optimizations to memory usage in some cases This enables optimizations to memory usage in some cases
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment