Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
...@@ -15,11 +15,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -15,11 +15,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, wd, momentum, dampening, lr, nesterov, first_run, num_tensors, wd, momentum, dampening, lr, nesterov, first_run,
wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream()); wd_after_momentum, scale, at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -108,10 +108,16 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -108,10 +108,16 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
} }
} }
TensorWrapper unquantized_out_cu; TensorWrapper unquantized_out_cu;
py::object unquantized_out;
if (force_unfused_kernel) { if (force_unfused_kernel) {
NoneQuantizer q{none}; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
py::object unquantized_out; auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none};
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
} }
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
...@@ -139,45 +145,12 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -139,45 +145,12 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
NVTE_SCOPED_GIL_RELEASE({ } else {
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), my_quantizer->quantize(unquantized_out_cu, out_cu);
at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
} }
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
} }
return {out, py::cast(mu), py::cast(rsigma)}; return {out, py::cast(mu), py::cast(rsigma)};
...@@ -269,10 +242,16 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -269,10 +242,16 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
} }
} }
TensorWrapper unquantized_out_cu; TensorWrapper unquantized_out_cu;
py::object unquantized_out;
if (force_unfused_kernel) { if (force_unfused_kernel) {
NoneQuantizer q{none}; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
py::object unquantized_out; auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none};
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
} }
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
...@@ -300,45 +279,12 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -300,45 +279,12 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
NVTE_SCOPED_GIL_RELEASE({ } else {
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), my_quantizer->quantize(unquantized_out_cu, out_cu);
at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
} }
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
} }
return {out, py::none(), py::cast(rsigma)}; return {out, py::none(), py::cast(rsigma)};
......
...@@ -111,7 +111,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -111,7 +111,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"), py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"),
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false,
py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt);
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
...@@ -228,6 +229,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -228,6 +229,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims,
"Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend, m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend,
"Get Fused Attention backend", py::call_guard<py::gil_scoped_release>()); "Get Fused Attention backend", py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &transformer_engine::pytorch::compute_amax, m.def("compute_amax", &transformer_engine::pytorch::compute_amax,
...@@ -394,6 +398,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -394,6 +398,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda, &transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>()); "Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
// Comm+GEMM Overlap
m.def("bulk_overlap_ag_with_external_gemm",
&transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm,
"Bulk overlap All-Gather with a GEMM operation launched by another communicator",
py::call_guard<py::gil_scoped_release>(), py::arg("allgather_communicator"),
py::arg("send_stream"), py::arg("recv_stream"));
// Data structures // Data structures
py::class_<transformer_engine::pytorch::FP8TensorMeta>(m, "FP8TensorMeta") py::class_<transformer_engine::pytorch::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>()) .def(py::init<>())
......
...@@ -18,31 +18,64 @@ namespace pytorch { ...@@ -18,31 +18,64 @@ namespace pytorch {
at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) { at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) {
init_extension(); init_extension();
const auto dim = input.dim(); // Tensor dimensions
NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); const auto shape = getTensorShape(input);
std::vector<int64_t> transpose_shape_int64;
if (input.dim() > 2) { if (shape.size() > 0) {
input = input.view({-1, input.size(dim - 1)}); transpose_shape_int64.push_back(shape.back());
for (size_t i = 0; i < shape.size() - 1; ++i) {
transpose_shape_int64.push_back(shape[i]);
}
} }
const size_t M = shape.size() > 0 ? product(shape) / shape.back() : 1;
const size_t N = shape.size() > 0 ? shape.back() : 1;
size_t M = static_cast<size_t>(input.size(0)); // Output tensor
size_t N = static_cast<size_t>(input.size(1));
at::Tensor out; at::Tensor out;
if (output.has_value()) { if (output.has_value()) {
out = *output; out = *output;
} else { } else {
out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
out = at::empty(transpose_shape_int64, opts);
}
// Return immediately if tensor is empty
if (M == 0 || N == 0) {
return out;
} }
if (M == 0 || N == 0) return out;
// Compute transpose
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector<size_t>{M, N}, otype); auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector<size_t>{M, N}, otype);
auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector<size_t>{N, M}, otype); auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector<size_t>{N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return out; return out;
} }
at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out) {
init_extension();
// Make sure input is contiguous
const auto &input = tensor.contiguous();
// Allocate output tensor if needed
if (!out) {
auto in_shape = getTensorShape(input);
NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")");
std::vector<int64_t> out_shape_int64(in_shape.begin(), in_shape.end());
out_shape_int64[0] = static_cast<int64_t>(in_shape[1]);
out_shape_int64[1] = static_cast<int64_t>(in_shape[0]);
auto opts = at::TensorOptions().dtype(input.dtype()).device(input.device());
out = at::empty(out_shape_int64, opts);
}
// Launch kernel
const TensorWrapper te_input = makeTransformerEngineTensor(input);
TensorWrapper te_output = makeTransformerEngineTensor(*out);
nvte_swap_first_dims(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
return std::move(*out);
}
} // namespace pytorch } // namespace pytorch
} // namespace transformer_engine } // namespace transformer_engine
...@@ -12,6 +12,27 @@ ...@@ -12,6 +12,27 @@
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
namespace {
/*! @brief Transposed tensor shape
*
* The tensor is interpreted as a 2D matrix by flattening all but the
* last dimension, and then transposed.
*/
template <typename T = size_t, typename S = T>
std::vector<T> make_transpose_shape(const std::vector<S>& shape) {
std::vector<T> ret;
if (shape.size() > 0) {
ret.push_back(shape.back());
for (size_t i = 0; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
}
}
return ret;
}
} // namespace
constexpr size_t MXFP8_BLOCK_SIZE = 32; constexpr size_t MXFP8_BLOCK_SIZE = 32;
Quantizer::Quantizer(const py::handle& quantizer) { Quantizer::Quantizer(const py::handle& quantizer) {
...@@ -37,24 +58,36 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti ...@@ -37,24 +58,36 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti
this->dtype = type; this->dtype = type;
} }
std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor( std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor(const std::vector<size_t>& shape,
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const { DType dtype) const {
at::TensorOptions opts; const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA); const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA);
std::vector<int64_t> torch_shape; return create_tensor(shape, dtype, at::empty(shape_int64, opts));
for (auto s : shape) { }
torch_shape.emplace_back(static_cast<int64_t>(s));
}
at::Tensor ret;
if (rowwise_data.has_value()) {
ret = std::move(*rowwise_data);
} else {
ret = at::empty(torch_shape, opts);
}
TensorWrapper tensor; std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor(const std::vector<size_t>& shape,
tensor.set_rowwise_data(ret.data_ptr(), dtype, shape); DType dtype,
return {std::move(tensor), py::cast(ret)}; at::Tensor data) const {
TensorWrapper out_cpp;
out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape);
set_quantization_params(&out_cpp);
return {std::move(out_cpp), py::cast(data)};
}
std::pair<TensorWrapper, py::object> NoneQuantizer::convert_and_update_tensor(
py::object tensor) const {
auto tensor_pyt = tensor.cast<at::Tensor>();
TensorWrapper out_cpp;
out_cpp.set_rowwise_data(tensor_pyt.data_ptr(),
GetTransformerEngineDType(tensor_pyt.scalar_type()),
getTensorShape(tensor_pyt));
set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void NoneQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
NVTE_ERROR("NoneQuantizer does not support quantization");
} }
void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
...@@ -76,68 +109,180 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { ...@@ -76,68 +109,180 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
} }
std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor( std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const { const std::vector<size_t>& shape, DType dtype) const {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
at::Tensor scale_inv = at::empty(std::vector<int64_t>{1}, opts);
return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv));
}
std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data,
std::optional<at::Tensor> transpose, std::optional<at::Tensor> scale_inv) const {
using namespace pybind11::literals; using namespace pybind11::literals;
std::vector<int64_t> rowwise_torch_shape;
std::vector<int64_t> columnwise_torch_shape;
if (!shape.empty()) { // Initialize data tensor
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape.back())); const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
} if (with_data && !data) {
for (size_t i = 0; i < shape.size(); ++i) { const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
if (i < shape.size() - 1) { const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i])); data = at::empty(shape_int64, opts);
} } else if (!with_data && data) {
rowwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i])); data.reset();
} }
at::TensorOptions opts; py::object data_py = with_data ? py::cast(*data) : py::none();
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
at::Tensor data; // Initialize transpose tensor
if (rowwise_usage) { const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (rowwise_data.has_value()) { if (with_transpose && !transpose) {
data = std::move(*rowwise_data); const auto transpose_shape = make_transpose_shape<int64_t>(shape);
} else { const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data = at::empty(rowwise_torch_shape, opts); transpose = at::empty(transpose_shape, opts);
} } else if (!with_transpose && transpose) {
transpose.reset();
} }
const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); // Initialize scale-inverse tensor
if (create_transpose) { if (!scale_inv) {
columnwise_data = at::empty(columnwise_torch_shape, opts); scale_inv = at::reciprocal(scale);
} }
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
opts = opts.dtype(torch::kFloat32); // Construct Python FP8 tensor
// TODO: Replace with an empty tensor. py::object out_py;
at::Tensor scale_inv = at::reciprocal(scale);
py::object ret;
if (internal) { if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass)); py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer); "quantizer"_a = this->quantizer);
} else { } else {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass)); py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
"data"_a = py_data, "fp8_scale_inv"_a = scale_inv, out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv,
"quantizer"_a = this->quantizer); "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
} }
TensorWrapper tensor(this->get_scaling_mode());
if (rowwise_usage) { // Construct C++ FP8 tensor
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); TensorWrapper out_cpp(this->get_scaling_mode());
tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1}); if (with_data) {
out_cpp.set_rowwise_data(data->data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
if (with_transpose) {
const auto transpose_shape = make_transpose_shape(shape);
out_cpp.set_columnwise_data(transpose->data_ptr(), this->dtype, transpose_shape);
out_cpp.set_columnwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> Float8Quantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor.");
// Expected buffers
const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer.");
// Extract buffers from Python tensor
auto data_py = tensor.attr("_data");
auto transpose_py = tensor.attr("_transpose");
const bool has_data = !data_py.is_none();
const bool has_transpose = !transpose_py.is_none();
NVTE_CHECK(has_data || has_transpose, "Float8Tensor has no data.");
std::optional<at::Tensor> data_tensor, transpose_tensor;
if (has_data) {
data_tensor = data_py.cast<at::Tensor>();
} }
if (create_transpose) { if (has_transpose) {
std::vector<size_t> transposed_shape; transpose_tensor = transpose_py.cast<at::Tensor>();
for (auto s : columnwise_torch_shape) { }
transposed_shape.emplace_back(static_cast<size_t>(s)); at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast<at::Tensor>();
// Tensor dimensions
std::vector<size_t> shape;
if (has_transpose) {
const auto transpose_shape = getTensorShape(*transpose_tensor);
if (transpose_shape.size() > 0) {
for (size_t i = 1; i < transpose_shape.size(); ++i) {
shape.push_back(transpose_shape[i]);
}
shape.push_back(transpose_shape.front());
}
if (has_data) {
auto expected_shape = getTensorShape(*data_tensor);
NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape,
") and transpose (shape=", transpose_shape, ") do not match");
} }
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); } else { // Already checked has_data == true
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1}); shape = getTensorShape(*data_tensor);
} }
this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)}; // Coerce data tensor
if (has_data && !need_data) {
data_tensor.reset();
data_py = py::none();
tensor.attr("_data") = data_py;
} else if (!has_data && need_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data_tensor = at::empty(shape_int64, opts);
data_py = py::cast(data_tensor);
tensor.attr("_data") = data_py;
}
// Coerce transpose tensor
if (has_transpose && !need_transpose) {
transpose_tensor.reset();
transpose_py = py::none();
tensor.attr("_transpose") = transpose_py;
} else if (!has_transpose && need_transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
transpose_tensor = at::empty(transpose_shape, opts);
transpose_py = py::cast(transpose_tensor);
tensor.attr("_transpose") = transpose_py;
}
tensor.attr("_transpose_invalid") = !need_transpose;
// Coerce other attrs
tensor.attr("_fp8_dtype") = dtype;
// Construct C++ FP8 tensor
TensorWrapper out_cpp;
if (data_tensor) {
out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
if (transpose_tensor) {
const auto transpose_shape = make_transpose_shape(shape);
out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape);
out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void Float8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
if (input.numel() == 0) {
return;
}
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
} }
Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer) Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer)
...@@ -187,71 +332,223 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso ...@@ -187,71 +332,223 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso
} }
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tensor( std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const { const std::vector<size_t>& shape, DType dtype) const {
using namespace pybind11::literals; using namespace pybind11::literals;
std::vector<int64_t> rowwise_torch_shape;
std::vector<int64_t> columnwise_torch_shape;
std::vector<int64_t> scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv
if (!shape.empty()) { // Initialize data tensor
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape.back())); at::Tensor data_tensor;
const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
if (with_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data_tensor = at::empty(shape_int64, opts);
} }
for (size_t i = 0; i < shape.size(); ++i) {
if (i < shape.size() - 1) { // Initialize transpose tensor
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i])); at::Tensor transpose_tensor;
} const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
rowwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i])); if (with_transpose) {
} const auto transpose_shape = make_transpose_shape<int64_t>(shape);
at::TensorOptions opts; const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); transpose_tensor = at::empty(transpose_shape, opts);
at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(rowwise_torch_shape, opts);
}
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
} }
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
// In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. // Initialize scale-inverse tensor
at::Tensor scale_inv = at::reciprocal(scale); at::Tensor scale_inv_tensor;
{
const std::vector<int64_t> scale_inv_shape = {1};
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
scale_inv_tensor = at::empty(scale_inv_shape, opts);
}
py::object ret; // Construct Python FP8 tensor
py::object out_py;
py::object data_py = with_data ? py::cast(data_tensor) : py::none();
py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none();
if (internal) { if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass)); py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer); "quantizer"_a = this->quantizer);
} else { } else {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass)); py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
"data"_a = py_data, "fp8_scale_inv"_a = scale_inv, out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor,
"quantizer"_a = this->quantizer); "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
} }
TensorWrapper tensor(this->get_scaling_mode());
if (rowwise_usage) { // Construct C++ FP8 tensor
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); TensorWrapper out_cpp(this->get_scaling_mode());
tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1}); if (with_data) {
out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
} }
if (create_transpose) { if (with_transpose) {
std::vector<size_t> transposed_shape; const auto transpose_shape = make_transpose_shape(shape);
for (auto s : columnwise_torch_shape) { out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape);
transposed_shape.emplace_back(static_cast<size_t>(s)); out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_hp_tensor_with_amax(
const std::vector<size_t>& shape, DType dtype) {
amax.zero_();
auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype);
out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()),
"Float8CurrentScalingQuantizer must output to Float8Tensor.");
// Expected buffers
const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages.");
// Extract buffers from Python tensor
auto data_py = tensor.attr("_data");
auto transpose_py = tensor.attr("_transpose");
const bool has_data = !data_py.is_none();
const bool has_transpose = !transpose_py.is_none();
NVTE_CHECK(has_data || has_transpose, "Tensor has no data.");
std::optional<at::Tensor> data_tensor, transpose_tensor;
if (has_data) {
data_tensor = data_py.cast<at::Tensor>();
}
if (has_transpose) {
transpose_tensor = transpose_py.cast<at::Tensor>();
}
at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast<at::Tensor>();
// Tensor dimensions
std::vector<size_t> shape;
if (has_transpose) {
const auto transpose_shape = getTensorShape(*transpose_tensor);
if (transpose_shape.size() > 0) {
for (size_t i = 1; i < transpose_shape.size(); ++i) {
shape.push_back(transpose_shape[i]);
}
shape.push_back(transpose_shape.front());
}
if (has_data) {
auto expected_shape = getTensorShape(*data_tensor);
NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape,
") and transpose (shape=", transpose_shape, ") do not match");
} }
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); } else { // Already checked has_data == true
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1}); shape = getTensorShape(*data_tensor);
} }
this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)}; // Coerce data tensor in Python tensor
if (has_data && !need_data) {
data_tensor.reset();
data_py = py::none();
tensor.attr("_data") = data_py;
} else if (!has_data && need_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data_tensor = at::empty(shape_int64, opts);
data_py = py::cast(data_tensor);
tensor.attr("_data") = data_py;
}
// Coerce transpose tensor
if (has_transpose && !need_transpose) {
transpose_tensor.reset();
transpose_py = py::none();
tensor.attr("_transpose") = transpose_py;
} else if (!has_transpose && need_transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
transpose_tensor = at::empty(transpose_shape, opts);
transpose_py = py::cast(transpose_tensor);
tensor.attr("_transpose") = transpose_py;
}
tensor.attr("_transpose_invalid") = !need_transpose;
// Coerce other attrs
tensor.attr("_fp8_dtype") = dtype;
// Construct C++ FP8 tensor
TensorWrapper out_cpp;
if (data_tensor) {
out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
if (transpose_tensor) {
const auto transpose_shape = make_transpose_shape(shape);
out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape);
out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag,
bool compute_amax) {
auto stream = at::cuda::getCurrentCUDAStream();
// Nothing to be done if input is empty
if (input.numel() == 0) {
return;
}
// Quantization configs
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon);
// Compute amax
if (compute_amax) {
NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); });
}
// Perform amax reduction if needed
if (with_amax_reduction) {
// allreduce amax tensor
c10d::AllreduceOptions opts;
opts.reduceOp = c10d::ReduceOp::MAX;
std::vector<at::Tensor> tensors = {amax};
NVTE_SCOPED_GIL_RELEASE({ amax_reduction_group->allreduce(tensors, opts)->wait(); });
}
// Compute scaling factor
NVTE_SCOPED_GIL_RELEASE({ nvte_compute_scale_from_amax(out.data(), quant_config, stream); });
// Cast to FP8
out.set_amax(nullptr, DType::kFloat32, out.defaultShape); // Avoid atomic amax updates
NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); });
}
void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
this->quantize_impl(input, out, noop_flag, true);
}
void Float8CurrentScalingQuantizer::quantize_with_amax(
TensorWrapper& input, TensorWrapper& out, const std::optional<TensorWrapper>& noop_flag) {
NVTE_CHECK(input.get_amax().data_ptr == amax.data_ptr(),
"Input does not use the appropriate amax tensor");
input.set_amax(nullptr, DType::kFloat32, input.defaultShape);
this->quantize_impl(input, out, noop_flag, false);
} }
Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {
...@@ -280,7 +577,7 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const ...@@ -280,7 +577,7 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const
} }
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const { const std::vector<size_t>& shape, DType dtype) const {
using namespace pybind11::literals; using namespace pybind11::literals;
std::vector<int64_t> torch_shape; std::vector<int64_t> torch_shape;
for (auto s : shape) { for (auto s : shape) {
...@@ -299,11 +596,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -299,11 +596,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
: Float8BlockScaleTensorFormat::GEMM_READY); : Float8BlockScaleTensorFormat::GEMM_READY);
if (rowwise_usage) { if (rowwise_usage) {
if (rowwise_data.has_value()) { data_rowwise = at::empty(torch_shape, opts);
data_rowwise = std::move(*rowwise_data);
} else {
data_rowwise = at::empty(torch_shape, opts);
}
auto scale_shape = get_scale_shape(shape, false); auto scale_shape = get_scale_shape(shape, false);
size_t sinv0 = scale_shape[0]; size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1]; size_t sinv1 = scale_shape[1];
...@@ -373,6 +666,177 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -373,6 +666,177 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
return {std::move(tensor), std::move(ret)}; return {std::move(tensor), std::move(ret)};
} }
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_tensor(
py::object tensor) const {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name);
if (attr_py.is_none()) {
return std::nullopt;
}
return attr_py.cast<at::Tensor>();
};
auto rowwise_data = get_tensor("_rowwise_data");
auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv");
auto columnwise_data = get_tensor("_columnwise_data");
auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv");
NVTE_CHECK(rowwise_data || columnwise_data, "FP8BlockwiseTensor has no data.");
// Tensor options and dimensions
at::TensorOptions opts;
at::TensorOptions scale_opts;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector<size_t> {
if (!columnwise_data) {
return std::vector<size_t>();
}
if (all_gather_usage) {
return getTensorShape(*columnwise_data);
}
std::vector<size_t> shape = getTensorShape(*columnwise_data);
std::vector<size_t> shape_transposed(shape.size());
for (size_t i = 0; i + 1 < shape.size(); ++i) {
shape_transposed[i] = shape[i + 1];
}
if (shape.size() > 0) {
shape_transposed[shape.size() - 1] = shape[0];
}
return shape_transposed;
};
std::vector<size_t> shape;
if (rowwise_data) {
shape = getTensorShape(*rowwise_data);
if (columnwise_data) {
auto expected_shape = get_columnwise_shape(all_gather_usage);
NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape,
") and column-wise data (shape=", expected_shape, ") do not match");
}
} else {
shape = get_columnwise_shape(all_gather_usage);
}
std::vector<int64_t> torch_shape;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
}
// Coerce row-wise data
if (rowwise_usage) {
if (!rowwise_data) {
rowwise_data = at::empty(torch_shape, opts);
tensor.attr("_rowwise_data") = *rowwise_data;
}
if (!rowwise_scale_inv) {
auto scale_shape = get_scale_shape(shape, false);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
rowwise_scale_inv =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
}
} else { // rowwise_usage == false
if (rowwise_data) {
rowwise_data.reset();
tensor.attr("_rowwise_data") = py::none();
}
if (rowwise_scale_inv) {
rowwise_scale_inv.reset();
tensor.attr("_rowwise_scale_inv") = py::none();
}
}
// Coerce column-wise data
if (columnwise_usage) {
std::vector<size_t> columnwise_shape;
std::vector<int64_t> torch_columnwise_shape;
if (torch_shape.size() > 0) {
if (!all_gather_usage) {
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
}
}
if (!columnwise_data) {
columnwise_data = at::empty(torch_columnwise_shape, opts);
tensor.attr("_columnwise_data") = *columnwise_data;
}
if (!columnwise_scale_inv) {
auto scale_shape = get_scale_shape(shape, true);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
columnwise_scale_inv =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
}
} else { // columnwise_usage == false
if (columnwise_data) {
columnwise_data.reset();
tensor.attr("_columnwise_data") = py::none();
}
if (columnwise_scale_inv) {
columnwise_scale_inv.reset();
tensor.attr("_columnwise_scale_inv") = py::none();
}
}
auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
if (rowwise_usage) {
const at::Tensor& data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor& scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
void* scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr();
const auto& rowwise_shape = getTensorShape(data_rowwise);
ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape);
const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape);
}
if (columnwise_usage) {
const at::Tensor& data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor& scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
void* scale_inv_colwise_dptr = scale_inv_colwise.data_ptr();
const auto& shape = getTensorShape(data_colwise);
ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape);
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
}
set_quantization_params(&ret);
return {std::move(ret), std::move(tensor)};
}
void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
if (input.numel() == 0) {
return;
}
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon);
if (all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
}
std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size_t>& shape, std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size_t>& shape,
bool columnwise) const { bool columnwise) const {
size_t numel = 1; size_t numel = 1;
...@@ -465,71 +929,204 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { ...@@ -465,71 +929,204 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
columnwise_data.shape); columnwise_data.shape);
} }
std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::vector<size_t>& shape,
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const { DType dtype) const {
using namespace pybind11::literals; using namespace pybind11::literals;
std::vector<int64_t> torch_shape;
size_t numel = 1;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
numel *= s;
}
TensorWrapper tensor(NVTE_MXFP8_1D_SCALING);
at::TensorOptions opts;
at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv,
columnwise_scale_inv; // TODO(pgadzinski) - change
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
at::Tensor data; // Tensor dimensions
if (rowwise_usage) { const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
if (rowwise_data.has_value()) { size_t flat_first_dim = 1;
data = std::move(*rowwise_data); if (shape.size() > 0) {
} else { for (size_t i = 0; i < shape.size() - 1; ++i) {
data = at::empty(torch_shape, opts); flat_first_dim *= shape[i];
} }
auto scale_shape = get_scale_shape(shape, false);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
rowwise_scale_inv = at::zeros({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(
rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
} }
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;
NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE,
" (got shape=", shape, ")");
const auto rowwise_scale_inv_shape = get_scale_shape(shape, false);
const auto columnwise_scale_inv_shape = get_scale_shape(shape, true);
// Allocate tensors
at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor;
at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor;
const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
if (rowwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(),
rowwise_scale_inv_shape.end());
rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts);
}
if (columnwise_usage) { if (columnwise_usage) {
auto scale_shape = get_scale_shape(shape, true); const std::vector<int64_t> scale_inv_shape_int64(columnwise_scale_inv_shape.begin(),
size_t sinv0 = scale_shape[0]; columnwise_scale_inv_shape.end());
size_t sinv1 = scale_shape[1]; columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
columnwise_data = at::empty(torch_shape, opts); columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts);
columnwise_scale_inv =
at::zeros({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, opts);
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv(
columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
} }
this->set_quantization_params(&tensor);
py::object ret; // Convert tensors to Python
auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object {
return need_cast ? py::cast(tensor) : py::none();
};
auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage);
auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage);
auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage);
auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage);
// Construct Python MXFP8 tensor
py::object out_py;
if (internal) { if (internal) {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorBasePythonClass)); py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorBasePythonClass));
ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv, "columnwise_data"_a = columnwise_data_py,
"columnwise_scale_inv"_a = columnwise_scale_inv, "rowwise_scale_inv"_a = rowwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); "columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
} else { } else {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass)); py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass));
ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, "rowwise_data"_a = rowwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv, "columnwise_data"_a = columnwise_data_py,
"columnwise_scale_inv"_a = columnwise_scale_inv, "rowwise_scale_inv"_a = rowwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); "columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
} }
return {std::move(tensor), std::move(ret)}; // Construct C++ MXFP8 tensor
TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING);
if (rowwise_usage) {
out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0,
rowwise_scale_inv_shape);
}
if (columnwise_usage) {
out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), this->dtype, shape);
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0,
columnwise_scale_inv_shape);
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor.");
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name);
if (attr_py.is_none()) {
return std::nullopt;
}
return attr_py.cast<at::Tensor>();
};
auto rowwise_data = get_tensor("_rowwise_data");
auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv");
auto columnwise_data = get_tensor("_columnwise_data");
auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv");
NVTE_CHECK(rowwise_data || columnwise_data, "MXFP8Tensor has no data.");
// Tensor dimensions
std::vector<size_t> shape;
if (columnwise_data) {
shape = getTensorShape(*columnwise_data);
if (rowwise_data) {
auto expected_shape = getTensorShape(*rowwise_data);
NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape,
") and column-wise data (shape=", shape, ") do not match");
}
} else { // Already checked columnwise_data_tensor == true
shape = getTensorShape(*rowwise_data);
}
// Coerce row-wise data
if (rowwise_usage) {
if (!rowwise_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_data = at::empty(shape_int64, opts);
tensor.attr("_rowwise_data") = *rowwise_data;
}
if (!rowwise_scale_inv) {
const auto scale_inv_shape = get_scale_shape(shape, false);
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
}
} else { // rowwise_usage == false
if (rowwise_data) {
rowwise_data.reset();
tensor.attr("_rowwise_data") = py::none();
}
if (rowwise_scale_inv) {
rowwise_scale_inv.reset();
tensor.attr("_rowwise_scale_inv") = py::none();
}
}
// Coerce column-wise data
if (columnwise_usage) {
if (!columnwise_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
columnwise_data = at::empty(shape_int64, opts);
tensor.attr("_columnwise_data") = *columnwise_data;
}
if (!columnwise_scale_inv) {
const auto scale_inv_shape = get_scale_shape(shape, true);
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
}
} else { // columnwise_usage == false
if (columnwise_data) {
columnwise_data.reset();
tensor.attr("_columnwise_data") = py::none();
}
if (columnwise_scale_inv) {
columnwise_scale_inv.reset();
tensor.attr("_columnwise_scale_inv") = py::none();
}
}
// Coerce other attrs
tensor.attr("_fp8_dtype") = dtype;
// Construct C++ MXFP8 tensor
TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING);
if (rowwise_usage) {
out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, shape);
out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0,
getTensorShape(*rowwise_scale_inv));
}
if (columnwise_usage) {
out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, shape);
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0,
getTensorShape(*columnwise_scale_inv));
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
if (input.numel() == 0) {
return;
}
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
} }
std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& shape, std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& shape,
......
...@@ -75,3 +75,98 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap ...@@ -75,3 +75,98 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
return swizzled_scale_inv; return swizzled_scale_inv;
} }
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper>& tensors, bool rowwise) {
using namespace transformer_engine::pytorch;
if (tensors.empty()) {
return std::nullopt;
}
bool all_same_scaling_mode = std::all_of(
tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) {
return val.scaling_mode() == tensors.front().scaling_mode();
});
NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same.");
if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) {
return std::nullopt;
}
std::vector<transformer_engine::TensorWrapper> wrappers;
std::vector<NVTETensor> input_tensors, output_tensors;
// Collect scale_inv shapes and calculate buffer size and offsets for scale_invs
std::vector<std::vector<size_t>> scale_inv_shapes;
std::vector<void*> scale_inv_dptrs;
size_t buffer_size = 0;
std::vector<size_t> scale_inv_offsets;
constexpr size_t scale_elem_size = 1;
for (auto& tensor : tensors) {
NVTEBasicTensor scale_inv;
if (rowwise) {
scale_inv = tensor.get_rowwise_scale_inv();
} else {
scale_inv = tensor.get_columnwise_scale_inv();
}
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_inv_offsets.push_back(buffer_size);
buffer_size += product(scale_inv_shape) * scale_elem_size;
scale_inv_shapes.emplace_back(scale_inv_shape);
scale_inv_dptrs.push_back(scale_inv.data_ptr);
}
// Allocate full buffer
auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8));
for (size_t i = 0; i < tensors.size(); ++i) {
auto& tensor = tensors[i];
void* scale_inv_dptr = scale_inv_dptrs[i];
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
auto input_shape = nvte_shape_to_vector(tensor.shape());
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING);
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
if (rowwise) {
input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
} else {
input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
output_cu.set_columnwise_data(tensor.columnwise_dptr(),
transformer_engine::DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_scale_inv(
swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
}
input_tensors.emplace_back(input_cu.data());
output_tensors.emplace_back(output_cu.data());
wrappers.emplace_back(std::move(input_cu));
wrappers.emplace_back(std::move(output_cu));
}
// Launch kernel
nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(),
input_tensors.size(), at::cuda::getCurrentCUDAStream());
return buffer;
}
...@@ -13,11 +13,18 @@ ...@@ -13,11 +13,18 @@
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
/* Swizzle the scaling factor of the input tensor. /*! \brief Swizzle the scaling factor of the input tensor.
* *
* The returned swizzled scaling factor tensor should be kept alive during the GEMM. * The returned swizzled scaling factor tensor should be kept alive during the GEMM.
*/ */
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input, std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input,
bool trans); bool rowwise);
/*! \brief Swizzle the scaling factor of the input tensors.
*
* The returned swizzled scaling factor tensors should be kept alive during the GEMMs.
*/
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper> &inputs, bool rowwise);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
...@@ -981,6 +981,15 @@ def _all_gather_fp8( ...@@ -981,6 +981,15 @@ def _all_gather_fp8(
return out, handle return out, handle
def _get_quantizer_format(quantizer: Quantizer) -> Optional[bool]:
"""Get quantizer format."""
if isinstance(quantizer, DebugQuantizer):
quantizer = quantizer.parent_quantizer
if isinstance(quantizer, Float8BlockQuantizer):
return quantizer.all_gather_usage
return None
def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None: def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
"""Make quantizer compact""" """Make quantizer compact"""
_quantizer = quantizer _quantizer = quantizer
...@@ -1129,6 +1138,10 @@ def _all_gather_fp8_blockwise( ...@@ -1129,6 +1138,10 @@ def _all_gather_fp8_blockwise(
"Dequantizing and requantizing to Float8BlockwiseQTensor." "Dequantizing and requantizing to Float8BlockwiseQTensor."
) )
inp = quantizer(inp.dequantize()) inp = quantizer(inp.dequantize())
# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
quantizer.all_gather_usage = orig_all_gather_usage quantizer.all_gather_usage = orig_all_gather_usage
# Begin to do network communication, need to make sure compact format # Begin to do network communication, need to make sure compact format
...@@ -1138,9 +1151,6 @@ def _all_gather_fp8_blockwise( ...@@ -1138,9 +1151,6 @@ def _all_gather_fp8_blockwise(
f"but found data_format={inp._data_format}" f"but found data_format={inp._data_format}"
) )
# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Coalesce NCCL collectives # Coalesce NCCL collectives
with torch.distributed._coalescing_manager( with torch.distributed._coalescing_manager(
group=process_group, group=process_group,
...@@ -1216,14 +1226,12 @@ def _all_gather_mxfp8( ...@@ -1216,14 +1226,12 @@ def _all_gather_mxfp8(
if inp._rowwise_data is not None: if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.size() in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None: elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.size() in_shape = inp._columnwise_data.size()
device = inp._columnwise_data.device device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype
else: else:
raise ValueError("Got MXFP8 input tensor without any data") raise ValueError("Got MXFP8 input tensor without any data")
dtype = torch.bfloat16 dtype = torch.bfloat16 # Guess high-precision dtype.
else: else:
raise ValueError( raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, " "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, "
...@@ -1343,6 +1351,44 @@ def gather_along_first_dim( ...@@ -1343,6 +1351,44 @@ def gather_along_first_dim(
inp = quantizer(inp) inp = quantizer(inp)
return inp, None return inp, None
# Debug case - call gather_along_first_dim on each tensor
if isinstance(inp, DebugQuantizedTensor):
out_obj = DebugQuantizedTensor(
rowwise_gemm_tensor=inp.rowwise_gemm_tensor,
columnwise_gemm_tensor=inp.columnwise_gemm_tensor,
quantizer=inp.quantizer,
layer_name=inp._layer_name,
tensor_name=inp._tensor_name,
)
rowwise = inp.get_tensor(False)
columnwise = inp.get_tensor(True)
# shapes
final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
)
rowwise_total = None
if rowwise is not None:
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[
0
]
out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise:
final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
)
columnwise_total = None
if columnwise is not None:
columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise
)
out_obj.columnwise_gemm_tensor = columnwise_total
else:
# Sometimes the same object is used both for rowwise and columnwise gemms,
# and we want to avoid double all-gathers.
out_obj.columnwise_gemm_tensor = out_obj.rowwise_gemm_tensor
return out_obj, None
# Output tensor dims # Output tensor dims
out_shape = list(inp.size()) out_shape = list(inp.size())
out_shape[0] *= world_size out_shape[0] *= world_size
...@@ -1380,34 +1426,6 @@ def gather_along_first_dim( ...@@ -1380,34 +1426,6 @@ def gather_along_first_dim(
out_shape=out_shape, out_shape=out_shape,
) )
# Debug case - call gather_along_first_dim on each tensor
if isinstance(inp, DebugQuantizedTensor):
out_obj = inp
rowwise = inp.get_tensor(False)
columnwise = inp.get_tensor(True)
final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(rowwise, Float8BlockwiseQTensorBase):
rowwise = inp._original_tensor
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise:
final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(columnwise, Float8BlockwiseQTensorBase):
columnwise = inp._original_tensor
columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise
)
out_obj.columnwise_gemm_tensor = columnwise_total
else:
out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor
return out_obj, None
# High-precision communication for quantized tensors # High-precision communication for quantized tensors
if quantizer is not None: if quantizer is not None:
warnings.warn( warnings.warn(
...@@ -1418,6 +1436,7 @@ def gather_along_first_dim( ...@@ -1418,6 +1436,7 @@ def gather_along_first_dim(
inp = inp.dequantize() inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer # Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format # means that it should directly output GEMM_READY format
compact = _get_quantizer_format(quantizer)
_set_quantizer_format(quantizer, compact=False) _set_quantizer_format(quantizer, compact=False)
out = torch.empty( out = torch.empty(
out_shape, out_shape,
...@@ -1427,6 +1446,7 @@ def gather_along_first_dim( ...@@ -1427,6 +1446,7 @@ def gather_along_first_dim(
) )
torch.distributed.all_gather_into_tensor(out, inp, group=process_group) torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out) out = quantizer(out)
_set_quantizer_format(quantizer, compact=compact)
return out, None return out, None
# Dequantize quantized tensor if not supported # Dequantize quantized tensor if not supported
......
...@@ -80,14 +80,26 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ...@@ -80,14 +80,26 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]:
return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def check_recipe_support(recipe: Recipe) -> None:
"""Check if the given recipe is supported."""
recipe_supported = True
unsupported_reason = ""
if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)):
recipe_supported, unsupported_reason = check_fp8_support()
elif isinstance(recipe, Float8BlockScaling):
recipe_supported, unsupported_reason = check_fp8_block_scaling_support()
elif isinstance(recipe, MXFP8BlockScaling):
recipe_supported, unsupported_reason = check_mxfp8_support()
assert recipe_supported, unsupported_reason
def get_default_fp8_recipe() -> Recipe: def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args.""" """FP8 recipe with default args."""
if check_mxfp8_support()[0]: if check_mxfp8_support()[0]:
# This is a temporary restriction until MXFP8 is supported for all
# gemm layouts.
if get_device_compute_capability() >= (12, 0):
return Float8BlockScaling()
return MXFP8BlockScaling() return MXFP8BlockScaling()
if get_device_compute_capability() >= (12, 0):
# This is a temporary restriction until MXFP8 is supported for all gemm layouts.
return Float8CurrentScaling()
return DelayedScaling() return DelayedScaling()
...@@ -664,6 +676,8 @@ def fp8_autocast( ...@@ -664,6 +676,8 @@ def fp8_autocast(
distributed group over which amaxes for the fp8 tensors distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step. are reduced at the end of each training step.
""" """
if enabled:
check_recipe_support(fp8_recipe)
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter( FP8GlobalStateManager.fp8_autocast_enter(
enabled=enabled, enabled=enabled,
......
...@@ -21,6 +21,8 @@ from .fp8 import ( ...@@ -21,6 +21,8 @@ from .fp8 import (
from .distributed import get_all_rng_states, graph_safe_rng_available from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule from .module.base import TransformerEngineBaseModule
from .ops.op import BasicOperation from .ops.op import BasicOperation
from .ops import Sequential
from .ops.fuser import OperationFuser
from .utils import make_weak_ref from .utils import make_weak_ref
__all__ = ["make_graphed_callables"] __all__ = ["make_graphed_callables"]
...@@ -44,7 +46,7 @@ def set_capture_end() -> None: ...@@ -44,7 +46,7 @@ def set_capture_end() -> None:
_IS_GRAPH_CAPTURING = False _IS_GRAPH_CAPTURING = False
def is_graph_capturing() -> None: def is_graph_capturing() -> bool:
"""Return whether within `make_graphed_callables`.""" """Return whether within `make_graphed_callables`."""
return _IS_GRAPH_CAPTURING return _IS_GRAPH_CAPTURING
...@@ -177,24 +179,17 @@ def _make_graphed_callables( ...@@ -177,24 +179,17 @@ def _make_graphed_callables(
assert isinstance( assert isinstance(
sample_args, list sample_args, list
), "sample_args must be a list for _reuse_graph_input_output_buffers." ), "sample_args must be a list for _reuse_graph_input_output_buffers."
len_args = len(sample_args[0])
for i, arg in enumerate(sample_args):
assert len_args == len(
arg
), "Arguments must have same length and shape for `_reuse_graph_input_output_buffers`."
len_kwargs = len(sample_kwargs[0])
assert isinstance(
sample_kwargs, list
), "sample_kwargs must be a list for _reuse_graph_input_output_buffers."
for i, kwarg in enumerate(sample_kwargs):
assert len_kwargs == len(kwarg), (
"Keyword arguments must have same length and shape for"
" `_reuse_graph_input_output_buffers`."
)
# Reorganize args and kwargs for input tensor reuse. # Reorganize args and kwargs for input tensor reuse.
# fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples.
# Each tuple contains the sample key signature and its fwd_idx. When we finish a backward
# chunk, we pop the corresponding fwd_idx and push to the consumed_sample_q.
# consumed_sample_q is keyed by the sample key signature. The value is a queue of the
# fwd_idx whose backward has been called so that we can reuse the same static buffers.
# In this way, we can reuse the same static input buffers for the non-overlapping samples
# with the same input signature.
fwd_sample_qs = {} fwd_sample_qs = {}
consumed_sample_q = [] consumed_sample_q = {}
fwd_idx = [0] * num_model_chunks fwd_idx = [0] * num_model_chunks
for c_id in _order: for c_id in _order:
m_chunk = abs(c_id) - 1 m_chunk = abs(c_id) - 1
...@@ -206,10 +201,21 @@ def _make_graphed_callables( ...@@ -206,10 +201,21 @@ def _make_graphed_callables(
fwd_sample_idx = [ fwd_sample_idx = [
sample_start_idx + i for i in range(_num_layers_per_chunk[m_chunk]) sample_start_idx + i for i in range(_num_layers_per_chunk[m_chunk])
] ]
fwd_sample_qs[m_chunk] = fwd_sample_qs.get(m_chunk, []) + fwd_sample_idx if m_chunk not in fwd_sample_qs:
fwd_sample_qs[m_chunk] = []
for per_callable_fwd_idx in fwd_sample_idx: for per_callable_fwd_idx in fwd_sample_idx:
if consumed_sample_q: sample_args_keys = tuple(
reuse_fwd_idx = consumed_sample_q.pop(0) (t.shape, t.dtype, t.layout) for t in sample_args[per_callable_fwd_idx]
)
sample_kwargs_keys = tuple(
(k, v.shape, v.dtype, v.layout)
for k, v in sorted(sample_kwargs[per_callable_fwd_idx].items())
)
sample_keys = sample_args_keys + sample_kwargs_keys
fwd_sample_qs[m_chunk].append((sample_keys, per_callable_fwd_idx))
if consumed_sample_q.get(sample_keys, []):
reuse_fwd_idx = consumed_sample_q[sample_keys].pop(0)
sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx] sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx]
sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx] sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx]
fwd_idx[m_chunk] += 1 fwd_idx[m_chunk] += 1
...@@ -217,7 +223,12 @@ def _make_graphed_callables( ...@@ -217,7 +223,12 @@ def _make_graphed_callables(
num_consumed_samples = min( num_consumed_samples = min(
len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk] len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk]
) )
consumed_sample_q += fwd_sample_qs[m_chunk][:num_consumed_samples] for sample_keys, per_callable_fwd_idx in fwd_sample_qs[m_chunk][
:num_consumed_samples
]:
if sample_keys not in consumed_sample_q:
consumed_sample_q[sample_keys] = []
consumed_sample_q[sample_keys].append(per_callable_fwd_idx)
fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:] fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:]
if fp8_weight_caching: if fp8_weight_caching:
...@@ -338,6 +349,16 @@ def _make_graphed_callables( ...@@ -338,6 +349,16 @@ def _make_graphed_callables(
def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule): if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module) visited_te_modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
visited_te_modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
assert module._module_groups is not None, "Should have been initialized by warmup"
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
visited_te_modules.add(basic_op)
# Run warmup and do the above filtering. # Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()): with torch.cuda.stream(torch.cuda.Stream()):
...@@ -410,8 +431,8 @@ def _make_graphed_callables( ...@@ -410,8 +431,8 @@ def _make_graphed_callables(
per_callable_static_grad_inputs = [None] * len(flatten_sample_args) per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
fwd_idx = [0] * num_model_chunks fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks
static_grad_outputs = None static_grad_outputs_dict = {}
previous_per_callable_bwd_idx = None previous_chunk_last_callable_bwd_idx = None
for c_id in _order: for c_id in _order:
if c_id > 0: if c_id > 0:
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
...@@ -434,6 +455,7 @@ def _make_graphed_callables( ...@@ -434,6 +455,7 @@ def _make_graphed_callables(
else: else:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk = -c_id - 1 m_chunk = -c_id - 1
previous_per_callable_bwd_idx = None
for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))): for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))):
per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
...@@ -442,9 +464,21 @@ def _make_graphed_callables( ...@@ -442,9 +464,21 @@ def _make_graphed_callables(
static_outputs = per_callable_static_outputs[per_callable_bwd_idx] static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx] bwd_graph = bwd_graphs[per_callable_bwd_idx]
# For now, assumes all static_outputs require grad # For now, assumes all static_outputs require grad
if not _reuse_graph_input_output_buffers or static_grad_outputs is None: if _reuse_graph_input_output_buffers:
# Note for _reuse_graph_input_output_buffers: grad output is only used # Note for _reuse_graph_input_output_buffers: grad output is only used
# within backward, so we can reuse the same static buffers every time. # within backward, so we can reuse the same static buffers every time.
static_grad_outputs_keys = tuple(
(o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad
)
if static_grad_outputs_keys in static_grad_outputs_dict:
static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys]
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None
for o in static_outputs
)
static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs
else:
static_grad_outputs = tuple( static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o.requires_grad else None for o in static_outputs
) )
...@@ -484,19 +518,29 @@ def _make_graphed_callables( ...@@ -484,19 +518,29 @@ def _make_graphed_callables(
per_callable_static_outputs[per_callable_bwd_idx] = make_weak_ref( per_callable_static_outputs[per_callable_bwd_idx] = make_weak_ref(
static_outputs static_outputs
) )
# Weak ref the static grad inputs of the previous backward pass.
# Note: After a backward pass, we assume Mcore will send the # Weak ref the static grad inputs of the previous backward pass within the
# grad input to another pipeline parallel rank and that the # same chunk.
# communication is finished before the end of the next backward
# pass.
if previous_per_callable_bwd_idx is not None: if previous_per_callable_bwd_idx is not None:
per_callable_static_grad_inputs[previous_per_callable_bwd_idx] = ( idx = previous_per_callable_bwd_idx
make_weak_ref( per_callable_static_grad_inputs[idx] = make_weak_ref(
per_callable_static_grad_inputs[previous_per_callable_bwd_idx] per_callable_static_grad_inputs[idx]
)
) )
previous_per_callable_bwd_idx = per_callable_bwd_idx previous_per_callable_bwd_idx = per_callable_bwd_idx
# Weak ref the static grad inputs of the previous chunk's last backward
# pass.
# Note: After a chunk's backward pass, we assume Mcore will send the grad
# input to another pipeline parallel rank and that the communication is
# finished before the end of the next chunk's backward pass.
if l_no == 0:
if previous_chunk_last_callable_bwd_idx is not None:
idx = previous_chunk_last_callable_bwd_idx
per_callable_static_grad_inputs[idx] = make_weak_ref(
per_callable_static_grad_inputs[idx]
)
previous_chunk_last_callable_bwd_idx = per_callable_bwd_idx
bwd_idx[m_chunk] += 1 bwd_idx[m_chunk] += 1
else: else:
# Capture forward graphs # Capture forward graphs
...@@ -674,31 +718,41 @@ def _make_graphed_callables( ...@@ -674,31 +718,41 @@ def _make_graphed_callables(
# run the graph, otherwise run the original forward method # run the graph, otherwise run the original forward method
if func.training == graph_training_state: if func.training == graph_training_state:
# Set the FP8 group from global amax reduction. # Set the FP8 group from global amax reduction.
for m in func.modules(): if FP8GlobalStateManager.is_fp8_enabled():
if ( fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
isinstance(m, TransformerEngineBaseModule) for m in func.modules():
and FP8GlobalStateManager.is_fp8_enabled()
):
if m not in visited_te_modules: if m not in visited_te_modules:
# Only Set the FP8 meta for the modules included by forward # Only Set the FP8 meta for the modules included by forward
continue continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if isinstance(m, TransformerEngineBaseModule):
from transformer_engine.pytorch.attention.dot_product_attention import ( from transformer_engine.pytorch.attention.dot_product_attention import (
DotProductAttention, DotProductAttention,
) )
if ( if (
isinstance(m, DotProductAttention) isinstance(m, DotProductAttention)
and not fp8_recipe.fp8_mha and not fp8_recipe.fp8_mha
and not fp8_recipe.fp8_dpa and not fp8_recipe.fp8_dpa
): ):
# Don't need to update FP8 meta for non-FP8 DPA # Don't need to update FP8 meta for non-FP8 DPA
continue continue
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m.fp8_meta, m.fp8_meta,
) )
elif isinstance(m, BasicOperation):
for mode in ("forward", "backward"):
if m.num_quantizers(mode):
m._fp8_metas[mode][
"fp8_group"
] = FP8GlobalStateManager.get_fp8_group()
m._fp8_metas[mode][
"recipe"
] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m._fp8_metas[mode],
)
return graphed(*user_args, **user_kwargs) return graphed(*user_args, **user_kwargs)
return orig_fwd(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs)
...@@ -721,7 +775,7 @@ def _make_graphed_callables( ...@@ -721,7 +775,7 @@ def _make_graphed_callables(
def save_fp8_tensors( def save_fp8_tensors(
modules: Iterable[torch.nn.Module], modules: Iterable[torch.nn.Module],
fp8_recipe: Recipe, fp8_recipe: Optional[Recipe],
) -> Optional[List[Any]]: ) -> Optional[List[Any]]:
""" """
Returns the FP8 tensors for all modules Returns the FP8 tensors for all modules
...@@ -740,7 +794,7 @@ def save_fp8_tensors( ...@@ -740,7 +794,7 @@ def save_fp8_tensors(
m.adjust_amax_history_length(fp8_recipe.amax_history_len) m.adjust_amax_history_length(fp8_recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors() module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, BasicOperation): elif isinstance(m, BasicOperation):
m.pre_first_forward(recipe=fp8_recipe) m.reset_recipe_state(recipe=fp8_recipe)
module_tensors = m._save_fp8_metas() module_tensors = m._save_fp8_metas()
fp8_tensors.append(module_tensors) fp8_tensors.append(module_tensors)
return fp8_tensors return fp8_tensors
...@@ -777,7 +831,7 @@ def make_graphed_callables( ...@@ -777,7 +831,7 @@ def make_graphed_callables(
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
fp8_enabled: bool = False, fp8_enabled: bool = False,
fp8_calibrating: bool = False, fp8_calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None, fp8_group: Optional[dist_group_type] = None,
fp8_weight_caching: bool = False, fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
...@@ -828,7 +882,7 @@ def make_graphed_callables( ...@@ -828,7 +882,7 @@ def make_graphed_callables(
data of fp8 tensors even when executing without fp8 enabled. This is data of fp8 tensors even when executing without fp8 enabled. This is
useful for saving an inference ready fp8 checkpoint while training useful for saving an inference ready fp8 checkpoint while training
using a higher precision. using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None` fp8_recipe: Recipe, default = `None`
recipe used for FP8 training. recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors distributed group over which amaxes for the fp8 tensors
...@@ -844,7 +898,10 @@ def make_graphed_callables( ...@@ -844,7 +898,10 @@ def make_graphed_callables(
""" """
set_capture_start() set_capture_start()
fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe if fp8_enabled and fp8_recipe is None:
fp8_recipe = get_default_fp8_recipe()
elif not fp8_enabled:
fp8_recipe = None
# Handle single module. # Handle single module.
just_one_callable = False just_one_callable = False
......
...@@ -136,30 +136,43 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: ...@@ -136,30 +136,43 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
@jit_fuser @jit_fuser
def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor: def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor:
"""L2 normalization fused - inference version""" """L2 normalization fused - inference version"""
x_squared = x.pow(2) x_fp32 = x.float()
x_squared = x_fp32.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
return x * rsqrt_norm y_fp32 = x_fp32 * rsqrt_norm
return y_fp32.to(x.dtype)
@jit_fuser @jit_fuser
def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]: def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
"""L2 normalization fused - training version that returns intermediate values""" """L2 normalization fused - training version that returns intermediate values"""
x_squared = x.pow(2) x_fp32 = x.float()
x_squared = x_fp32.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) l2_norm_squared_eps = l2_norm_squared + eps
y = x * rsqrt_norm rsqrt_norm = torch.rsqrt(l2_norm_squared_eps)
y_fp32 = x_fp32 * rsqrt_norm
y = y_fp32.to(x.dtype)
return y, rsqrt_norm return y, rsqrt_norm
@jit_fuser @jit_fuser
def l2normalization_backward_fused_( def l2normalization_backward_fused_(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float grad_output: torch.Tensor,
x: torch.Tensor,
rsqrt_norm: torch.Tensor,
eps: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""L2 normalization backward fused""" """L2 normalization backward fused"""
x_dy_sum = (x * grad_output).sum(dim=-1, keepdim=True) x_fp32 = x.float()
x_norm_squared = x.pow(2).sum(dim=-1, keepdim=True) + eps grad_output_fp32 = grad_output.float()
return rsqrt_norm * (grad_output - x * x_dy_sum / x_norm_squared) x_dy_sum = (x_fp32 * grad_output_fp32).sum(dim=-1, keepdim=True)
x_squared = x_fp32.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
x_norm_squared = l2_norm_squared + eps
dx_fp32 = rsqrt_norm * (grad_output_fp32 - x_fp32 * x_dy_sum / x_norm_squared)
return dx_fp32.to(x.dtype)
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
...@@ -193,7 +206,10 @@ def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor ...@@ -193,7 +206,10 @@ def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor
def l2normalization_backward_fused( def l2normalization_backward_fused(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float grad_output: torch.Tensor,
x: torch.Tensor,
rsqrt_norm: torch.Tensor,
eps: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Disable native AMP for l2normalization_backward_fused_""" """Disable native AMP for l2normalization_backward_fused_"""
with gpu_autocast_ctx(enabled=False): with gpu_autocast_ctx(enabled=False):
......
...@@ -42,11 +42,12 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer ...@@ -42,11 +42,12 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import DelayedScaling, Recipe from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
__all__ = ["initialize_ub", "destroy_ub"] __all__ = ["initialize_ub", "destroy_ub"]
...@@ -173,7 +174,7 @@ def initialize_ub( ...@@ -173,7 +174,7 @@ def initialize_ub(
``` ```
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_dgrad"]`. "fc2_fprop", "fc2_wgrad"]`.
bootstrap_backend : str = None bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and `torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are barrier collectives during Userbuffers initialization. Not all backends are
...@@ -272,9 +273,11 @@ def initialize_ub( ...@@ -272,9 +273,11 @@ def initialize_ub(
"qkv_fprop", "qkv_fprop",
"qkv_dgrad", "qkv_dgrad",
"proj_dgrad", "proj_dgrad",
"proj_wgrad",
"fc1_fprop", "fc1_fprop",
"fc1_dgrad", "fc1_dgrad",
"fc2_dgrad", "fc2_dgrad",
"fc2_wgrad",
] ]
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
...@@ -284,29 +287,34 @@ def initialize_ub( ...@@ -284,29 +287,34 @@ def initialize_ub(
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop", "fc2_fprop"], "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop", "fc2_fprop"],
"pipeline": [], "pipeline": [],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
"external": ["proj_wgrad", "fc2_wgrad"],
} }
elif bool(int(os.getenv("NVTE_PROJ_NO_PIPELINE_OVERLAP", "0"))): elif bool(int(os.getenv("NVTE_PROJ_NO_PIPELINE_OVERLAP", "0"))):
methods = { methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop"], "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop"],
"pipeline": ["fc2_fprop"], "pipeline": ["fc2_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
"external": ["proj_wgrad", "fc2_wgrad"],
} }
elif bool(int(os.getenv("NVTE_FC2_NO_PIPELINE_OVERLAP", "0"))): elif bool(int(os.getenv("NVTE_FC2_NO_PIPELINE_OVERLAP", "0"))):
methods = { methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "fc2_fprop"], "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "fc2_fprop"],
"pipeline": ["proj_fprop"], "pipeline": ["proj_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
"external": ["proj_wgrad", "fc2_wgrad"],
} }
else: else:
methods = { methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"pipeline": ["proj_fprop", "fc2_fprop"], "pipeline": ["proj_fprop", "fc2_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
"external": ["proj_wgrad", "fc2_wgrad"],
} }
# AG-RS overlap pairs of layers forming a tensor-parallel block # AG-RS overlap pairs of layers forming a tensor-parallel block
ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"}
global layers_atomic_ring_exchange global layers_atomic_ring_exchange
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -360,7 +368,7 @@ def initialize_ub( ...@@ -360,7 +368,7 @@ def initialize_ub(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases." "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
) )
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
if method == "bulk": if method in ("bulk", "external"):
warnings.warn( warnings.warn(
f"At {name}, atoimic GEMM not is supported for a bulk overlap." f"At {name}, atoimic GEMM not is supported for a bulk overlap."
"Defaulting to `atomic_gemm=False`." "Defaulting to `atomic_gemm=False`."
...@@ -389,6 +397,16 @@ def initialize_ub( ...@@ -389,6 +397,16 @@ def initialize_ub(
if atomic_gemm and method == "ring_exchange": if atomic_gemm and method == "ring_exchange":
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message
if name in external_gemm_to_overlap:
assert method == "external", (
f"At {name}, `external` overlap method is specified, but the selected method is"
f" {method}"
)
assert external_gemm_to_overlap[name] in methods["ring_exchange"], (
f"At {name}, `external` overlap method is specified, but the external gemm"
f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method"
)
buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
if method == "ring_exchange": if method == "ring_exchange":
ub_obj = tex.CommOverlapP2P( ub_obj = tex.CommOverlapP2P(
...@@ -437,7 +455,9 @@ def initialize_ub( ...@@ -437,7 +455,9 @@ def initialize_ub(
new_method = ub_cfgs[name]["method"] new_method = ub_cfgs[name]["method"]
methods[new_method].append(name) methods[new_method].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: for name in (
methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"]
):
if name in remove_ag_gemm_dgrad: if name in remove_ag_gemm_dgrad:
continue continue
ub_cfg = get_default_config(name) ub_cfg = get_default_config(name)
...@@ -610,6 +630,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -610,6 +630,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
super().__init__() super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA." assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None self.name = None
self.next_iter_when_debug_should_be_run = 0
self.fp8_initialized = False self.fp8_initialized = False
self.fp8 = False self.fp8 = False
self.fp8_calibration = False self.fp8_calibration = False
...@@ -628,6 +649,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -628,6 +649,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fsdp_group = None self.fsdp_group = None
self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None self.activation_dtype: Optional[torch.dtype] = None
self.wgrad_accumulation_and_reduce_hooks = []
if not TEDebugState.debug_enabled: if not TEDebugState.debug_enabled:
TEDebugState.initialize() TEDebugState.initialize()
...@@ -1339,21 +1361,29 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1339,21 +1361,29 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Try getting workspace from cache # Try getting workspace from cache
out = None out = None
if cache_name is not None: if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None) out = self._fp8_workspaces.get(cache_name, None)
if quantizer is not None and isinstance(out, MXFP8TensorBase):
# Reset cache if workspace is invalid
if out is not None and quantizer is not None:
reset_cache = False
if isinstance(out, Float8TensorBase):
if (
not is_non_tn_fp8_gemm_supported()
and quantizer.columnwise_usage
and out._transpose is None
):
reset_cache = True
elif isinstance(out, MXFP8TensorBase):
if quantizer.rowwise_usage and out._rowwise_data is None: if quantizer.rowwise_usage and out._rowwise_data is None:
out = None reset_cache = True
del self._fp8_workspaces[cache_name]
elif quantizer.columnwise_usage and out._columnwise_data is None: elif quantizer.columnwise_usage and out._columnwise_data is None:
out = None reset_cache = True
del self._fp8_workspaces[cache_name] if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
reset_cache = True
is_debug = isinstance(quantizer, DebugQuantizer) if reset_cache:
is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor)
if is_debug != is_out_debug_tensor:
out = None out = None
del self._fp8_workspaces[cache_name]
# Gather cached Fp8 workspace if it's distributed # Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
...@@ -1421,6 +1451,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1421,6 +1451,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
) )
def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook):
"""
This method is used to manually control the weight gradient accumulation and reduce.
This method should be called before the backward() method.
Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation
and reduce in backward();
And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method.
"""
self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)
def backward_dw(self): def backward_dw(self):
""" """
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
...@@ -1431,14 +1471,58 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1431,14 +1471,58 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, bgrad), _ = self.wgrad_store.pop() (wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation: if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names] weight_tensor = noop_cat(self._get_weight_tensors())
weight_tensor = noop_cat(unfused_weights)
if weight_tensor.grad is None: if weight_tensor.grad is None:
weight_tensor.grad = wgrad.to(weight_tensor.dtype) weight_tensor.grad = wgrad.to(weight_tensor.dtype)
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None: if bias_tensor.grad is None:
bias_tensor.grad = bgrad.to(bias_tensor.dtype) bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del wgrad
del bgrad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
def is_debug_iter(self) -> bool:
"""
This function checks if the debug should be enabled for this layer.
"""
debug = TEDebugState.debug_enabled
if not debug:
return False
self._validate_name()
# If layer is run first time in new iteration,
# we need to check if the debug should be enabled for this layer -
# maybe in previous iterations debug features returned information
# that no feature will be active for this layer for multiple next iterations.
started_new_iteration = TEDebugState.get_iteration() != getattr(
self, "debug_last_iteration", None
)
if started_new_iteration:
if self.next_iter_when_debug_should_be_run is None:
debug = False
else:
debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
self.debug_last_iteration = TEDebugState.get_iteration()
return debug
def no_debug_features_active(self, quantizers):
"""
Checks if any debug feature is active for this layer.
"""
run_current = any_feature_enabled(quantizers)
# Sometimes features inform that they will not be enabled for particular layer
# for multiple next iterations.
self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers)
if not run_current:
return True
if self.primary_weights_in_fp8:
raise RuntimeError("FP8 weights are not supported in debug mode.")
return False
def _validate_name(self): def _validate_name(self):
""" """
...@@ -1446,6 +1530,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1446,6 +1530,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
If no name is assigned, it creates a default name with layer count as the variable. If no name is assigned, it creates a default name with layer count as the variable.
""" """
if self.name is not None:
return
assert TEDebugState.debug_enabled assert TEDebugState.debug_enabled
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
...@@ -1494,29 +1580,3 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1494,29 +1580,3 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
" Please check the recipes assigned during fp8_model_init() and" " Please check the recipes assigned during fp8_model_init() and"
" fp8_autocast() calls." " fp8_autocast() calls."
) )
def _turn_off_unsupported_features_in_debug(self):
if (
getattr(self, "ub_bulk_wgrad", False)
or getattr(self, "ub_bulk_dgrad", False)
or getattr(self, "ub_overlap_ag", False)
or getattr(self, "ub_overlap_rs_dgrad", False)
or getattr(self, "ub_overlap_rs", False)
):
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
"UserBuffers are not supported in debug module. "
"Using UB optimization will not affect the debug module. ",
level=logging.WARNING,
)
if hasattr(self, "ub_bulk_wgrad"):
self.ub_bulk_wgrad = None
if hasattr(self, "ub_bulk_dgrad"):
self.ub_bulk_dgrad = None
if hasattr(self, "ub_overlap_ag"):
self.ub_overlap_ag = None
if hasattr(self, "ub_overlap_rs_dgrad"):
self.ub_overlap_rs_dgrad = None
if hasattr(self, "ub_overlap_rs"):
self.ub_overlap_rs = None
...@@ -667,6 +667,12 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -667,6 +667,12 @@ class GroupedLinear(TransformerEngineBaseModule):
self.reset_parameters(defer_init=device == "meta") self.reset_parameters(defer_init=device == "meta")
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
for i in range(self.num_gemms):
if name in (f"weight{i}", f"bias{i}"):
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
...@@ -747,7 +753,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -747,7 +753,9 @@ class GroupedLinear(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None: if skip_fp8_weight_update is not None:
is_first_microbatch = False is_first_microbatch = False
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors() weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
...@@ -822,19 +830,21 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -822,19 +830,21 @@ class GroupedLinear(TransformerEngineBaseModule):
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop() (_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2] wgrad_list = tensor_list[2]
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fuse_wgrad_accumulation: if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms): for i in range(self.num_gemms):
weight_param = getattr(self, f"weight{i}") if weight_params[i].grad is None:
if weight_param.grad is None: weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype)
weight_param.grad = wgrad_list[i].to(weight_param.dtype)
if self.use_bias: if self.use_bias:
for i in range(self.num_gemms): for i in range(self.num_gemms):
bias_param = getattr(self, f"bias{i}") if bias_params[i].grad is None:
if bias_param.grad is None: bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype)
bias_param.grad = grad_biases_[i].to(bias_param.dtype)
del grad_biases_ del grad_biases_
del wgrad_list del wgrad_list
del tensor_list del tensor_list
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear.""" """Customize quantizers based on current scaling recipe + linear."""
......
...@@ -62,9 +62,7 @@ from ..tensor.quantized_tensor import ( ...@@ -62,9 +62,7 @@ from ..tensor.quantized_tensor import (
restore_from_saved, restore_from_saved,
) )
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
...@@ -169,6 +167,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -169,6 +167,13 @@ class _LayerNormLinear(torch.autograd.Function):
with_input_all_gather = parallel_mode == "column" and sequence_parallel with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode
ub_overlap_ag_fprop = False
ub_overlap_rs_fprop = False
ub_overlap_ag_dgrad = False
ub_overlap_rs_dgrad = False
ub_bulk_wgrad = False
ub_bulk_dgrad = False
ub_obj = None ub_obj = None
ub_type = None ub_type = None
ub_overlap_ag_fprop = ( ub_overlap_ag_fprop = (
...@@ -186,9 +191,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -186,9 +191,7 @@ class _LayerNormLinear(torch.autograd.Function):
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if with_input_all_gather and isinstance( if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather():
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
...@@ -645,7 +648,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -645,7 +648,7 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None quantizer = None
if ctx.input_quantizer is not None: if ctx.input_quantizer is not None:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if quantizer.supports_only_rowwise_all_gather():
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
else: else:
...@@ -762,27 +765,36 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -762,27 +765,36 @@ class _LayerNormLinear(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
# make sure required data is available # make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output # UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't # all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we # convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered # can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly # for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM. # overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad")
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream # We use the send stream to copy into the userbuffers.
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() # This is the same stream that we will use to access the data in the AG,
with torch.cuda.stream(dgrad_comm_stream): # so we dont need to add any syncs yet.
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather with torch.cuda.stream(dgrad_send_stream):
# This ensures that we don't start until all communication for the dgrad GEMM is complete grad_output, _ = fill_userbuffers_buffer_for_all_gather(
grad_output, mxfp8_grad_output_work = gather_along_first_dim( ub_obj_overlap_wgrad,
grad_outputs[0], grad_outputs[0],
ctx.grad_output_quantizer,
ctx.tp_group, ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
) )
# Synchronize with the main stream
mxfp8_grad_output_work.wait() # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
tex.bulk_overlap_ag_with_external_gemm(
ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
)
# Prepare input tensor # Prepare input tensor
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
...@@ -1177,8 +1189,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1177,8 +1189,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -1396,6 +1406,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1396,6 +1406,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in self.weight_names or name in self.bias_names:
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
...@@ -1480,9 +1495,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1480,9 +1495,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
""" """
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output) return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled
if debug: debug = self.is_debug_iter()
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing(): if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
...@@ -1498,7 +1512,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1498,7 +1512,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with self.prepare_forward( with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp: ) as inp:
...@@ -1511,13 +1527,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1511,13 +1527,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
else self._get_debug_quantizers(fp8_output, fp8_grad) else self._get_debug_quantizers(fp8_output, fp8_grad)
) )
if debug: if debug:
if not any_feature_enabled(quantizers): if self.no_debug_features_active(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad)
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
( (
input_quantizer, input_quantizer,
......
...@@ -69,7 +69,6 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer ...@@ -69,7 +69,6 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase, QuantizedTensorBase,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
...@@ -79,7 +78,6 @@ from ..cpp_extensions import ( ...@@ -79,7 +78,6 @@ from ..cpp_extensions import (
general_gemm, general_gemm,
) )
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.utils import any_feature_enabled
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
...@@ -224,6 +222,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -224,6 +222,12 @@ class _LayerNormMLP(torch.autograd.Function):
device = inp.device device = inp.device
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode
ub_overlap_ag = False
ub_overlap_rs = False
ub_overlap_rs_dgrad = False
ub_bulk_wgrad = False
ub_bulk_dgrad = False
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled ub_overlap_rs = ub_overlap_rs and is_grad_enabled
...@@ -239,9 +243,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -239,9 +243,7 @@ class _LayerNormMLP(torch.autograd.Function):
if fc1_input_quantizer is None: if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor") raise ValueError("Missing quantizer for FC1 input tensor")
fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input) fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input)
if sequence_parallel and isinstance( if sequence_parallel and fc1_input_quantizer.supports_only_rowwise_all_gather():
fc1_input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False) fc1_input_quantizer.set_usage(columnwise=False)
...@@ -850,26 +852,37 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -850,26 +852,37 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
# make sure required data is available # make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output # UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't # all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we # convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered # can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly # for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM. # overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = (
ub_obj_fc2_dgrad.get_communication_stream()
)
ub_obj_fc2_wgrad = get_ub("fc2_wgrad")
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_fc2_dgrad.get_communication_stream() # We use the send stream to copy into the userbuffers.
with torch.cuda.stream(dgrad_comm_stream): # This is the same stream that we will use to access the data in the AG,
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather # so we dont need to add any syncs yet.
# This ensures that we don't start until all communication for the dgrad GEMM is complete with torch.cuda.stream(dgrad_send_stream):
grad_output, mxfp8_fc2_grad_output_work = gather_along_first_dim( grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fc2_wgrad,
grad_outputs[0], grad_outputs[0],
ctx.fc2_grad_output_quantizer,
ctx.tp_group, ctx.tp_group,
async_op=True,
quantizer=ctx.fc2_grad_output_quantizer,
) )
# Synchronize with the main stream
mxfp8_fc2_grad_output_work.wait() # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
tex.bulk_overlap_ag_with_external_gemm(
ub_obj_fc2_wgrad, dgrad_send_stream, dgrad_recv_stream
)
# Prepare input tensor # Prepare input tensor
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
...@@ -1541,9 +1554,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1541,9 +1554,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
) )
self.name = name self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if tp_group is None: if tp_group is None:
...@@ -1660,6 +1670,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1660,6 +1670,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
warmup_jit_bias_gelu_all_dtypes( warmup_jit_bias_gelu_all_dtypes(
self.size_per_partition, seq_length, micro_batch_size self.size_per_partition, seq_length, micro_batch_size
) )
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in ["fc1_weight", "fc2_weight", "fc1_bias", "fc2_bias"]:
param.skip_backward_post_hook = True
# These many SMs are subtracted from the total SM count when calling forward # These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
...@@ -1742,9 +1756,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1742,9 +1756,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
""" """
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp) return self.onnx_forward(inp)
debug = TEDebugState.debug_enabled
if debug: debug = self.is_debug_iter()
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing(): if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
...@@ -1758,7 +1771,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1758,7 +1771,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
if get_ub("fc2_fprop").is_fp8_ubuf(): if get_ub("fc2_fprop").is_fp8_ubuf():
fp8_output = True fp8_output = True
with self.prepare_forward(inp, num_gemms=2) as inp: with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=2) as inp:
quantizers = ( quantizers = (
self._get_quantizers(fp8_output) self._get_quantizers(fp8_output)
...@@ -1766,12 +1781,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1766,12 +1781,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
else self._get_debug_quantizers(fp8_output) else self._get_debug_quantizers(fp8_output)
) )
if debug: if debug:
if not any_feature_enabled(quantizers): if self.no_debug_features_active(quantizers):
quantizers = self._get_quantizers(fp8_output)
debug = False debug = False
quantizers = self._get_quantizers(fp8_output)
if isinstance(self.fc1_weight, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
# Get quantizers # Get quantizers
( (
...@@ -2169,3 +2181,5 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2169,3 +2181,5 @@ class LayerNormMLP(TransformerEngineBaseModule):
del fc2_wgrad del fc2_wgrad
del fc1_wgrad del fc1_wgrad
del fc1_bias_grad del fc1_bias_grad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
...@@ -66,12 +66,9 @@ from ..tensor.quantized_tensor import ( ...@@ -66,12 +66,9 @@ from ..tensor.quantized_tensor import (
) )
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -140,6 +137,12 @@ class _Linear(torch.autograd.Function): ...@@ -140,6 +137,12 @@ class _Linear(torch.autograd.Function):
) )
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode
ub_overlap_rs_fprop = False
ub_overlap_ag_fprop = False
ub_overlap_rs_dgrad = False
ub_bulk_wgrad = False
ub_bulk_dgrad = False
ub_obj = None ub_obj = None
ub_type = None ub_type = None
if ub_overlap_rs_fprop: if ub_overlap_rs_fprop:
...@@ -171,16 +174,19 @@ class _Linear(torch.autograd.Function): ...@@ -171,16 +174,19 @@ class _Linear(torch.autograd.Function):
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorBase): if not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage( own_quantized_input = True
rowwise=True, columnwise=backward_needs_input and not save_original_input input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
)
if isinstance( if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
): ):
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
if save_original_input:
# No need for column-wise data since this
# tensor will not be cached for backward pass
input_quantizer.set_usage(columnwise=False)
own_quantized_input = False
inputmat = input_quantizer(inputmat) inputmat = input_quantizer(inputmat)
own_quantized_input = True
else: else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
...@@ -345,23 +351,30 @@ class _Linear(torch.autograd.Function): ...@@ -345,23 +351,30 @@ class _Linear(torch.autograd.Function):
inputmat = inp inputmat = inp
ctx.weight_quantizer = weight_quantizer ctx.weight_quantizer = weight_quantizer
saved_inputmat = None
ctx.backward_input_needs_gather = ( ctx.backward_input_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel weight.requires_grad and parallel_mode == "column" and sequence_parallel
) )
# Discard unneeded data in input tensor
if (
backward_needs_input
and own_quantized_input
and isinstance(inputmat, QuantizedTensorBase)
):
if (
ctx.backward_input_needs_gather
and weight_quantizer.supports_only_rowwise_all_gather()
):
# All-gather is not supported with FP8 column-wise data
inputmat.update_usage(rowwise_usage=True, columnwise_usage=False)
else:
# Discard row-wise data since it is not needed in backward pass
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
# Cached input tensor
saved_inputmat = None
if backward_needs_input: if backward_needs_input:
if not save_original_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if (
isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.backward_input_needs_gather
):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
saved_inputmat = inputmat saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
...@@ -547,6 +560,19 @@ class _Linear(torch.autograd.Function): ...@@ -547,6 +560,19 @@ class _Linear(torch.autograd.Function):
# usage for only dgrad GEMM. # usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False) quantizer.set_usage(columnwise=False)
# Adjust the quantization direction approach depending
# on whether wgrad calculations will be performed.
# NOTE: If requires_dgrad is False, disabling `rowwise` quantization and keeping `columnwise` quantization
# results in `Assertion failed: output_tensor->has_data(). Quantizing in only the columnwise direction not supported yet!`
# NOTE: For `ctx.bias is True`, selected quantize kernel errors with
# `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.`
if (
not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
):
ctx.grad_output_quantizer.set_usage(columnwise=False)
# Prepare grad output tensor # Prepare grad output tensor
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
( (
...@@ -573,11 +599,18 @@ class _Linear(torch.autograd.Function): ...@@ -573,11 +599,18 @@ class _Linear(torch.autograd.Function):
inputmat_total = None inputmat_total = None
inputmat_total_work = None inputmat_total_work = None
if ctx.requires_wgrad: if ctx.requires_wgrad:
input_is_quantized = isinstance(inputmat, QuantizedTensorBase)
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if not input_is_quantized: if isinstance(inputmat, QuantizedTensorBase):
# Input tensor is already quantized
pass
elif ctx.debug:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else:
# Quantize input tensor
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data
quantizer.set_usage( quantizer.set_usage(
rowwise=True, rowwise=True,
columnwise=not ctx.backward_input_needs_gather, columnwise=not ctx.backward_input_needs_gather,
...@@ -586,7 +619,7 @@ class _Linear(torch.autograd.Function): ...@@ -586,7 +619,7 @@ class _Linear(torch.autograd.Function):
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
inputmat = quantizer(inputmat) inputmat = quantizer(inputmat)
else: else:
if input_is_quantized: if isinstance(inputmat, QuantizedTensorBase):
inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
else: else:
inputmat = cast_if_needed(inputmat, ctx.activation_dtype) inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
...@@ -594,7 +627,7 @@ class _Linear(torch.autograd.Function): ...@@ -594,7 +627,7 @@ class _Linear(torch.autograd.Function):
quantizer = None quantizer = None
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if quantizer.supports_only_rowwise_all_gather():
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
else: else:
...@@ -726,26 +759,36 @@ class _Linear(torch.autograd.Function): ...@@ -726,26 +759,36 @@ class _Linear(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
# make sure required data is available # make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output # UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't # all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we # convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered # can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly # for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM. # overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad")
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() # We use the send stream to copy into the userbuffers.
with torch.cuda.stream(dgrad_comm_stream): # This is the same stream that we will use to access the data in the AG,
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather # so we dont need to add any syncs yet.
# This ensures that we don't start until all communication for the dgrad GEMM is complete with torch.cuda.stream(dgrad_send_stream):
grad_output, grad_output_work = gather_along_first_dim( grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_overlap_wgrad,
grad_output_arg, grad_output_arg,
ctx.grad_output_quantizer,
ctx.tp_group, ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
) )
# Synchronize with the main stream
grad_output_work.wait() # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
tex.bulk_overlap_ag_with_external_gemm(
ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
)
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorBase):
...@@ -1067,9 +1110,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1067,9 +1110,6 @@ class Linear(TransformerEngineBaseModule):
self.save_original_input = save_original_input self.save_original_input = save_original_input
self.name = name self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if device == "meta": if device == "meta":
...@@ -1261,6 +1301,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -1261,6 +1301,11 @@ class Linear(TransformerEngineBaseModule):
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in self.weight_names or name in self.bias_names:
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
...@@ -1326,9 +1371,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1326,9 +1371,7 @@ class Linear(TransformerEngineBaseModule):
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output) return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled debug = self.is_debug_iter()
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing(): if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
...@@ -1344,7 +1387,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -1344,7 +1387,9 @@ class Linear(TransformerEngineBaseModule):
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with self.prepare_forward( with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp, inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor), allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp: ) as inp:
...@@ -1356,14 +1401,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -1356,14 +1401,11 @@ class Linear(TransformerEngineBaseModule):
if not debug if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad) else self._get_debug_quantizers(fp8_output, fp8_grad)
) )
if debug: if debug:
if not any_feature_enabled(quantizers): if self.no_debug_features_active(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad)
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
( (
input_quantizer, input_quantizer,
......
...@@ -5,11 +5,13 @@ ...@@ -5,11 +5,13 @@
"""Single tensor operations supported by the operation fuser.""" """Single tensor operations supported by the operation fuser."""
from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU
from .add_in_place import AddInPlace from .add_extra_input import AddExtraInput
from .all_gather import AllGather from .all_gather import AllGather
from .all_reduce import AllReduce from .all_reduce import AllReduce
from .basic_linear import BasicLinear from .basic_linear import BasicLinear
from .bias import Bias from .bias import Bias
from .constant_scale import ConstantScale
from .dropout import Dropout
from .identity import Identity from .identity import Identity
from .l2normalization import L2Normalization from .l2normalization import L2Normalization
from .layer_norm import LayerNorm from .layer_norm import LayerNorm
......
...@@ -11,7 +11,6 @@ from typing import Optional ...@@ -11,7 +11,6 @@ from typing import Optional
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ...fp8 import FP8GlobalStateManager
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
...@@ -71,7 +70,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -71,7 +70,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -87,14 +86,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -87,14 +86,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Check input tensor # Check input tensor
x = maybe_dequantize(input_.contiguous(), dtype) x = maybe_dequantize(input_.contiguous(), dtype)
# Check if quantized compute is enabled
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
quantizer = None
if with_quantized_compute:
quantizer = next_op_input_quantizer
# Launch kernel # Launch kernel
y = self._activation_forward_impl(x, quantizer) y = self._activation_forward_impl(x, next_op_input_quantizer)
# Quantize input to FP8 before caching if needed # Quantize input to FP8 before caching if needed
if self.cache_quantized_input: if self.cache_quantized_input:
...@@ -103,10 +96,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -103,10 +96,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
x = input_quantizer(x) x = input_quantizer(x)
# Save state for backward pass # Save state for backward pass
ctx.save_for_backward(x) if ctx.requires_grad:
ctx.with_quantized_compute = with_quantized_compute ctx.save_for_backward(x)
ctx.dtype = dtype ctx.dtype = dtype
ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
return y return y
...@@ -125,13 +118,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -125,13 +118,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Check grad output tensor # Check grad output tensor
dy = maybe_dequantize(grad_output.contiguous(), x.dtype) dy = maybe_dequantize(grad_output.contiguous(), x.dtype)
# Check if quantized compute is enabled
quantizer = None
if ctx.with_quantized_compute:
quantizer = ctx.prev_op_grad_input_quantizer
# Launch kernel # Launch kernel
dx = self._activation_backward_impl(dy, x, quantizer) dx = self._activation_backward_impl(dy, x, ctx.prev_op_grad_output_quantizer)
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x) clear_tensor_data(x)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Fusible operation for in-place add.""" """Fusible operation for adding extra input tensor."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
...@@ -18,16 +18,17 @@ from transformer_engine.pytorch.ops.op import ( ...@@ -18,16 +18,17 @@ from transformer_engine.pytorch.ops.op import (
from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor import Quantizer
class AddInPlace(BasicOperation): class AddExtraInput(BasicOperation):
"""Add in-place """Add extra input tensor
This operation requires an extra tensor input to the operation This operation requires an extra tensor input to the operation
fuser. The main input is added in-place to the extra input, and a user. It returns the sum of the main input and the extra input.
view of the extra input is output. If in_place=True, the main input is added in-place to the extra
input, and a view of the extra input is output.
This operation is considered an advanced feature and most users Using this operation with in_place=True is considered an advanced
are discouraged from using it. In-place operations break some feature and most users are discouraged from it. In-place operations
autograd assumptions and they can result in subtle, esoteric bugs. break some autograd assumptions and they can result in subtle, esoteric bugs.
Compare to `MakeExtraOutput`, which does a similar operation in Compare to `MakeExtraOutput`, which does a similar operation in
the backward pass. the backward pass.
...@@ -37,6 +38,10 @@ class AddInPlace(BasicOperation): ...@@ -37,6 +38,10 @@ class AddInPlace(BasicOperation):
# Operation expects buffer for output tensor # Operation expects buffer for output tensor
num_extra_inputs: int = 1 num_extra_inputs: int = 1
def __init__(self, *, in_place: bool = False):
super().__init__()
self._in_place = in_place
def op_forward(self, *args, **kwargs) -> None: def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError( raise RuntimeError(
"{self.__class__.__name__} operation has " "{self.__class__.__name__} operation has "
...@@ -59,12 +64,17 @@ class AddInPlace(BasicOperation): ...@@ -59,12 +64,17 @@ class AddInPlace(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
output = basic_op_extra_inputs[0][0].detach() extra_input = basic_op_extra_inputs[0][0]
output += input_ if self._in_place:
extra_input = extra_input.detach()
extra_input += input_
output = extra_input
else:
output = extra_input + input_
return output, [()] return output, [()]
def fuser_backward( def fuser_backward(
......
...@@ -40,7 +40,7 @@ class AllGather(BasicOperation): ...@@ -40,7 +40,7 @@ class AllGather(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
out: torch.Tensor out: torch.Tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment