Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
......@@ -15,11 +15,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, wd, momentum, dampening, lr, nesterov, first_run,
wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream());
wd_after_momentum, scale, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -108,11 +108,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
}
}
TensorWrapper unquantized_out_cu;
py::object unquantized_out;
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none};
py::object unquantized_out;
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
}
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size
......@@ -139,45 +145,12 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(),
at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
} else {
my_quantizer->quantize(unquantized_out_cu, out_cu);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
}
return {out, py::cast(mu), py::cast(rsigma)};
......@@ -269,11 +242,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
}
}
TensorWrapper unquantized_out_cu;
py::object unquantized_out;
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none};
py::object unquantized_out;
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
}
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size
......@@ -300,45 +279,12 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(),
at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
} else {
my_quantizer->quantize(unquantized_out_cu, out_cu);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
}
return {out, py::none(), py::cast(rsigma)};
......
......@@ -111,7 +111,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"),
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false,
py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt);
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
......@@ -228,6 +229,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims,
"Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend,
"Get Fused Attention backend", py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &transformer_engine::pytorch::compute_amax,
......@@ -394,6 +398,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
// Comm+GEMM Overlap
m.def("bulk_overlap_ag_with_external_gemm",
&transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm,
"Bulk overlap All-Gather with a GEMM operation launched by another communicator",
py::call_guard<py::gil_scoped_release>(), py::arg("allgather_communicator"),
py::arg("send_stream"), py::arg("recv_stream"));
// Data structures
py::class_<transformer_engine::pytorch::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>())
......
......@@ -18,31 +18,64 @@ namespace pytorch {
at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) {
init_extension();
const auto dim = input.dim();
NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose.");
if (input.dim() > 2) {
input = input.view({-1, input.size(dim - 1)});
// Tensor dimensions
const auto shape = getTensorShape(input);
std::vector<int64_t> transpose_shape_int64;
if (shape.size() > 0) {
transpose_shape_int64.push_back(shape.back());
for (size_t i = 0; i < shape.size() - 1; ++i) {
transpose_shape_int64.push_back(shape[i]);
}
}
const size_t M = shape.size() > 0 ? product(shape) / shape.back() : 1;
const size_t N = shape.size() > 0 ? shape.back() : 1;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
// Output tensor
at::Tensor out;
if (output.has_value()) {
out = *output;
} else {
out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
out = at::empty(transpose_shape_int64, opts);
}
// Return immediately if tensor is empty
if (M == 0 || N == 0) {
return out;
}
if (M == 0 || N == 0) return out;
// Compute transpose
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector<size_t>{M, N}, otype);
auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector<size_t>{N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return out;
}
at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out) {
init_extension();
// Make sure input is contiguous
const auto &input = tensor.contiguous();
// Allocate output tensor if needed
if (!out) {
auto in_shape = getTensorShape(input);
NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")");
std::vector<int64_t> out_shape_int64(in_shape.begin(), in_shape.end());
out_shape_int64[0] = static_cast<int64_t>(in_shape[1]);
out_shape_int64[1] = static_cast<int64_t>(in_shape[0]);
auto opts = at::TensorOptions().dtype(input.dtype()).device(input.device());
out = at::empty(out_shape_int64, opts);
}
// Launch kernel
const TensorWrapper te_input = makeTransformerEngineTensor(input);
TensorWrapper te_output = makeTransformerEngineTensor(*out);
nvte_swap_first_dims(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
return std::move(*out);
}
} // namespace pytorch
} // namespace transformer_engine
......@@ -12,6 +12,27 @@
namespace transformer_engine::pytorch {
namespace {
/*! @brief Transposed tensor shape
*
* The tensor is interpreted as a 2D matrix by flattening all but the
* last dimension, and then transposed.
*/
template <typename T = size_t, typename S = T>
std::vector<T> make_transpose_shape(const std::vector<S>& shape) {
std::vector<T> ret;
if (shape.size() > 0) {
ret.push_back(shape.back());
for (size_t i = 0; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
}
}
return ret;
}
} // namespace
constexpr size_t MXFP8_BLOCK_SIZE = 32;
Quantizer::Quantizer(const py::handle& quantizer) {
......@@ -37,24 +58,36 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti
this->dtype = type;
}
std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
at::TensorOptions opts;
opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA);
std::vector<int64_t> torch_shape;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
}
at::Tensor ret;
if (rowwise_data.has_value()) {
ret = std::move(*rowwise_data);
} else {
ret = at::empty(torch_shape, opts);
}
std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype) const {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA);
return create_tensor(shape, dtype, at::empty(shape_int64, opts));
}
std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype,
at::Tensor data) const {
TensorWrapper out_cpp;
out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape);
set_quantization_params(&out_cpp);
return {std::move(out_cpp), py::cast(data)};
}
TensorWrapper tensor;
tensor.set_rowwise_data(ret.data_ptr(), dtype, shape);
return {std::move(tensor), py::cast(ret)};
std::pair<TensorWrapper, py::object> NoneQuantizer::convert_and_update_tensor(
py::object tensor) const {
auto tensor_pyt = tensor.cast<at::Tensor>();
TensorWrapper out_cpp;
out_cpp.set_rowwise_data(tensor_pyt.data_ptr(),
GetTransformerEngineDType(tensor_pyt.scalar_type()),
getTensorShape(tensor_pyt));
set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void NoneQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
NVTE_ERROR("NoneQuantizer does not support quantization");
}
void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
......@@ -76,68 +109,180 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
}
std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
const std::vector<size_t>& shape, DType dtype) const {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
at::Tensor scale_inv = at::empty(std::vector<int64_t>{1}, opts);
return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv));
}
std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data,
std::optional<at::Tensor> transpose, std::optional<at::Tensor> scale_inv) const {
using namespace pybind11::literals;
std::vector<int64_t> rowwise_torch_shape;
std::vector<int64_t> columnwise_torch_shape;
if (!shape.empty()) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape.back()));
}
for (size_t i = 0; i < shape.size(); ++i) {
if (i < shape.size() - 1) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
rowwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
at::TensorOptions opts;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(rowwise_torch_shape, opts);
// Initialize data tensor
const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
if (with_data && !data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data = at::empty(shape_int64, opts);
} else if (!with_data && data) {
data.reset();
}
py::object data_py = with_data ? py::cast(*data) : py::none();
// Initialize transpose tensor
const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (with_transpose && !transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
transpose = at::empty(transpose_shape, opts);
} else if (!with_transpose && transpose) {
transpose.reset();
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
}
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
opts = opts.dtype(torch::kFloat32);
// TODO: Replace with an empty tensor.
at::Tensor scale_inv = at::reciprocal(scale);
py::object ret;
// Construct Python FP8 tensor
py::object out_py;
if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
} else {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype),
"data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"data"_a = data_py, "fp8_scale_inv"_a = *scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
}
TensorWrapper tensor(this->get_scaling_mode());
if (rowwise_usage) {
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
// Construct C++ FP8 tensor
TensorWrapper out_cpp(this->get_scaling_mode());
if (with_data) {
out_cpp.set_rowwise_data(data->data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
if (with_transpose) {
const auto transpose_shape = make_transpose_shape(shape);
out_cpp.set_columnwise_data(transpose->data_ptr(), this->dtype, transpose_shape);
out_cpp.set_columnwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> Float8Quantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor.");
// Expected buffers
const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer.");
// Extract buffers from Python tensor
auto data_py = tensor.attr("_data");
auto transpose_py = tensor.attr("_transpose");
const bool has_data = !data_py.is_none();
const bool has_transpose = !transpose_py.is_none();
NVTE_CHECK(has_data || has_transpose, "Float8Tensor has no data.");
std::optional<at::Tensor> data_tensor, transpose_tensor;
if (has_data) {
data_tensor = data_py.cast<at::Tensor>();
}
if (create_transpose) {
std::vector<size_t> transposed_shape;
for (auto s : columnwise_torch_shape) {
transposed_shape.emplace_back(static_cast<size_t>(s));
if (has_transpose) {
transpose_tensor = transpose_py.cast<at::Tensor>();
}
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape);
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast<at::Tensor>();
// Tensor dimensions
std::vector<size_t> shape;
if (has_transpose) {
const auto transpose_shape = getTensorShape(*transpose_tensor);
if (transpose_shape.size() > 0) {
for (size_t i = 1; i < transpose_shape.size(); ++i) {
shape.push_back(transpose_shape[i]);
}
this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)};
shape.push_back(transpose_shape.front());
}
if (has_data) {
auto expected_shape = getTensorShape(*data_tensor);
NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape,
") and transpose (shape=", transpose_shape, ") do not match");
}
} else { // Already checked has_data == true
shape = getTensorShape(*data_tensor);
}
// Coerce data tensor
if (has_data && !need_data) {
data_tensor.reset();
data_py = py::none();
tensor.attr("_data") = data_py;
} else if (!has_data && need_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data_tensor = at::empty(shape_int64, opts);
data_py = py::cast(data_tensor);
tensor.attr("_data") = data_py;
}
// Coerce transpose tensor
if (has_transpose && !need_transpose) {
transpose_tensor.reset();
transpose_py = py::none();
tensor.attr("_transpose") = transpose_py;
} else if (!has_transpose && need_transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
transpose_tensor = at::empty(transpose_shape, opts);
transpose_py = py::cast(transpose_tensor);
tensor.attr("_transpose") = transpose_py;
}
tensor.attr("_transpose_invalid") = !need_transpose;
// Coerce other attrs
tensor.attr("_fp8_dtype") = dtype;
// Construct C++ FP8 tensor
TensorWrapper out_cpp;
if (data_tensor) {
out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
if (transpose_tensor) {
const auto transpose_shape = make_transpose_shape(shape);
out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape);
out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void Float8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
if (input.numel() == 0) {
return;
}
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
}
Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer)
......@@ -187,71 +332,223 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
const std::vector<size_t>& shape, DType dtype) const {
using namespace pybind11::literals;
std::vector<int64_t> rowwise_torch_shape;
std::vector<int64_t> columnwise_torch_shape;
std::vector<int64_t> scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv
if (!shape.empty()) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape.back()));
}
for (size_t i = 0; i < shape.size(); ++i) {
if (i < shape.size() - 1) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
// Initialize data tensor
at::Tensor data_tensor;
const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
if (with_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data_tensor = at::empty(shape_int64, opts);
}
rowwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
at::TensorOptions opts;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(rowwise_torch_shape, opts);
}
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
// Initialize transpose tensor
at::Tensor transpose_tensor;
const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (with_transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
transpose_tensor = at::empty(transpose_shape, opts);
}
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
// In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set.
at::Tensor scale_inv = at::reciprocal(scale);
// Initialize scale-inverse tensor
at::Tensor scale_inv_tensor;
{
const std::vector<int64_t> scale_inv_shape = {1};
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
scale_inv_tensor = at::empty(scale_inv_shape, opts);
}
py::object ret;
// Construct Python FP8 tensor
py::object out_py;
py::object data_py = with_data ? py::cast(data_tensor) : py::none();
py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none();
if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
} else {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype),
"data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
}
TensorWrapper tensor(this->get_scaling_mode());
if (rowwise_usage) {
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
// Construct C++ FP8 tensor
TensorWrapper out_cpp(this->get_scaling_mode());
if (with_data) {
out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
if (with_transpose) {
const auto transpose_shape = make_transpose_shape(shape);
out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape);
out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_hp_tensor_with_amax(
const std::vector<size_t>& shape, DType dtype) {
amax.zero_();
auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype);
out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()),
"Float8CurrentScalingQuantizer must output to Float8Tensor.");
// Expected buffers
const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages.");
// Extract buffers from Python tensor
auto data_py = tensor.attr("_data");
auto transpose_py = tensor.attr("_transpose");
const bool has_data = !data_py.is_none();
const bool has_transpose = !transpose_py.is_none();
NVTE_CHECK(has_data || has_transpose, "Tensor has no data.");
std::optional<at::Tensor> data_tensor, transpose_tensor;
if (has_data) {
data_tensor = data_py.cast<at::Tensor>();
}
if (has_transpose) {
transpose_tensor = transpose_py.cast<at::Tensor>();
}
if (create_transpose) {
std::vector<size_t> transposed_shape;
for (auto s : columnwise_torch_shape) {
transposed_shape.emplace_back(static_cast<size_t>(s));
at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast<at::Tensor>();
// Tensor dimensions
std::vector<size_t> shape;
if (has_transpose) {
const auto transpose_shape = getTensorShape(*transpose_tensor);
if (transpose_shape.size() > 0) {
for (size_t i = 1; i < transpose_shape.size(); ++i) {
shape.push_back(transpose_shape[i]);
}
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape);
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
shape.push_back(transpose_shape.front());
}
if (has_data) {
auto expected_shape = getTensorShape(*data_tensor);
NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape,
") and transpose (shape=", transpose_shape, ") do not match");
}
} else { // Already checked has_data == true
shape = getTensorShape(*data_tensor);
}
this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)};
// Coerce data tensor in Python tensor
if (has_data && !need_data) {
data_tensor.reset();
data_py = py::none();
tensor.attr("_data") = data_py;
} else if (!has_data && need_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data_tensor = at::empty(shape_int64, opts);
data_py = py::cast(data_tensor);
tensor.attr("_data") = data_py;
}
// Coerce transpose tensor
if (has_transpose && !need_transpose) {
transpose_tensor.reset();
transpose_py = py::none();
tensor.attr("_transpose") = transpose_py;
} else if (!has_transpose && need_transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
transpose_tensor = at::empty(transpose_shape, opts);
transpose_py = py::cast(transpose_tensor);
tensor.attr("_transpose") = transpose_py;
}
tensor.attr("_transpose_invalid") = !need_transpose;
// Coerce other attrs
tensor.attr("_fp8_dtype") = dtype;
// Construct C++ FP8 tensor
TensorWrapper out_cpp;
if (data_tensor) {
out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
if (transpose_tensor) {
const auto transpose_shape = make_transpose_shape(shape);
out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape);
out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag,
bool compute_amax) {
auto stream = at::cuda::getCurrentCUDAStream();
// Nothing to be done if input is empty
if (input.numel() == 0) {
return;
}
// Quantization configs
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon);
// Compute amax
if (compute_amax) {
NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); });
}
// Perform amax reduction if needed
if (with_amax_reduction) {
// allreduce amax tensor
c10d::AllreduceOptions opts;
opts.reduceOp = c10d::ReduceOp::MAX;
std::vector<at::Tensor> tensors = {amax};
NVTE_SCOPED_GIL_RELEASE({ amax_reduction_group->allreduce(tensors, opts)->wait(); });
}
// Compute scaling factor
NVTE_SCOPED_GIL_RELEASE({ nvte_compute_scale_from_amax(out.data(), quant_config, stream); });
// Cast to FP8
out.set_amax(nullptr, DType::kFloat32, out.defaultShape); // Avoid atomic amax updates
NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); });
}
void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
this->quantize_impl(input, out, noop_flag, true);
}
void Float8CurrentScalingQuantizer::quantize_with_amax(
TensorWrapper& input, TensorWrapper& out, const std::optional<TensorWrapper>& noop_flag) {
NVTE_CHECK(input.get_amax().data_ptr == amax.data_ptr(),
"Input does not use the appropriate amax tensor");
input.set_amax(nullptr, DType::kFloat32, input.defaultShape);
this->quantize_impl(input, out, noop_flag, false);
}
Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {
......@@ -280,7 +577,7 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const
}
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
const std::vector<size_t>& shape, DType dtype) const {
using namespace pybind11::literals;
std::vector<int64_t> torch_shape;
for (auto s : shape) {
......@@ -299,11 +596,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
: Float8BlockScaleTensorFormat::GEMM_READY);
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data_rowwise = std::move(*rowwise_data);
} else {
data_rowwise = at::empty(torch_shape, opts);
}
auto scale_shape = get_scale_shape(shape, false);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
......@@ -373,6 +666,177 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
return {std::move(tensor), std::move(ret)};
}
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_tensor(
py::object tensor) const {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name);
if (attr_py.is_none()) {
return std::nullopt;
}
return attr_py.cast<at::Tensor>();
};
auto rowwise_data = get_tensor("_rowwise_data");
auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv");
auto columnwise_data = get_tensor("_columnwise_data");
auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv");
NVTE_CHECK(rowwise_data || columnwise_data, "FP8BlockwiseTensor has no data.");
// Tensor options and dimensions
at::TensorOptions opts;
at::TensorOptions scale_opts;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector<size_t> {
if (!columnwise_data) {
return std::vector<size_t>();
}
if (all_gather_usage) {
return getTensorShape(*columnwise_data);
}
std::vector<size_t> shape = getTensorShape(*columnwise_data);
std::vector<size_t> shape_transposed(shape.size());
for (size_t i = 0; i + 1 < shape.size(); ++i) {
shape_transposed[i] = shape[i + 1];
}
if (shape.size() > 0) {
shape_transposed[shape.size() - 1] = shape[0];
}
return shape_transposed;
};
std::vector<size_t> shape;
if (rowwise_data) {
shape = getTensorShape(*rowwise_data);
if (columnwise_data) {
auto expected_shape = get_columnwise_shape(all_gather_usage);
NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape,
") and column-wise data (shape=", expected_shape, ") do not match");
}
} else {
shape = get_columnwise_shape(all_gather_usage);
}
std::vector<int64_t> torch_shape;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
}
// Coerce row-wise data
if (rowwise_usage) {
if (!rowwise_data) {
rowwise_data = at::empty(torch_shape, opts);
tensor.attr("_rowwise_data") = *rowwise_data;
}
if (!rowwise_scale_inv) {
auto scale_shape = get_scale_shape(shape, false);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
rowwise_scale_inv =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
}
} else { // rowwise_usage == false
if (rowwise_data) {
rowwise_data.reset();
tensor.attr("_rowwise_data") = py::none();
}
if (rowwise_scale_inv) {
rowwise_scale_inv.reset();
tensor.attr("_rowwise_scale_inv") = py::none();
}
}
// Coerce column-wise data
if (columnwise_usage) {
std::vector<size_t> columnwise_shape;
std::vector<int64_t> torch_columnwise_shape;
if (torch_shape.size() > 0) {
if (!all_gather_usage) {
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
}
}
if (!columnwise_data) {
columnwise_data = at::empty(torch_columnwise_shape, opts);
tensor.attr("_columnwise_data") = *columnwise_data;
}
if (!columnwise_scale_inv) {
auto scale_shape = get_scale_shape(shape, true);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
columnwise_scale_inv =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
}
} else { // columnwise_usage == false
if (columnwise_data) {
columnwise_data.reset();
tensor.attr("_columnwise_data") = py::none();
}
if (columnwise_scale_inv) {
columnwise_scale_inv.reset();
tensor.attr("_columnwise_scale_inv") = py::none();
}
}
auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
if (rowwise_usage) {
const at::Tensor& data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor& scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
void* scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr();
const auto& rowwise_shape = getTensorShape(data_rowwise);
ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape);
const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape);
}
if (columnwise_usage) {
const at::Tensor& data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor& scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
void* scale_inv_colwise_dptr = scale_inv_colwise.data_ptr();
const auto& shape = getTensorShape(data_colwise);
ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape);
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
}
set_quantization_params(&ret);
return {std::move(ret), std::move(tensor)};
}
void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
if (input.numel() == 0) {
return;
}
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon);
if (all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
}
std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size_t>& shape,
bool columnwise) const {
size_t numel = 1;
......@@ -465,71 +929,204 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype) const {
using namespace pybind11::literals;
std::vector<int64_t> torch_shape;
size_t numel = 1;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
numel *= s;
}
TensorWrapper tensor(NVTE_MXFP8_1D_SCALING);
at::TensorOptions opts;
at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv,
columnwise_scale_inv; // TODO(pgadzinski) - change
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(torch_shape, opts);
// Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1;
if (shape.size() > 0) {
for (size_t i = 0; i < shape.size() - 1; ++i) {
flat_first_dim *= shape[i];
}
auto scale_shape = get_scale_shape(shape, false);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
rowwise_scale_inv = at::zeros({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(
rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
}
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;
NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE,
" (got shape=", shape, ")");
const auto rowwise_scale_inv_shape = get_scale_shape(shape, false);
const auto columnwise_scale_inv_shape = get_scale_shape(shape, true);
// Allocate tensors
at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor;
at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor;
const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
if (rowwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(),
rowwise_scale_inv_shape.end());
rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts);
}
if (columnwise_usage) {
auto scale_shape = get_scale_shape(shape, true);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
columnwise_data = at::empty(torch_shape, opts);
columnwise_scale_inv =
at::zeros({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, opts);
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv(
columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
const std::vector<int64_t> scale_inv_shape_int64(columnwise_scale_inv_shape.begin(),
columnwise_scale_inv_shape.end());
columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts);
}
this->set_quantization_params(&tensor);
py::object ret;
// Convert tensors to Python
auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object {
return need_cast ? py::cast(tensor) : py::none();
};
auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage);
auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage);
auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage);
auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage);
// Construct Python MXFP8 tensor
py::object out_py;
if (internal) {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorBasePythonClass));
ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data,
"rowwise_scale_inv"_a = rowwise_scale_inv,
"columnwise_scale_inv"_a = columnwise_scale_inv,
out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py,
"columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
} else {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass));
ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = data, "columnwise_data"_a = columnwise_data,
"rowwise_scale_inv"_a = rowwise_scale_inv,
"columnwise_scale_inv"_a = columnwise_scale_inv,
out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = rowwise_data_py,
"columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
}
return {std::move(tensor), std::move(ret)};
// Construct C++ MXFP8 tensor
TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING);
if (rowwise_usage) {
out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0,
rowwise_scale_inv_shape);
}
if (columnwise_usage) {
out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), this->dtype, shape);
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0,
columnwise_scale_inv_shape);
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor.");
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name);
if (attr_py.is_none()) {
return std::nullopt;
}
return attr_py.cast<at::Tensor>();
};
auto rowwise_data = get_tensor("_rowwise_data");
auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv");
auto columnwise_data = get_tensor("_columnwise_data");
auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv");
NVTE_CHECK(rowwise_data || columnwise_data, "MXFP8Tensor has no data.");
// Tensor dimensions
std::vector<size_t> shape;
if (columnwise_data) {
shape = getTensorShape(*columnwise_data);
if (rowwise_data) {
auto expected_shape = getTensorShape(*rowwise_data);
NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape,
") and column-wise data (shape=", shape, ") do not match");
}
} else { // Already checked columnwise_data_tensor == true
shape = getTensorShape(*rowwise_data);
}
// Coerce row-wise data
if (rowwise_usage) {
if (!rowwise_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_data = at::empty(shape_int64, opts);
tensor.attr("_rowwise_data") = *rowwise_data;
}
if (!rowwise_scale_inv) {
const auto scale_inv_shape = get_scale_shape(shape, false);
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
}
} else { // rowwise_usage == false
if (rowwise_data) {
rowwise_data.reset();
tensor.attr("_rowwise_data") = py::none();
}
if (rowwise_scale_inv) {
rowwise_scale_inv.reset();
tensor.attr("_rowwise_scale_inv") = py::none();
}
}
// Coerce column-wise data
if (columnwise_usage) {
if (!columnwise_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
columnwise_data = at::empty(shape_int64, opts);
tensor.attr("_columnwise_data") = *columnwise_data;
}
if (!columnwise_scale_inv) {
const auto scale_inv_shape = get_scale_shape(shape, true);
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
}
} else { // columnwise_usage == false
if (columnwise_data) {
columnwise_data.reset();
tensor.attr("_columnwise_data") = py::none();
}
if (columnwise_scale_inv) {
columnwise_scale_inv.reset();
tensor.attr("_columnwise_scale_inv") = py::none();
}
}
// Coerce other attrs
tensor.attr("_fp8_dtype") = dtype;
// Construct C++ MXFP8 tensor
TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING);
if (rowwise_usage) {
out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, shape);
out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0,
getTensorShape(*rowwise_scale_inv));
}
if (columnwise_usage) {
out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, shape);
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0,
getTensorShape(*columnwise_scale_inv));
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
if (input.numel() == 0) {
return;
}
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
}
std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& shape,
......
......@@ -75,3 +75,98 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
return swizzled_scale_inv;
}
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper>& tensors, bool rowwise) {
using namespace transformer_engine::pytorch;
if (tensors.empty()) {
return std::nullopt;
}
bool all_same_scaling_mode = std::all_of(
tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) {
return val.scaling_mode() == tensors.front().scaling_mode();
});
NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same.");
if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) {
return std::nullopt;
}
std::vector<transformer_engine::TensorWrapper> wrappers;
std::vector<NVTETensor> input_tensors, output_tensors;
// Collect scale_inv shapes and calculate buffer size and offsets for scale_invs
std::vector<std::vector<size_t>> scale_inv_shapes;
std::vector<void*> scale_inv_dptrs;
size_t buffer_size = 0;
std::vector<size_t> scale_inv_offsets;
constexpr size_t scale_elem_size = 1;
for (auto& tensor : tensors) {
NVTEBasicTensor scale_inv;
if (rowwise) {
scale_inv = tensor.get_rowwise_scale_inv();
} else {
scale_inv = tensor.get_columnwise_scale_inv();
}
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_inv_offsets.push_back(buffer_size);
buffer_size += product(scale_inv_shape) * scale_elem_size;
scale_inv_shapes.emplace_back(scale_inv_shape);
scale_inv_dptrs.push_back(scale_inv.data_ptr);
}
// Allocate full buffer
auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8));
for (size_t i = 0; i < tensors.size(); ++i) {
auto& tensor = tensors[i];
void* scale_inv_dptr = scale_inv_dptrs[i];
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
auto input_shape = nvte_shape_to_vector(tensor.shape());
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING);
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
if (rowwise) {
input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
} else {
input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
output_cu.set_columnwise_data(tensor.columnwise_dptr(),
transformer_engine::DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_scale_inv(
swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
}
input_tensors.emplace_back(input_cu.data());
output_tensors.emplace_back(output_cu.data());
wrappers.emplace_back(std::move(input_cu));
wrappers.emplace_back(std::move(output_cu));
}
// Launch kernel
nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(),
input_tensors.size(), at::cuda::getCurrentCUDAStream());
return buffer;
}
......@@ -13,11 +13,18 @@
#include "transformer_engine/transformer_engine.h"
/* Swizzle the scaling factor of the input tensor.
/*! \brief Swizzle the scaling factor of the input tensor.
*
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
*/
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input,
bool trans);
bool rowwise);
/*! \brief Swizzle the scaling factor of the input tensors.
*
* The returned swizzled scaling factor tensors should be kept alive during the GEMMs.
*/
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper> &inputs, bool rowwise);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
......@@ -981,6 +981,15 @@ def _all_gather_fp8(
return out, handle
def _get_quantizer_format(quantizer: Quantizer) -> Optional[bool]:
"""Get quantizer format."""
if isinstance(quantizer, DebugQuantizer):
quantizer = quantizer.parent_quantizer
if isinstance(quantizer, Float8BlockQuantizer):
return quantizer.all_gather_usage
return None
def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
"""Make quantizer compact"""
_quantizer = quantizer
......@@ -1129,6 +1138,10 @@ def _all_gather_fp8_blockwise(
"Dequantizing and requantizing to Float8BlockwiseQTensor."
)
inp = quantizer(inp.dequantize())
# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
quantizer.all_gather_usage = orig_all_gather_usage
# Begin to do network communication, need to make sure compact format
......@@ -1138,9 +1151,6 @@ def _all_gather_fp8_blockwise(
f"but found data_format={inp._data_format}"
)
# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Coalesce NCCL collectives
with torch.distributed._coalescing_manager(
group=process_group,
......@@ -1216,14 +1226,12 @@ def _all_gather_mxfp8(
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.size()
device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype
else:
raise ValueError("Got MXFP8 input tensor without any data")
dtype = torch.bfloat16
dtype = torch.bfloat16 # Guess high-precision dtype.
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, "
......@@ -1343,6 +1351,44 @@ def gather_along_first_dim(
inp = quantizer(inp)
return inp, None
# Debug case - call gather_along_first_dim on each tensor
if isinstance(inp, DebugQuantizedTensor):
out_obj = DebugQuantizedTensor(
rowwise_gemm_tensor=inp.rowwise_gemm_tensor,
columnwise_gemm_tensor=inp.columnwise_gemm_tensor,
quantizer=inp.quantizer,
layer_name=inp._layer_name,
tensor_name=inp._tensor_name,
)
rowwise = inp.get_tensor(False)
columnwise = inp.get_tensor(True)
# shapes
final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
)
rowwise_total = None
if rowwise is not None:
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[
0
]
out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise:
final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
)
columnwise_total = None
if columnwise is not None:
columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise
)
out_obj.columnwise_gemm_tensor = columnwise_total
else:
# Sometimes the same object is used both for rowwise and columnwise gemms,
# and we want to avoid double all-gathers.
out_obj.columnwise_gemm_tensor = out_obj.rowwise_gemm_tensor
return out_obj, None
# Output tensor dims
out_shape = list(inp.size())
out_shape[0] *= world_size
......@@ -1380,34 +1426,6 @@ def gather_along_first_dim(
out_shape=out_shape,
)
# Debug case - call gather_along_first_dim on each tensor
if isinstance(inp, DebugQuantizedTensor):
out_obj = inp
rowwise = inp.get_tensor(False)
columnwise = inp.get_tensor(True)
final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(rowwise, Float8BlockwiseQTensorBase):
rowwise = inp._original_tensor
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise:
final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(columnwise, Float8BlockwiseQTensorBase):
columnwise = inp._original_tensor
columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise
)
out_obj.columnwise_gemm_tensor = columnwise_total
else:
out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor
return out_obj, None
# High-precision communication for quantized tensors
if quantizer is not None:
warnings.warn(
......@@ -1418,6 +1436,7 @@ def gather_along_first_dim(
inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
compact = _get_quantizer_format(quantizer)
_set_quantizer_format(quantizer, compact=False)
out = torch.empty(
out_shape,
......@@ -1427,6 +1446,7 @@ def gather_along_first_dim(
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out)
_set_quantizer_format(quantizer, compact=compact)
return out, None
# Dequantize quantized tensor if not supported
......
......@@ -80,14 +80,26 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]:
return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def check_recipe_support(recipe: Recipe) -> None:
"""Check if the given recipe is supported."""
recipe_supported = True
unsupported_reason = ""
if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)):
recipe_supported, unsupported_reason = check_fp8_support()
elif isinstance(recipe, Float8BlockScaling):
recipe_supported, unsupported_reason = check_fp8_block_scaling_support()
elif isinstance(recipe, MXFP8BlockScaling):
recipe_supported, unsupported_reason = check_mxfp8_support()
assert recipe_supported, unsupported_reason
def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args."""
if check_mxfp8_support()[0]:
# This is a temporary restriction until MXFP8 is supported for all
# gemm layouts.
if get_device_compute_capability() >= (12, 0):
return Float8BlockScaling()
return MXFP8BlockScaling()
if get_device_compute_capability() >= (12, 0):
# This is a temporary restriction until MXFP8 is supported for all gemm layouts.
return Float8CurrentScaling()
return DelayedScaling()
......@@ -664,6 +676,8 @@ def fp8_autocast(
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
"""
if enabled:
check_recipe_support(fp8_recipe)
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(
enabled=enabled,
......
......@@ -21,6 +21,8 @@ from .fp8 import (
from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule
from .ops.op import BasicOperation
from .ops import Sequential
from .ops.fuser import OperationFuser
from .utils import make_weak_ref
__all__ = ["make_graphed_callables"]
......@@ -44,7 +46,7 @@ def set_capture_end() -> None:
_IS_GRAPH_CAPTURING = False
def is_graph_capturing() -> None:
def is_graph_capturing() -> bool:
"""Return whether within `make_graphed_callables`."""
return _IS_GRAPH_CAPTURING
......@@ -177,24 +179,17 @@ def _make_graphed_callables(
assert isinstance(
sample_args, list
), "sample_args must be a list for _reuse_graph_input_output_buffers."
len_args = len(sample_args[0])
for i, arg in enumerate(sample_args):
assert len_args == len(
arg
), "Arguments must have same length and shape for `_reuse_graph_input_output_buffers`."
len_kwargs = len(sample_kwargs[0])
assert isinstance(
sample_kwargs, list
), "sample_kwargs must be a list for _reuse_graph_input_output_buffers."
for i, kwarg in enumerate(sample_kwargs):
assert len_kwargs == len(kwarg), (
"Keyword arguments must have same length and shape for"
" `_reuse_graph_input_output_buffers`."
)
# Reorganize args and kwargs for input tensor reuse.
# fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples.
# Each tuple contains the sample key signature and its fwd_idx. When we finish a backward
# chunk, we pop the corresponding fwd_idx and push to the consumed_sample_q.
# consumed_sample_q is keyed by the sample key signature. The value is a queue of the
# fwd_idx whose backward has been called so that we can reuse the same static buffers.
# In this way, we can reuse the same static input buffers for the non-overlapping samples
# with the same input signature.
fwd_sample_qs = {}
consumed_sample_q = []
consumed_sample_q = {}
fwd_idx = [0] * num_model_chunks
for c_id in _order:
m_chunk = abs(c_id) - 1
......@@ -206,10 +201,21 @@ def _make_graphed_callables(
fwd_sample_idx = [
sample_start_idx + i for i in range(_num_layers_per_chunk[m_chunk])
]
fwd_sample_qs[m_chunk] = fwd_sample_qs.get(m_chunk, []) + fwd_sample_idx
if m_chunk not in fwd_sample_qs:
fwd_sample_qs[m_chunk] = []
for per_callable_fwd_idx in fwd_sample_idx:
if consumed_sample_q:
reuse_fwd_idx = consumed_sample_q.pop(0)
sample_args_keys = tuple(
(t.shape, t.dtype, t.layout) for t in sample_args[per_callable_fwd_idx]
)
sample_kwargs_keys = tuple(
(k, v.shape, v.dtype, v.layout)
for k, v in sorted(sample_kwargs[per_callable_fwd_idx].items())
)
sample_keys = sample_args_keys + sample_kwargs_keys
fwd_sample_qs[m_chunk].append((sample_keys, per_callable_fwd_idx))
if consumed_sample_q.get(sample_keys, []):
reuse_fwd_idx = consumed_sample_q[sample_keys].pop(0)
sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx]
sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx]
fwd_idx[m_chunk] += 1
......@@ -217,7 +223,12 @@ def _make_graphed_callables(
num_consumed_samples = min(
len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk]
)
consumed_sample_q += fwd_sample_qs[m_chunk][:num_consumed_samples]
for sample_keys, per_callable_fwd_idx in fwd_sample_qs[m_chunk][
:num_consumed_samples
]:
if sample_keys not in consumed_sample_q:
consumed_sample_q[sample_keys] = []
consumed_sample_q[sample_keys].append(per_callable_fwd_idx)
fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:]
if fp8_weight_caching:
......@@ -338,6 +349,16 @@ def _make_graphed_callables(
def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
visited_te_modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
assert module._module_groups is not None, "Should have been initialized by warmup"
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
visited_te_modules.add(basic_op)
# Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()):
......@@ -410,8 +431,8 @@ def _make_graphed_callables(
per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
static_grad_outputs = None
previous_per_callable_bwd_idx = None
static_grad_outputs_dict = {}
previous_chunk_last_callable_bwd_idx = None
for c_id in _order:
if c_id > 0:
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
......@@ -434,6 +455,7 @@ def _make_graphed_callables(
else:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk = -c_id - 1
previous_per_callable_bwd_idx = None
for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))):
per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
......@@ -442,9 +464,21 @@ def _make_graphed_callables(
static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx]
# For now, assumes all static_outputs require grad
if not _reuse_graph_input_output_buffers or static_grad_outputs is None:
if _reuse_graph_input_output_buffers:
# Note for _reuse_graph_input_output_buffers: grad output is only used
# within backward, so we can reuse the same static buffers every time.
static_grad_outputs_keys = tuple(
(o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad
)
if static_grad_outputs_keys in static_grad_outputs_dict:
static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys]
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None
for o in static_outputs
)
static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
......@@ -484,19 +518,29 @@ def _make_graphed_callables(
per_callable_static_outputs[per_callable_bwd_idx] = make_weak_ref(
static_outputs
)
# Weak ref the static grad inputs of the previous backward pass.
# Note: After a backward pass, we assume Mcore will send the
# grad input to another pipeline parallel rank and that the
# communication is finished before the end of the next backward
# pass.
# Weak ref the static grad inputs of the previous backward pass within the
# same chunk.
if previous_per_callable_bwd_idx is not None:
per_callable_static_grad_inputs[previous_per_callable_bwd_idx] = (
make_weak_ref(
per_callable_static_grad_inputs[previous_per_callable_bwd_idx]
)
idx = previous_per_callable_bwd_idx
per_callable_static_grad_inputs[idx] = make_weak_ref(
per_callable_static_grad_inputs[idx]
)
previous_per_callable_bwd_idx = per_callable_bwd_idx
# Weak ref the static grad inputs of the previous chunk's last backward
# pass.
# Note: After a chunk's backward pass, we assume Mcore will send the grad
# input to another pipeline parallel rank and that the communication is
# finished before the end of the next chunk's backward pass.
if l_no == 0:
if previous_chunk_last_callable_bwd_idx is not None:
idx = previous_chunk_last_callable_bwd_idx
per_callable_static_grad_inputs[idx] = make_weak_ref(
per_callable_static_grad_inputs[idx]
)
previous_chunk_last_callable_bwd_idx = per_callable_bwd_idx
bwd_idx[m_chunk] += 1
else:
# Capture forward graphs
......@@ -674,15 +718,13 @@ def _make_graphed_callables(
# run the graph, otherwise run the original forward method
if func.training == graph_training_state:
# Set the FP8 group from global amax reduction.
if FP8GlobalStateManager.is_fp8_enabled():
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
for m in func.modules():
if (
isinstance(m, TransformerEngineBaseModule)
and FP8GlobalStateManager.is_fp8_enabled()
):
if m not in visited_te_modules:
# Only Set the FP8 meta for the modules included by forward
continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if isinstance(m, TransformerEngineBaseModule):
from transformer_engine.pytorch.attention.dot_product_attention import (
DotProductAttention,
)
......@@ -699,6 +741,18 @@ def _make_graphed_callables(
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m.fp8_meta,
)
elif isinstance(m, BasicOperation):
for mode in ("forward", "backward"):
if m.num_quantizers(mode):
m._fp8_metas[mode][
"fp8_group"
] = FP8GlobalStateManager.get_fp8_group()
m._fp8_metas[mode][
"recipe"
] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m._fp8_metas[mode],
)
return graphed(*user_args, **user_kwargs)
return orig_fwd(*user_args, **user_kwargs)
......@@ -721,7 +775,7 @@ def _make_graphed_callables(
def save_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_recipe: Recipe,
fp8_recipe: Optional[Recipe],
) -> Optional[List[Any]]:
"""
Returns the FP8 tensors for all modules
......@@ -740,7 +794,7 @@ def save_fp8_tensors(
m.adjust_amax_history_length(fp8_recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, BasicOperation):
m.pre_first_forward(recipe=fp8_recipe)
m.reset_recipe_state(recipe=fp8_recipe)
module_tensors = m._save_fp8_metas()
fp8_tensors.append(module_tensors)
return fp8_tensors
......@@ -777,7 +831,7 @@ def make_graphed_callables(
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
fp8_enabled: bool = False,
fp8_calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None,
fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None,
......@@ -828,7 +882,7 @@ def make_graphed_callables(
data of fp8 tensors even when executing without fp8 enabled. This is
useful for saving an inference ready fp8 checkpoint while training
using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None`
fp8_recipe: Recipe, default = `None`
recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors
......@@ -844,7 +898,10 @@ def make_graphed_callables(
"""
set_capture_start()
fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
if fp8_enabled and fp8_recipe is None:
fp8_recipe = get_default_fp8_recipe()
elif not fp8_enabled:
fp8_recipe = None
# Handle single module.
just_one_callable = False
......
......@@ -136,30 +136,43 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
@jit_fuser
def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor:
"""L2 normalization fused - inference version"""
x_squared = x.pow(2)
x_fp32 = x.float()
x_squared = x_fp32.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
return x * rsqrt_norm
y_fp32 = x_fp32 * rsqrt_norm
return y_fp32.to(x.dtype)
@jit_fuser
def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
"""L2 normalization fused - training version that returns intermediate values"""
x_squared = x.pow(2)
x_fp32 = x.float()
x_squared = x_fp32.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
y = x * rsqrt_norm
l2_norm_squared_eps = l2_norm_squared + eps
rsqrt_norm = torch.rsqrt(l2_norm_squared_eps)
y_fp32 = x_fp32 * rsqrt_norm
y = y_fp32.to(x.dtype)
return y, rsqrt_norm
@jit_fuser
def l2normalization_backward_fused_(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float
grad_output: torch.Tensor,
x: torch.Tensor,
rsqrt_norm: torch.Tensor,
eps: float,
) -> torch.Tensor:
"""L2 normalization backward fused"""
x_dy_sum = (x * grad_output).sum(dim=-1, keepdim=True)
x_norm_squared = x.pow(2).sum(dim=-1, keepdim=True) + eps
return rsqrt_norm * (grad_output - x * x_dy_sum / x_norm_squared)
x_fp32 = x.float()
grad_output_fp32 = grad_output.float()
x_dy_sum = (x_fp32 * grad_output_fp32).sum(dim=-1, keepdim=True)
x_squared = x_fp32.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
x_norm_squared = l2_norm_squared + eps
dx_fp32 = rsqrt_norm * (grad_output_fp32 - x_fp32 * x_dy_sum / x_norm_squared)
return dx_fp32.to(x.dtype)
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
......@@ -193,7 +206,10 @@ def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor
def l2normalization_backward_fused(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float
grad_output: torch.Tensor,
x: torch.Tensor,
rsqrt_norm: torch.Tensor,
eps: float,
) -> torch.Tensor:
"""Disable native AMP for l2normalization_backward_fused_"""
with gpu_autocast_ctx(enabled=False):
......
......@@ -42,11 +42,12 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype
from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled
from torch.utils.cpp_extension import IS_HIP_EXTENSION
__all__ = ["initialize_ub", "destroy_ub"]
......@@ -173,7 +174,7 @@ def initialize_ub(
```
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_dgrad"]`.
"fc2_fprop", "fc2_wgrad"]`.
bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
......@@ -272,9 +273,11 @@ def initialize_ub(
"qkv_fprop",
"qkv_dgrad",
"proj_dgrad",
"proj_wgrad",
"fc1_fprop",
"fc1_dgrad",
"fc2_dgrad",
"fc2_wgrad",
]
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
......@@ -284,29 +287,34 @@ def initialize_ub(
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop", "fc2_fprop"],
"pipeline": [],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
"external": ["proj_wgrad", "fc2_wgrad"],
}
elif bool(int(os.getenv("NVTE_PROJ_NO_PIPELINE_OVERLAP", "0"))):
methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop"],
"pipeline": ["fc2_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
"external": ["proj_wgrad", "fc2_wgrad"],
}
elif bool(int(os.getenv("NVTE_FC2_NO_PIPELINE_OVERLAP", "0"))):
methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "fc2_fprop"],
"pipeline": ["proj_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
"external": ["proj_wgrad", "fc2_wgrad"],
}
else:
methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"pipeline": ["proj_fprop", "fc2_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
"external": ["proj_wgrad", "fc2_wgrad"],
}
# AG-RS overlap pairs of layers forming a tensor-parallel block
ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"}
global layers_atomic_ring_exchange
layers_atomic_ring_exchange = []
......@@ -360,7 +368,7 @@ def initialize_ub(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
if method == "bulk":
if method in ("bulk", "external"):
warnings.warn(
f"At {name}, atoimic GEMM not is supported for a bulk overlap."
"Defaulting to `atomic_gemm=False`."
......@@ -389,6 +397,16 @@ def initialize_ub(
if atomic_gemm and method == "ring_exchange":
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message
if name in external_gemm_to_overlap:
assert method == "external", (
f"At {name}, `external` overlap method is specified, but the selected method is"
f" {method}"
)
assert external_gemm_to_overlap[name] in methods["ring_exchange"], (
f"At {name}, `external` overlap method is specified, but the external gemm"
f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method"
)
buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
if method == "ring_exchange":
ub_obj = tex.CommOverlapP2P(
......@@ -437,7 +455,9 @@ def initialize_ub(
new_method = ub_cfgs[name]["method"]
methods[new_method].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
for name in (
methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"]
):
if name in remove_ag_gemm_dgrad:
continue
ub_cfg = get_default_config(name)
......@@ -610,6 +630,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None
self.next_iter_when_debug_should_be_run = 0
self.fp8_initialized = False
self.fp8 = False
self.fp8_calibration = False
......@@ -628,6 +649,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fsdp_group = None
self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None
self.wgrad_accumulation_and_reduce_hooks = []
if not TEDebugState.debug_enabled:
TEDebugState.initialize()
......@@ -1339,22 +1361,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Try getting workspace from cache
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)
if quantizer is not None and isinstance(out, MXFP8TensorBase):
# Reset cache if workspace is invalid
if out is not None and quantizer is not None:
reset_cache = False
if isinstance(out, Float8TensorBase):
if (
not is_non_tn_fp8_gemm_supported()
and quantizer.columnwise_usage
and out._transpose is None
):
reset_cache = True
elif isinstance(out, MXFP8TensorBase):
if quantizer.rowwise_usage and out._rowwise_data is None:
out = None
del self._fp8_workspaces[cache_name]
reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None:
reset_cache = True
if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
reset_cache = True
if reset_cache:
out = None
del self._fp8_workspaces[cache_name]
is_debug = isinstance(quantizer, DebugQuantizer)
is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor)
if is_debug != is_out_debug_tensor:
out = None
# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
......@@ -1421,6 +1451,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook):
"""
This method is used to manually control the weight gradient accumulation and reduce.
This method should be called before the backward() method.
Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation
and reduce in backward();
And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method.
"""
self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
......@@ -1431,14 +1471,58 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names]
weight_tensor = noop_cat(unfused_weights)
weight_tensor = noop_cat(self._get_weight_tensors())
if weight_tensor.grad is None:
weight_tensor.grad = wgrad.to(weight_tensor.dtype)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None:
bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del wgrad
del bgrad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
def is_debug_iter(self) -> bool:
"""
This function checks if the debug should be enabled for this layer.
"""
debug = TEDebugState.debug_enabled
if not debug:
return False
self._validate_name()
# If layer is run first time in new iteration,
# we need to check if the debug should be enabled for this layer -
# maybe in previous iterations debug features returned information
# that no feature will be active for this layer for multiple next iterations.
started_new_iteration = TEDebugState.get_iteration() != getattr(
self, "debug_last_iteration", None
)
if started_new_iteration:
if self.next_iter_when_debug_should_be_run is None:
debug = False
else:
debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
self.debug_last_iteration = TEDebugState.get_iteration()
return debug
def no_debug_features_active(self, quantizers):
"""
Checks if any debug feature is active for this layer.
"""
run_current = any_feature_enabled(quantizers)
# Sometimes features inform that they will not be enabled for particular layer
# for multiple next iterations.
self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers)
if not run_current:
return True
if self.primary_weights_in_fp8:
raise RuntimeError("FP8 weights are not supported in debug mode.")
return False
def _validate_name(self):
"""
......@@ -1446,6 +1530,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
If no name is assigned, it creates a default name with layer count as the variable.
"""
if self.name is not None:
return
assert TEDebugState.debug_enabled
import nvdlfw_inspect.api as debug_api
......@@ -1494,29 +1580,3 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
" Please check the recipes assigned during fp8_model_init() and"
" fp8_autocast() calls."
)
def _turn_off_unsupported_features_in_debug(self):
if (
getattr(self, "ub_bulk_wgrad", False)
or getattr(self, "ub_bulk_dgrad", False)
or getattr(self, "ub_overlap_ag", False)
or getattr(self, "ub_overlap_rs_dgrad", False)
or getattr(self, "ub_overlap_rs", False)
):
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
"UserBuffers are not supported in debug module. "
"Using UB optimization will not affect the debug module. ",
level=logging.WARNING,
)
if hasattr(self, "ub_bulk_wgrad"):
self.ub_bulk_wgrad = None
if hasattr(self, "ub_bulk_dgrad"):
self.ub_bulk_dgrad = None
if hasattr(self, "ub_overlap_ag"):
self.ub_overlap_ag = None
if hasattr(self, "ub_overlap_rs_dgrad"):
self.ub_overlap_rs_dgrad = None
if hasattr(self, "ub_overlap_rs"):
self.ub_overlap_rs = None
......@@ -667,6 +667,12 @@ class GroupedLinear(TransformerEngineBaseModule):
self.reset_parameters(defer_init=device == "meta")
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
for i in range(self.num_gemms):
if name in (f"weight{i}", f"bias{i}"):
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
......@@ -747,7 +753,9 @@ class GroupedLinear(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
......@@ -822,19 +830,21 @@ class GroupedLinear(TransformerEngineBaseModule):
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2]
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms):
weight_param = getattr(self, f"weight{i}")
if weight_param.grad is None:
weight_param.grad = wgrad_list[i].to(weight_param.dtype)
if weight_params[i].grad is None:
weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype)
if self.use_bias:
for i in range(self.num_gemms):
bias_param = getattr(self, f"bias{i}")
if bias_param.grad is None:
bias_param.grad = grad_biases_[i].to(bias_param.dtype)
if bias_params[i].grad is None:
bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype)
del grad_biases_
del wgrad_list
del tensor_list
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
......
......@@ -62,9 +62,7 @@ from ..tensor.quantized_tensor import (
restore_from_saved,
)
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
......@@ -169,6 +167,13 @@ class _LayerNormLinear(torch.autograd.Function):
with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode
ub_overlap_ag_fprop = False
ub_overlap_rs_fprop = False
ub_overlap_ag_dgrad = False
ub_overlap_rs_dgrad = False
ub_bulk_wgrad = False
ub_bulk_dgrad = False
ub_obj = None
ub_type = None
ub_overlap_ag_fprop = (
......@@ -186,9 +191,7 @@ class _LayerNormLinear(torch.autograd.Function):
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if with_input_all_gather and isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
......@@ -645,7 +648,7 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None
if ctx.input_quantizer is not None:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if quantizer.supports_only_rowwise_all_gather():
# If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False)
else:
......@@ -762,27 +765,36 @@ class _LayerNormLinear(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
# overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad")
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, mxfp8_grad_output_work = gather_along_first_dim(
# We use the send stream to copy into the userbuffers.
# This is the same stream that we will use to access the data in the AG,
# so we dont need to add any syncs yet.
with torch.cuda.stream(dgrad_send_stream):
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_overlap_wgrad,
grad_outputs[0],
ctx.grad_output_quantizer,
ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
)
# Synchronize with the main stream
mxfp8_grad_output_work.wait()
# Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
tex.bulk_overlap_ag_with_external_gemm(
ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
)
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
......@@ -1177,8 +1189,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if tp_group is None:
self.tp_size = tp_size
......@@ -1396,6 +1406,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in self.weight_names or name in self.bias_names:
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
......@@ -1480,9 +1495,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
"""
if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
debug = self.is_debug_iter()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
......@@ -1498,7 +1512,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
......@@ -1511,13 +1527,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if not any_feature_enabled(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
if self.no_debug_features_active(quantizers):
debug = False
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
quantizers = self._get_quantizers(fp8_output, fp8_grad)
(
input_quantizer,
......
......@@ -69,7 +69,6 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
Quantizer,
prepare_for_saving,
......@@ -79,7 +78,6 @@ from ..cpp_extensions import (
general_gemm,
)
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.utils import any_feature_enabled
from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["LayerNormMLP"]
......@@ -224,6 +222,12 @@ class _LayerNormMLP(torch.autograd.Function):
device = inp.device
# Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode
ub_overlap_ag = False
ub_overlap_rs = False
ub_overlap_rs_dgrad = False
ub_bulk_wgrad = False
ub_bulk_dgrad = False
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
......@@ -239,9 +243,7 @@ class _LayerNormMLP(torch.autograd.Function):
if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor")
fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input)
if sequence_parallel and isinstance(
fc1_input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
if sequence_parallel and fc1_input_quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False)
......@@ -850,26 +852,37 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
# overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = (
ub_obj_fc2_dgrad.get_communication_stream()
)
ub_obj_fc2_wgrad = get_ub("fc2_wgrad")
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_fc2_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, mxfp8_fc2_grad_output_work = gather_along_first_dim(
# We use the send stream to copy into the userbuffers.
# This is the same stream that we will use to access the data in the AG,
# so we dont need to add any syncs yet.
with torch.cuda.stream(dgrad_send_stream):
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fc2_wgrad,
grad_outputs[0],
ctx.fc2_grad_output_quantizer,
ctx.tp_group,
async_op=True,
quantizer=ctx.fc2_grad_output_quantizer,
)
# Synchronize with the main stream
mxfp8_fc2_grad_output_work.wait()
# Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
tex.bulk_overlap_ag_with_external_gemm(
ub_obj_fc2_wgrad, dgrad_send_stream, dgrad_recv_stream
)
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
......@@ -1541,9 +1554,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if tp_group is None:
......@@ -1660,6 +1670,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
warmup_jit_bias_gelu_all_dtypes(
self.size_per_partition, seq_length, micro_batch_size
)
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in ["fc1_weight", "fc2_weight", "fc1_bias", "fc2_bias"]:
param.skip_backward_post_hook = True
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
......@@ -1742,9 +1756,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""
if is_in_onnx_export_mode():
return self.onnx_forward(inp)
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
debug = self.is_debug_iter()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
......@@ -1758,7 +1771,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
if get_ub("fc2_fprop").is_fp8_ubuf():
fp8_output = True
with self.prepare_forward(inp, num_gemms=2) as inp:
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=2) as inp:
quantizers = (
self._get_quantizers(fp8_output)
......@@ -1766,12 +1781,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
else self._get_debug_quantizers(fp8_output)
)
if debug:
if not any_feature_enabled(quantizers):
quantizers = self._get_quantizers(fp8_output)
if self.no_debug_features_active(quantizers):
debug = False
if isinstance(self.fc1_weight, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
quantizers = self._get_quantizers(fp8_output)
# Get quantizers
(
......@@ -2169,3 +2181,5 @@ class LayerNormMLP(TransformerEngineBaseModule):
del fc2_wgrad
del fc1_wgrad
del fc1_bias_grad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
......@@ -66,12 +66,9 @@ from ..tensor.quantized_tensor import (
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
__all__ = ["Linear"]
......@@ -140,6 +137,12 @@ class _Linear(torch.autograd.Function):
)
# Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode
ub_overlap_rs_fprop = False
ub_overlap_ag_fprop = False
ub_overlap_rs_dgrad = False
ub_bulk_wgrad = False
ub_bulk_dgrad = False
ub_obj = None
ub_type = None
if ub_overlap_rs_fprop:
......@@ -171,16 +174,19 @@ class _Linear(torch.autograd.Function):
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage(
rowwise=True, columnwise=backward_needs_input and not save_original_input
)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
if save_original_input:
# No need for column-wise data since this
# tensor will not be cached for backward pass
input_quantizer.set_usage(columnwise=False)
own_quantized_input = False
inputmat = input_quantizer(inputmat)
own_quantized_input = True
else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
......@@ -345,23 +351,30 @@ class _Linear(torch.autograd.Function):
inputmat = inp
ctx.weight_quantizer = weight_quantizer
saved_inputmat = None
ctx.backward_input_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)
if backward_needs_input:
if not save_original_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
# Discard unneeded data in input tensor
if (
isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.backward_input_needs_gather
backward_needs_input
and own_quantized_input
and isinstance(inputmat, QuantizedTensorBase)
):
if (
ctx.backward_input_needs_gather
and weight_quantizer.supports_only_rowwise_all_gather()
):
# All-gather is not supported with FP8 column-wise data
inputmat.update_usage(rowwise_usage=True, columnwise_usage=False)
else:
# Discard row-wise data since it is not needed in backward pass
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
# Cached input tensor
saved_inputmat = None
if backward_needs_input:
saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
......@@ -547,6 +560,19 @@ class _Linear(torch.autograd.Function):
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Adjust the quantization direction approach depending
# on whether wgrad calculations will be performed.
# NOTE: If requires_dgrad is False, disabling `rowwise` quantization and keeping `columnwise` quantization
# results in `Assertion failed: output_tensor->has_data(). Quantizing in only the columnwise direction not supported yet!`
# NOTE: For `ctx.bias is True`, selected quantize kernel errors with
# `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.`
if (
not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
):
ctx.grad_output_quantizer.set_usage(columnwise=False)
# Prepare grad output tensor
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
(
......@@ -573,11 +599,18 @@ class _Linear(torch.autograd.Function):
inputmat_total = None
inputmat_total_work = None
if ctx.requires_wgrad:
input_is_quantized = isinstance(inputmat, QuantizedTensorBase)
if ctx.fp8 or ctx.debug:
if not input_is_quantized:
if isinstance(inputmat, QuantizedTensorBase):
# Input tensor is already quantized
pass
elif ctx.debug:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else:
# Quantize input tensor
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data
quantizer.set_usage(
rowwise=True,
columnwise=not ctx.backward_input_needs_gather,
......@@ -586,7 +619,7 @@ class _Linear(torch.autograd.Function):
quantizer.set_usage(rowwise=False, columnwise=True)
inputmat = quantizer(inputmat)
else:
if input_is_quantized:
if isinstance(inputmat, QuantizedTensorBase):
inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
else:
inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
......@@ -594,7 +627,7 @@ class _Linear(torch.autograd.Function):
quantizer = None
if ctx.fp8 or ctx.debug:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if quantizer.supports_only_rowwise_all_gather():
# If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False)
else:
......@@ -726,26 +759,36 @@ class _Linear(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
# overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad")
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, grad_output_work = gather_along_first_dim(
# We use the send stream to copy into the userbuffers.
# This is the same stream that we will use to access the data in the AG,
# so we dont need to add any syncs yet.
with torch.cuda.stream(dgrad_send_stream):
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_overlap_wgrad,
grad_output_arg,
ctx.grad_output_quantizer,
ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
)
# Synchronize with the main stream
grad_output_work.wait()
# Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
tex.bulk_overlap_ag_with_external_gemm(
ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
......@@ -1067,9 +1110,6 @@ class Linear(TransformerEngineBaseModule):
self.save_original_input = save_original_input
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if device == "meta":
......@@ -1261,6 +1301,11 @@ class Linear(TransformerEngineBaseModule):
else:
self.gemm_bias_unfused_add = False
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in self.weight_names or name in self.bias_names:
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
......@@ -1326,9 +1371,7 @@ class Linear(TransformerEngineBaseModule):
if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
debug = self.is_debug_iter()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
......@@ -1344,7 +1387,9 @@ class Linear(TransformerEngineBaseModule):
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
......@@ -1356,14 +1401,11 @@ class Linear(TransformerEngineBaseModule):
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if not any_feature_enabled(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
if self.no_debug_features_active(quantizers):
debug = False
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
quantizers = self._get_quantizers(fp8_output, fp8_grad)
(
input_quantizer,
......
......@@ -5,11 +5,13 @@
"""Single tensor operations supported by the operation fuser."""
from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU
from .add_in_place import AddInPlace
from .add_extra_input import AddExtraInput
from .all_gather import AllGather
from .all_reduce import AllReduce
from .basic_linear import BasicLinear
from .bias import Bias
from .constant_scale import ConstantScale
from .dropout import Dropout
from .identity import Identity
from .l2normalization import L2Normalization
from .layer_norm import LayerNorm
......
......@@ -11,7 +11,6 @@ from typing import Optional
import torch
import transformer_engine_torch as tex
from ...fp8 import FP8GlobalStateManager
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
......@@ -71,7 +70,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
......@@ -87,14 +86,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Check input tensor
x = maybe_dequantize(input_.contiguous(), dtype)
# Check if quantized compute is enabled
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
quantizer = None
if with_quantized_compute:
quantizer = next_op_input_quantizer
# Launch kernel
y = self._activation_forward_impl(x, quantizer)
y = self._activation_forward_impl(x, next_op_input_quantizer)
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
......@@ -103,10 +96,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
x = input_quantizer(x)
# Save state for backward pass
if ctx.requires_grad:
ctx.save_for_backward(x)
ctx.with_quantized_compute = with_quantized_compute
ctx.dtype = dtype
ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
return y
......@@ -125,13 +118,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Check grad output tensor
dy = maybe_dequantize(grad_output.contiguous(), x.dtype)
# Check if quantized compute is enabled
quantizer = None
if ctx.with_quantized_compute:
quantizer = ctx.prev_op_grad_input_quantizer
# Launch kernel
dx = self._activation_backward_impl(dy, x, quantizer)
dx = self._activation_backward_impl(dy, x, ctx.prev_op_grad_output_quantizer)
# Clear input tensor if possible
clear_tensor_data(x)
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Fusible operation for in-place add."""
"""Fusible operation for adding extra input tensor."""
from __future__ import annotations
from collections.abc import Iterable
......@@ -18,16 +18,17 @@ from transformer_engine.pytorch.ops.op import (
from transformer_engine.pytorch.tensor import Quantizer
class AddInPlace(BasicOperation):
"""Add in-place
class AddExtraInput(BasicOperation):
"""Add extra input tensor
This operation requires an extra tensor input to the operation
fuser. The main input is added in-place to the extra input, and a
view of the extra input is output.
user. It returns the sum of the main input and the extra input.
If in_place=True, the main input is added in-place to the extra
input, and a view of the extra input is output.
This operation is considered an advanced feature and most users
are discouraged from using it. In-place operations break some
autograd assumptions and they can result in subtle, esoteric bugs.
Using this operation with in_place=True is considered an advanced
feature and most users are discouraged from it. In-place operations
break some autograd assumptions and they can result in subtle, esoteric bugs.
Compare to `MakeExtraOutput`, which does a similar operation in
the backward pass.
......@@ -37,6 +38,10 @@ class AddInPlace(BasicOperation):
# Operation expects buffer for output tensor
num_extra_inputs: int = 1
def __init__(self, *, in_place: bool = False):
super().__init__()
self._in_place = in_place
def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
......@@ -59,12 +64,17 @@ class AddInPlace(BasicOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
output = basic_op_extra_inputs[0][0].detach()
output += input_
extra_input = basic_op_extra_inputs[0][0]
if self._in_place:
extra_input = extra_input.detach()
extra_input += input_
output = extra_input
else:
output = extra_input + input_
return output, [()]
def fuser_backward(
......
......@@ -40,7 +40,7 @@ class AllGather(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
out: torch.Tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment