Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
...@@ -420,6 +420,8 @@ class NonPagedKVCacheManager(KVCacheManager): ...@@ -420,6 +420,8 @@ class NonPagedKVCacheManager(KVCacheManager):
dtype=torch.int32, dtype=torch.int32,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
) )
# whether reindexing is needed, i.e. when batch seq_ids have changed
self.need_reindex = True
def allocate_memory(self, layer_number): def allocate_memory(self, layer_number):
"""Allocate memory for the cache""" """Allocate memory for the cache"""
...@@ -451,6 +453,7 @@ class NonPagedKVCacheManager(KVCacheManager): ...@@ -451,6 +453,7 @@ class NonPagedKVCacheManager(KVCacheManager):
# step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that # step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that
# they are contiguous and match the indexing in q # they are contiguous and match the indexing in q
prev_batch_size = len(self.sequences) prev_batch_size = len(self.sequences)
prev_seq_ids = set(self.sequences.keys())
unfinished_seqs = self.sequences.keys() & step_dict.keys() unfinished_seqs = self.sequences.keys() & step_dict.keys()
finished_seqs = self.sequences.keys() - unfinished_seqs finished_seqs = self.sequences.keys() - unfinished_seqs
unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs]
...@@ -478,6 +481,9 @@ class NonPagedKVCacheManager(KVCacheManager): ...@@ -478,6 +481,9 @@ class NonPagedKVCacheManager(KVCacheManager):
for i in new_seqs: for i in new_seqs:
self.sequences[i] = step_dict[i] self.sequences[i] = step_dict[i]
# Whether reindexing is needed
self.need_reindex = set(self.sequences.keys()) != prev_seq_ids
return self.sequences return self.sequences
def step( def step(
...@@ -538,7 +544,7 @@ class NonPagedKVCacheManager(KVCacheManager): ...@@ -538,7 +544,7 @@ class NonPagedKVCacheManager(KVCacheManager):
ctx_len, ctx_len,
self.max_seqlen, self.max_seqlen,
1, 1,
True, self.need_reindex,
) )
k_cache = k_cache[:batch_size] k_cache = k_cache[:batch_size]
......
...@@ -9,8 +9,8 @@ from typing import Any, Dict, Optional ...@@ -9,8 +9,8 @@ from typing import Any, Dict, Optional
import torch import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from .tensor.quantized_tensor import QuantizedTensorBase from .tensor.quantized_tensor import QuantizedTensorBase
from .tensor.float8_tensor import Float8Tensor from .tensor.float8_tensor import Float8Tensor
__all__ = ["get_cpu_offload_context"] __all__ = ["get_cpu_offload_context"]
...@@ -20,6 +20,9 @@ CPUOffloadEnabled = False ...@@ -20,6 +20,9 @@ CPUOffloadEnabled = False
def mark_activation_offload(*tensors): def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor.""" """Set the type of the offloading needed for a tensor."""
if TEDebugState.debug_enabled:
raise RuntimeError("CPU offload is not supported in debug mode.")
for tensor in tensors: for tensor in tensors:
if tensor is None: if tensor is None:
continue continue
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <transformer_engine/comm_gemm_overlap.h> #include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h> #include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_router.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h> #include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h> #include <transformer_engine/multi_tensor.h>
...@@ -215,6 +216,8 @@ class Float8BlockQuantizer : public Quantizer { ...@@ -215,6 +216,8 @@ class Float8BlockQuantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor( std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype, const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override; std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
}; };
class MXFP8Quantizer : public Quantizer { class MXFP8Quantizer : public Quantizer {
...@@ -230,6 +233,8 @@ class MXFP8Quantizer : public Quantizer { ...@@ -230,6 +233,8 @@ class MXFP8Quantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor( std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype, const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override; std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
}; };
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer); std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
......
...@@ -13,6 +13,38 @@ ...@@ -13,6 +13,38 @@
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
/***************************************************************************************************
* Router fusion
**************************************************************************************************/
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fwd(
at::Tensor logits, int topk, bool use_pre_softmax, c10::optional<int> num_groups,
c10::optional<int> group_topk, c10::optional<float> scaling_factor, std::string score_function,
c10::optional<at::Tensor> expert_bias);
at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts,
at::Tensor routing_map,
at::Tensor intermediate_output, at::Tensor grad_probs,
int topk, bool use_pre_softmax,
c10::optional<float> scaling_factor,
std::string score_function);
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
at::Tensor logits, int topk, std::string score_function);
at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts,
at::Tensor intermediate_output, at::Tensor grad_probs,
int topk, std::string score_function);
std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
at::Tensor tokens_per_expert,
int total_num_tokens, int num_experts,
int num_rows, int num_cols, int topk,
float coeff);
at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows,
int num_cols, at::Tensor grad_aux_loss);
/*************************************************************************************************** /***************************************************************************************************
* Permutation * Permutation
**************************************************************************************************/ **************************************************************************************************/
...@@ -136,10 +168,6 @@ std::vector<at::Tensor> te_batchgemm_ts( ...@@ -136,10 +168,6 @@ std::vector<at::Tensor> te_batchgemm_ts(
* Transpose * Transpose
**************************************************************************************************/ **************************************************************************************************/
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, DType otype);
at::Tensor fp8_transpose(at::Tensor input, DType otype, at::Tensor fp8_transpose(at::Tensor input, DType otype,
std::optional<at::Tensor> output = std::nullopt); std::optional<at::Tensor> output = std::nullopt);
...@@ -210,10 +238,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -210,10 +238,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
**************************************************************************************************/ **************************************************************************************************/
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
std::optional<at::Tensor> noop); std::optional<at::Tensor> noop_flag);
py::object dequantize(const py::handle &input, DType otype); py::object dequantize(const py::handle &input, DType otype);
std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list);
std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<int> &split_sections,
std::vector<py::handle> quantizer_list);
/*************************************************************************************************** /***************************************************************************************************
* Bias gradient fusions * Bias gradient fusions
**************************************************************************************************/ **************************************************************************************************/
...@@ -395,6 +430,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -395,6 +430,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list, std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list); std::vector<size_t> padded_input_row_list);
void fused_multi_row_unpadding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> unpadded_input_row_list);
/*************************************************************************************************** /***************************************************************************************************
* NVSHMEM APIs * NVSHMEM APIs
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -6,60 +6,51 @@ ...@@ -6,60 +6,51 @@
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "../extensions.h" #include "../extensions.h"
#include "common.h" #include "common.h"
#include "pybind.h" #include "pybind.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
namespace transformer_engine::pytorch { namespace transformer_engine {
namespace pytorch {
py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::object& output,
std::optional<at::Tensor> noop) {
init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto input_tensor = tensor.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); namespace {
const auto& te_input_shape = te_input.shape();
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim);
auto fake_tensor_type = tensor.scalar_type();
if (!detail::IsFloatingPointType(fake_tensor_type)) {
fake_tensor_type = at::kFloat;
}
TensorWrapper te_output; std::vector<size_t> get_tensor_shape(const TensorWrapper &tensor) {
py::object out; const auto &shape = tensor.shape();
if (output.is_none()) { return std::vector<size_t>(shape.data, shape.data + shape.ndim);
DType fake_te_type = GetTransformerEngineDType(fake_tensor_type); }
std::tie(te_output, out) = my_quantizer->create_tensor(input_shape, fake_te_type);
} else {
out = output;
te_output = makeTransformerEngineTensor(output, quantizer);
}
TensorWrapper te_noop; void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py,
if (noop.has_value()) { std::unique_ptr<Quantizer> &quantizer_cpp, TensorWrapper &output,
te_noop = makeTransformerEngineTensor(*noop); TensorWrapper &noop_flag) {
} else { // Check tensor dims
te_noop = TensorWrapper(); NVTE_CHECK(get_tensor_shape(input) == get_tensor_shape(output),
"Input tensor (shape=", get_tensor_shape(input),
") and output tensor (shape=", get_tensor_shape(output), ") do not match");
if (input.numel() == 0) {
return;
} }
if (te_output.numel() == 0) return out; // Recipe-specific configuration
QuantizationConfigWrapper quant_config; QuantizationConfigWrapper quant_config;
quant_config.set_noop_tensor(te_noop.data()); quant_config.set_noop_tensor(noop_flag.data());
if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) {
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
// my_quantizer here has to be a Float8CurrentScalingQuantizer NVTE_SCOPED_GIL_RELEASE(
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get()); { nvte_compute_amax(input.data(), output.data(), at::cuda::getCurrentCUDAStream()); });
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs) // check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) { if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group; c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory // construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch}; std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor // allreduce amax tensor
c10d::AllreduceOptions allreduce_opts; c10d::AllreduceOptions allreduce_opts;
...@@ -72,37 +63,70 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob ...@@ -72,37 +63,70 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(te_output.data(), quant_config, nvte_compute_scale_from_amax(output.data(), quant_config, at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
}); });
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); output.set_amax(nullptr, DType::kFloat32, output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { } else if (detail::IsFloat8BlockwiseQuantizers(quantizer_py.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get()); auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(quantizer_cpp.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) { if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
} }
} }
// Perform quantization
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, nvte_quantize_v2(input.data(), output.data(), quant_config, at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
}); });
}
return out; } // namespace
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
std::optional<at::Tensor> noop_flag) {
// Convert quantizer to C++ object
auto quantizer_cpp = convert_quantizer(quantizer);
// Convert input tensor to C++ object
auto input_contiguous = tensor.contiguous();
const auto input_cpp = makeTransformerEngineTensor(input_contiguous);
// Initialize output tensor
TensorWrapper output_cpp;
py::object output_py;
if (output.is_none()) {
const auto shape = get_tensor_shape(input_cpp);
const auto fake_dtype = input_cpp.dtype();
std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype);
} else {
output_py = output;
output_cpp = makeTransformerEngineTensor(output_py, quantizer);
}
// Initialize no-op flag
TensorWrapper noop_flag_cpp;
if (noop_flag.has_value()) {
noop_flag_cpp = makeTransformerEngineTensor(*noop_flag);
}
// Perform quantization
quantize_impl(input_cpp, quantizer, quantizer_cpp, output_cpp, noop_flag_cpp);
return output_py;
} }
py::object dequantize(const py::handle& input, transformer_engine::DType otype) { py::object dequantize(const py::handle &input, transformer_engine::DType otype) {
init_extension(); init_extension();
const auto none = py::none(); const auto none = py::none();
const auto& input_tensor = makeTransformerEngineTensor(input, none); const auto &input_tensor = makeTransformerEngineTensor(input, none);
NoneQuantizer q(none); NoneQuantizer q(none);
const auto& shape = convertShape(input_tensor.shape()); const auto &shape = convertShape(input_tensor.shape());
auto [out_tensor, out] = q.create_tensor(shape, otype); auto [out_tensor, out] = q.create_tensor(shape, otype);
...@@ -113,9 +137,522 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype) ...@@ -113,9 +137,522 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype)
return out; return out;
} }
namespace {
void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
std::vector<py::handle> &quantizer_py_list,
std::vector<std::unique_ptr<Quantizer>> &quantizer_cpp_list,
std::vector<TensorWrapper> &output_list) {
// Check number of tensors
const size_t num_tensors = input_list.size();
NVTE_CHECK(quantizer_py_list.size() == num_tensors, "Expected ", num_tensors,
" Python quantizers, but got ", quantizer_py_list.size());
NVTE_CHECK(quantizer_cpp_list.size() == num_tensors, "Expected ", num_tensors,
" C++ quantizers, but got ", quantizer_cpp_list.size());
NVTE_CHECK(output_list.size() == num_tensors, "Expected ", num_tensors,
" output tensors, but got ", output_list.size());
// Choose implementation
// Note: Currently only have fused kernel for FP8 delayed scaling
bool with_fused_kernel = true;
for (size_t i = 0; i < num_tensors; i++) {
if (!detail::IsFloat8Quantizers(quantizer_py_list[i].ptr())) {
with_fused_kernel = false;
break;
}
if (nvte_tensor_data(output_list[i].data()) == nullptr ||
nvte_tensor_columnwise_data(output_list[i].data()) == nullptr) {
with_fused_kernel = false;
break;
}
}
// Launch TE kernel
if (with_fused_kernel) {
// Fused kernel for multi-tensor quantize
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
for (size_t i = 0; i < num_tensors; ++i) {
nvte_tensor_input_list.push_back(input_list[i].data());
nvte_tensor_output_list.push_back(output_list[i].data());
}
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
});
} else {
// Quantize kernels individually
TensorWrapper dummy_noop_flag;
for (size_t i = 0; i < num_tensors; ++i) {
quantize_impl(input_list[i], quantizer_py_list[i], quantizer_cpp_list[i], output_list[i],
dummy_noop_flag);
}
}
}
} // namespace
std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list) {
// Check number of tensors
const size_t num_tensors = tensor_list.size();
NVTE_CHECK(quantizer_list.size() == num_tensors, "Expected ", num_tensors,
" quantizers, but got ", quantizer_list.size());
// Convert quantizers to C++ objects
std::vector<std::unique_ptr<Quantizer>> quantizer_cpp_list;
for (size_t i = 0; i < num_tensors; i++) {
quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i]));
}
// Initialize input and output tensors
std::vector<TensorWrapper> input_cpp_list;
std::vector<TensorWrapper> output_cpp_list;
std::vector<py::object> output_py_list;
for (size_t i = 0; i < num_tensors; ++i) {
// Convert input tensor to C++ object
const auto &input_py = tensor_list[i];
NVTE_CHECK(input_py.is_contiguous(), "Input tensor ", i, " is not contiguous");
input_cpp_list.emplace_back(makeTransformerEngineTensor(input_py));
const auto &input_cpp = input_cpp_list.back();
const auto input_shape = input_cpp.shape();
const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type());
// Construct output tensor
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype);
output_cpp_list.emplace_back(std::move(output_cpp));
output_py_list.emplace_back(std::move(output_py));
}
// Perform multi-tensor quantization
multi_tensor_quantize_impl(input_cpp_list, quantizer_list, quantizer_cpp_list, output_cpp_list);
return output_py_list;
}
namespace {
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp8_blockwise_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<Float8BlockQuantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
// Number of tensors
const size_t num_tensors = shape_list.size();
if (num_tensors == 0) {
return retval;
}
// Quantization parameters
const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage;
const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage;
const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode();
const auto is_2D_scaled = scaling_mode == NVTE_BLOCK_SCALING_2D;
const auto fp8_dtype = quantizer_cpp_list[0]->dtype;
constexpr size_t fp8_elem_size = 1;
constexpr size_t scale_elem_size = 4;
// Helper function to construct tensor view
// Note: Deleter holds a shared_ptr for the buffer, so the buffer
// will survive until all views are deleted.
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
size_t offset, at::ScalarType dtype) -> at::Tensor {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
// in the case where full buffer is empty because local rank receives no tokens for all the experts
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
// as much as possible to avoid CPU overhead
if (buffer->data_ptr<uint8_t>() == nullptr) {
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
}
return at::from_blob(
buffer->data_ptr<uint8_t>() + offset, shape_int64,
[buffer](void *) {}, // deleter holds shared_ptr
at::device(at::kCUDA).dtype(dtype));
};
// Allocate row-wise data
std::vector<at::Tensor> rowwise_data_list, rowwise_scale_list;
std::vector<std::vector<size_t>> rowwise_data_shapes, rowwise_scale_shapes;
if (rowwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_shapes.emplace_back(shape_list[i]);
rowwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size;
}
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_list.emplace_back(
make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8));
rowwise_scale_list.emplace_back(
make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kFloat32));
}
}
// Allocate column-wise data
std::vector<at::Tensor> columnwise_data_list, columnwise_scale_list;
std::vector<std::vector<size_t>> columnwise_data_shapes, columnwise_scale_shapes;
if (columnwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
columnwise_data_shapes.emplace_back();
auto &shape = columnwise_data_shapes.back();
shape.push_back(shape_list[i].back());
for (size_t j = 0; j < shape_list[i].size() - 1; ++j) {
shape.push_back(shape_list[i][j]);
}
columnwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size;
}
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
columnwise_data_list.emplace_back(
make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8));
columnwise_scale_list.emplace_back(
make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kFloat32));
}
}
// Construct FP8 block-wise tensors
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject *>(Float8BlockwiseQTensorBasePythonClass));
for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none();
py::object columnwise_data =
(columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none());
py::object columnwise_scale =
(columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none());
// Construct Python tensor
tensor_py_list.emplace_back(Float8BlockwiseQTensorClass(
rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype,
quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY));
// Construct C++ tensor
tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp8_dtype, nullptr,
nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode));
}
return retval;
}
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mxfp8_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<MXFP8Quantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
// Number of tensors
const size_t num_tensors = shape_list.size();
if (num_tensors == 0) {
return retval;
}
// Quantization parameters
const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage;
const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage;
const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode();
const auto fp8_dtype = quantizer_cpp_list[0]->dtype;
constexpr size_t fp8_elem_size = 1;
constexpr size_t scale_elem_size = 1;
// Helper function to construct tensor view
// Note: Deleter holds a shared_ptr for the buffer, so the buffer
// will survive until all views are deleted.
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
size_t offset, at::ScalarType dtype) -> at::Tensor {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
// in the case where full buffer is empty because local rank receives no tokens for all the experts
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
// as much as possible to avoid CPU overhead
if (buffer->data_ptr<uint8_t>() == nullptr) {
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
}
return at::from_blob(
buffer->data_ptr<uint8_t>() + offset, shape_int64,
[buffer](void *) {}, // deleter holds shared_ptr
at::device(at::kCUDA).dtype(dtype));
};
// Allocate row-wise data
std::vector<at::Tensor> rowwise_data_list, rowwise_scale_list;
std::vector<std::vector<size_t>> rowwise_data_shapes, rowwise_scale_shapes;
if (rowwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_shapes.emplace_back(shape_list[i]);
rowwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size;
}
// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto buffer = std::make_shared<at::Tensor>(
at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_list.emplace_back(
make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8));
rowwise_scale_list.emplace_back(
make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8));
}
}
// Allocate column-wise data
std::vector<at::Tensor> columnwise_data_list, columnwise_scale_list;
std::vector<std::vector<size_t>> columnwise_data_shapes, columnwise_scale_shapes;
if (columnwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
// For MXFP8, the columnwise data doesn't need transpose
// because of TN, NT, NN layout support in SM100
columnwise_data_shapes.emplace_back(shape_list[i]);
columnwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size;
}
// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto buffer = std::make_shared<at::Tensor>(
at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
columnwise_data_list.emplace_back(
make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8));
columnwise_scale_list.emplace_back(
make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8));
}
}
// Construct mxfp8 tensors
py::handle MXFP8TensorClass(reinterpret_cast<PyObject *>(MXFP8TensorBasePythonClass));
for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none();
py::object columnwise_data =
(columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none());
py::object columnwise_scale =
(columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none());
// Construct Python tensor
tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data,
columnwise_scale, fp8_dtype,
quantizer_py_list[i]));
// Construct C++ tensor
tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp8_dtype, nullptr,
nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode));
}
return retval;
}
} // namespace
std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<int> &split_sections,
std::vector<py::handle> quantizer_list) {
init_extension();
// Check number of tensors
const size_t num_splits = split_sections.size();
NVTE_CHECK(quantizer_list.size() == num_splits, "Expected ", num_splits, " quantizers, but got ",
quantizer_list.size());
if (num_splits == 0) {
return {};
}
// Input tensor properties
auto input_py = tensor.contiguous();
uint8_t *input_dptr = reinterpret_cast<uint8_t *>(input_py.data_ptr());
auto input_dtype = GetTransformerEngineDType(input_py.scalar_type());
std::vector<size_t> input_shape;
size_t input_size = 1;
for (const auto &d : input_py.sizes()) {
input_shape.push_back(d);
input_size *= d;
}
NVTE_CHECK(input_shape.size() > 0, "Input tensor has 0 dims");
// Split input tensor along dim 0
std::vector<TensorWrapper> input_list;
std::vector<std::vector<size_t>> split_shapes;
size_t dim0_offset = 0;
const size_t dim0_stride =
input_shape[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape[0];
for (size_t i = 0; i < num_splits; ++i) {
NVTE_CHECK(split_sections[i] >= 0, "Attempted to split tensor with shape=", input_shape,
" along dim 0 with split_sections=", split_sections);
NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape[0],
"Attempted to split tensor with shape=", input_shape,
" along dim 0 with split_sections=", split_sections);
split_shapes.push_back(input_shape);
auto &split_shape = split_shapes.back();
split_shape[0] = split_sections[i];
void *split_dptr = static_cast<void *>(input_dptr + dim0_offset * dim0_stride);
input_list.emplace_back(makeTransformerEngineTensor(split_dptr, split_shape, input_dtype));
dim0_offset += split_sections[i];
}
// Convert quantizers to C++ objects
std::vector<std::unique_ptr<Quantizer>> quantizer_cpp_list;
for (size_t i = 0; i < num_splits; i++) {
quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i]));
}
// For FP8 block-scaling, we construct output tensors with bulk allocations
// For MXFP8, we also use bulk allocations
bool use_fused_bulk_alloc = true;
for (size_t i = 0; i < quantizer_list.size(); i++) {
if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr()) &&
!detail::IsMXFP8Quantizers(quantizer_list[i].ptr())) {
use_fused_bulk_alloc = false;
break;
}
}
// Allocate output tensors
std::vector<TensorWrapper> output_cpp_list;
std::vector<py::object> output_py_list;
if (!use_fused_bulk_alloc) {
// Allocate output tensors individually
for (size_t i = 0; i < num_splits; ++i) {
auto [output_cpp, output_py] =
quantizer_cpp_list[i]->create_tensor(split_shapes[i], input_dtype);
output_cpp_list.emplace_back(std::move(output_cpp));
output_py_list.emplace_back(std::move(output_py));
}
} else {
// TODO(zhongbo): make a better api to make this part less hacky
bool is_fp8_blockwise = detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr());
bool is_mxfp8 = detail::IsMXFP8Quantizers(quantizer_list[0].ptr());
if (is_fp8_blockwise) {
// FP8 block-scaling: construct output tensors with bulk allocations
std::vector<Float8BlockQuantizer *> blockwise_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
blockwise_quantizers.push_back(static_cast<Float8BlockQuantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers);
} else if (is_mxfp8) {
// MXFP8: construct output tensors with bulk allocations
std::vector<MXFP8Quantizer *> mxfp8_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
mxfp8_quantizers.push_back(static_cast<MXFP8Quantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers);
} else {
NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer");
}
}
// Perform multi-tensor quantization
multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list);
return output_py_list;
}
template <void (*func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor, template <void (*func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor,
cudaStream_t)> cudaStream_t)>
std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tensor& act_input, std::vector<py::object> dbias_dact(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) { py::handle quantizer) {
init_extension(); init_extension();
auto my_quantizer = convert_quantizer(quantizer); auto my_quantizer = convert_quantizer(quantizer);
...@@ -125,7 +662,7 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens ...@@ -125,7 +662,7 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype()); auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype());
auto act_input_tensor = makeTransformerEngineTensor(act_input); auto act_input_tensor = makeTransformerEngineTensor(act_input);
const auto& shape = convertShape(grad_tensor.shape()); const auto &shape = convertShape(grad_tensor.shape());
auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype()); auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype());
auto dbias_tensor = makeTransformerEngineTensor(grad_bias); auto dbias_tensor = makeTransformerEngineTensor(grad_bias);
...@@ -149,29 +686,30 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens ...@@ -149,29 +686,30 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
return {py::cast(grad_bias), dact}; return {py::cast(grad_bias), dact};
} }
std::vector<py::object> dbias_dgelu(const at::Tensor& grad_output, const at::Tensor& act_input, std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) { py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dgelu>(grad_output, act_input, quantizer); return dbias_dact<nvte_quantize_dbias_dgelu>(grad_output, act_input, quantizer);
} }
std::vector<py::object> dbias_dsilu(const at::Tensor& grad_output, const at::Tensor& act_input, std::vector<py::object> dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) { py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dsilu>(grad_output, act_input, quantizer); return dbias_dact<nvte_quantize_dbias_dsilu>(grad_output, act_input, quantizer);
} }
std::vector<py::object> dbias_drelu(const at::Tensor& grad_output, const at::Tensor& act_input, std::vector<py::object> dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) { py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_drelu>(grad_output, act_input, quantizer); return dbias_dact<nvte_quantize_dbias_drelu>(grad_output, act_input, quantizer);
} }
std::vector<py::object> dbias_dqgelu(const at::Tensor& grad_output, const at::Tensor& act_input, std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) { py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dqgelu>(grad_output, act_input, quantizer); return dbias_dact<nvte_quantize_dbias_dqgelu>(grad_output, act_input, quantizer);
} }
std::vector<py::object> dbias_dsrelu(const at::Tensor& grad_output, const at::Tensor& act_input, std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) { py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dsrelu>(grad_output, act_input, quantizer); return dbias_dact<nvte_quantize_dbias_dsrelu>(grad_output, act_input, quantizer);
} }
} // namespace transformer_engine::pytorch } // namespace pytorch
} // namespace transformer_engine
...@@ -81,4 +81,77 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -81,4 +81,77 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
}); });
} }
void fused_multi_row_unpadding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> unpadded_input_row_list) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(input_row_list.size() == unpadded_input_row_list.size(),
"Number of input row list and padded row list must match.");
NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2.");
NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2.");
const auto num_tensors = input_row_list.size();
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, output_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, output_shape_list;
std::vector<transformer_engine::DType> input_type_list;
void* d_input_ptr = reinterpret_cast<void*>(input.data_ptr());
void* d_output_ptr = reinterpret_cast<void*>(output.data_ptr());
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
input_dptr_list.push_back(d_input_ptr);
output_dptr_list.push_back(d_output_ptr);
// Move the input pointer to the next split.
char* input_char_ptr = reinterpret_cast<char*>(d_input_ptr);
const size_t input_dptr_offset =
input_row_list[tensor_id] * input.size(1) * input.element_size();
input_char_ptr += input_dptr_offset;
d_input_ptr = reinterpret_cast<void*>(input_char_ptr);
input_shape_list.push_back({input_row_list[tensor_id], static_cast<size_t>(input.size(1))});
input_type_list.push_back(GetTransformerEngineDType(input.scalar_type()));
// Move the output pointer to the next split.
char* output_char_ptr = reinterpret_cast<char*>(d_output_ptr);
const size_t output_dptr_offset =
unpadded_input_row_list[tensor_id] * output.size(1) * output.element_size();
output_char_ptr += output_dptr_offset;
d_output_ptr = reinterpret_cast<void*>(output_char_ptr);
output_shape_list.push_back(
{unpadded_input_row_list[tensor_id], static_cast<size_t>(output.size(1))});
}
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list, nvte_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype) -> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype));
return tensor_wrappers.back().data();
};
std::vector<int> unpadded_num_rows_list;
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
if (input_dptr_list[i] == nullptr || input_row_list[i] == 0) continue;
nvte_input_list.emplace_back(
make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i]));
nvte_output_list.emplace_back(
make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i]));
unpadded_num_rows_list.emplace_back(unpadded_input_row_list[i]);
}
// Check tensor lists
NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(),
"Number of input and output tensors must match");
NVTE_CHECK(unpadded_num_rows_list.size() == nvte_input_list.size() &&
"Number of input and padded row list must match");
// Launch TE kernel
nvte_multi_unpadding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(),
unpadded_num_rows_list.data(), at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <stdexcept> #include <memory>
#include <optional>
#include <vector>
#include "../common.h" #include "../common.h"
#include "../extensions.h" #include "../extensions.h"
...@@ -206,10 +208,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -206,10 +208,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm"); m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm");
m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize, m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize,
"Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"), "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
py::arg("quantizer_list"), py::arg("otype")); m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
"Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"),
py::arg("quantizer_list"));
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM"); "Grouped GEMM");
#ifdef USE_ROCM #ifdef USE_ROCM
...@@ -242,6 +245,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -242,6 +245,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>()); py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding,
"Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>()); "Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding,
"Fused Multi-tensor unpadding", py::call_guard<py::gil_scoped_release>());
// attention kernels // attention kernels
m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd,
...@@ -266,6 +271,32 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -266,6 +271,32 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward, m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
"Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>()); "Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
// fused router
m.def("fused_topk_with_score_function_fwd",
&transformer_engine::pytorch::fused_topk_with_score_function_fwd, py::arg("logits"),
py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), py::arg("group_topk"),
py::arg("scaling_factor"), py::arg("score_function"), py::arg("expert_bias"),
"Fused topk softmax fwd");
m.def("fused_topk_with_score_function_bwd",
&transformer_engine::pytorch::fused_topk_with_score_function_bwd, py::arg("num_tokens"),
py::arg("num_experts"), py::arg("routing_map"), py::arg("intermediate_output"),
py::arg("grad_probs"), py::arg("topk"), py::arg("use_pre_softmax"),
py::arg("scaling_factor"), py::arg("score_function"), "Fused topk softmax bwd");
m.def("fused_score_for_moe_aux_loss_fwd",
&transformer_engine::pytorch::fused_score_for_moe_aux_loss_fwd, py::arg("logits"),
py::arg("topk"), py::arg("score_function"), "Fused topk softmax fwd");
m.def("fused_score_for_moe_aux_loss_bwd",
&transformer_engine::pytorch::fused_score_for_moe_aux_loss_bwd, py::arg("num_tokens"),
py::arg("num_experts"), py::arg("intermediate_output"), py::arg("grad_scores"),
py::arg("topk"), py::arg("score_function"), "Fused topk softmax bwd");
m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd,
py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"),
py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"),
py::arg("coeff"), "Fused aux loss fwd");
m.def("fused_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_moe_aux_loss_bwd,
py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"),
py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd");
// Misc // Misc
m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version, m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version,
"Get cublasLt version", py::call_guard<py::gil_scoped_release>()); "Get cublasLt version", py::call_guard<py::gil_scoped_release>());
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common.h"
namespace transformer_engine::pytorch {
static std::map<std::string, int> score_function_map = {{"sigmoid", 0}, {"softmax", 1}};
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fwd(
at::Tensor logits, int topk, bool use_pre_softmax, c10::optional<int> num_groups,
c10::optional<int> group_topk, c10::optional<float> scaling_factor, std::string score_function,
c10::optional<at::Tensor> expert_bias) {
int num_tokens = logits.size(0);
int num_experts = logits.size(1);
// Check if the input is valid
TORCH_CHECK(num_tokens > 0 && num_experts > 0,
"num_tokens and num_experts must be greater than 0");
// Expert bias only happens at the sigmoid case
if (expert_bias.has_value()) {
TORCH_CHECK(score_function == "sigmoid",
"score_function must be sigmoid when expert_bias is not None");
}
// Check if the score function is valid
TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid",
"score_function must be softmax or sigmoid for router fusion");
if (score_function == "sigmoid") {
use_pre_softmax = false; // Pre-softmax only happens at the softmax case
}
// Reformat the input to make it compatible with the kernel
int group_topk_value = group_topk.has_value() ? group_topk.value() : -1;
int num_groups_value = num_groups.has_value() ? num_groups.value() : -1;
float scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f;
// Construct the output tensor
at::Tensor probs =
at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA));
at::Tensor routing_map =
at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA));
// Intermediate output is used to store the output of the softmax/sigmoid function
at::Tensor intermediate_output =
at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA));
auto logits_cu = makeTransformerEngineTensor(logits);
auto probs_cu = makeTransformerEngineTensor(probs);
auto routing_map_cu = makeTransformerEngineTensor(routing_map);
auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output);
auto expert_bias_cu = TensorWrapper(); // empty expert_bias_cu tensor
if (expert_bias.has_value()) {
expert_bias_cu = makeTransformerEngineTensor(expert_bias.value());
}
nvte_fused_topk_with_score_function_forward(
logits_cu.data(), num_tokens, num_experts, topk, use_pre_softmax, num_groups_value,
group_topk_value, scaling_factor_value, score_function_map[score_function],
expert_bias_cu.data(), probs_cu.data(), routing_map_cu.data(), intermediate_output_cu.data(),
at::cuda::getCurrentCUDAStream());
return std::make_tuple(probs, routing_map, intermediate_output);
}
at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts,
at::Tensor routing_map,
at::Tensor intermediate_output, at::Tensor grad_probs,
int topk, bool use_pre_softmax,
c10::optional<float> scaling_factor,
std::string score_function) {
// Get the value of the parameters
auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f;
auto score_function_value = score_function_map[score_function];
// Init the output tensor
at::Tensor grad_logits = at::empty(
{num_tokens, num_experts}, at::dtype(intermediate_output.scalar_type()).device(at::kCUDA));
auto routing_map_cu = makeTransformerEngineTensor(routing_map);
auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output);
auto grad_probs_cu = makeTransformerEngineTensor(grad_probs);
auto grad_logits_cu = makeTransformerEngineTensor(grad_logits);
nvte_fused_topk_with_score_function_backward(
routing_map_cu.data(), intermediate_output_cu.data(), grad_probs_cu.data(), num_tokens,
num_experts, topk, use_pre_softmax, scaling_factor_value, score_function_value,
grad_logits_cu.data(), at::cuda::getCurrentCUDAStream());
return grad_logits;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
at::Tensor logits, int topk, std::string score_function) {
int num_tokens = logits.size(0);
int num_experts = logits.size(1);
// Check if the input is valid
TORCH_CHECK(num_tokens > 0 && num_experts > 0,
"num_tokens and num_experts must be greater than 0");
TORCH_CHECK(topk > 0, "topk must be greater than 0");
// Check if the score function is valid
TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid",
"score_function must be softmax or sigmoid for router fusion");
int score_function_value = score_function_map[score_function];
// Construct the output tensor
at::Tensor scores =
at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA));
at::Tensor routing_map =
at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA));
at::Tensor intermediate_output =
at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA));
auto logits_cu = makeTransformerEngineTensor(logits);
auto scores_cu = makeTransformerEngineTensor(scores);
auto routing_map_cu = makeTransformerEngineTensor(routing_map);
auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output);
nvte_fused_score_for_moe_aux_loss_forward(
logits_cu.data(), num_tokens, num_experts, topk, score_function_value, scores_cu.data(),
routing_map_cu.data(), intermediate_output_cu.data(), at::cuda::getCurrentCUDAStream());
return std::make_tuple(scores, routing_map, intermediate_output);
}
at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts,
at::Tensor intermediate_output, at::Tensor grad_scores,
int topk, std::string score_function) {
// Get the value of the parameters
int score_function_value = score_function_map[score_function];
// Init the output tensor
at::Tensor grad_logits = at::empty(
{num_tokens, num_experts}, at::dtype(intermediate_output.scalar_type()).device(at::kCUDA));
auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output);
auto grad_scores_cu = makeTransformerEngineTensor(grad_scores);
auto grad_logits_cu = makeTransformerEngineTensor(grad_logits);
nvte_fused_score_for_moe_aux_loss_backward(
intermediate_output_cu.data(), grad_scores_cu.data(), num_tokens, num_experts, topk,
score_function_value, grad_logits_cu.data(), at::cuda::getCurrentCUDAStream());
return grad_logits;
}
std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
at::Tensor tokens_per_expert,
int total_num_tokens, int num_experts,
int num_rows, int num_cols, int topk,
float coeff) {
TORCH_CHECK(topk > 0, "topk must be greater than 0");
TORCH_CHECK(total_num_tokens > 0, "total_num_tokens must be greater than 0");
TORCH_CHECK(num_experts > 0, "num_experts must be greater than 0");
// Create the output tensor
at::Tensor aux_loss = at::empty({}, at::dtype(probs.scalar_type()).device(at::kCUDA));
at::Tensor Const_buf = at::empty({}, at::dtype(at::kFloat).device(at::kCUDA));
auto probs_cu = makeTransformerEngineTensor(probs);
auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert);
auto aux_loss_cu = makeTransformerEngineTensor(aux_loss);
auto Const_buf_cu = makeTransformerEngineTensor(Const_buf);
nvte_fused_moe_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens,
num_experts, num_rows, num_cols, topk, coeff, aux_loss_cu.data(),
Const_buf_cu.data(), at::cuda::getCurrentCUDAStream());
return std::make_tuple(aux_loss, Const_buf);
}
at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows,
int num_cols, at::Tensor grad_aux_loss) {
// Create the output tensor
at::Tensor grad_probs =
at::empty({num_rows, num_cols}, at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA));
auto Const_buf_cu = makeTransformerEngineTensor(Const_buf);
auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert);
auto grad_aux_loss_cu = makeTransformerEngineTensor(grad_aux_loss);
auto grad_probs_cu = makeTransformerEngineTensor(grad_probs);
// Meta data for the kernel
nvte_fused_moe_aux_loss_backward(Const_buf_cu.data(), tokens_per_expert_cu.data(), num_rows,
num_cols, grad_aux_loss_cu.data(), grad_probs_cu.data(),
at::cuda::getCurrentCUDAStream());
return grad_probs;
}
} // namespace transformer_engine::pytorch
...@@ -4,80 +4,16 @@ ...@@ -4,80 +4,16 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <pybind.h>
#include <optional> #include <optional>
#include <vector>
#include "../extensions.h" #include "../extensions.h"
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine {
namespace pytorch {
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, DType otype) {
init_extension();
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
std::vector<py::object> py_output_objects_list;
std::vector<TensorWrapper> tensor_wrappers;
if (output_list.has_value()) {
py_output_objects_list = output_list.value();
}
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool with_fused_kernel = true;
// create TE tensors from input
for (size_t i = 0; i < input_list.size(); i++) {
auto input_tensor = makeTransformerEngineTensor(input_list[i]);
const NVTEShape input_shape = input_tensor.shape();
TensorWrapper output_tensor;
if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) {
with_fused_kernel = false;
}
if (output_list == std::nullopt) {
std::unique_ptr<Quantizer> quantizer = convert_quantizer(quantizer_list[i]);
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
py::object o;
std::tie(output_tensor, o) = quantizer->create_tensor(output_shape, otype);
py_output_objects_list.push_back(o);
} else {
output_tensor = makeTransformerEngineTensor((*output_list)[i], quantizer_list[i]);
}
if (input_tensor.numel() == 0) continue;
nvte_tensor_output_list.emplace_back(output_tensor.data());
nvte_tensor_input_list.emplace_back(input_tensor.data());
tensor_wrappers.emplace_back(std::move(input_tensor));
tensor_wrappers.emplace_back(std::move(output_tensor));
}
// Check tensor lists
NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(),
"Number of input and output tensors must match");
for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) {
if (nvte_tensor_columnwise_data(nvte_tensor_output_list[i]) == nullptr) {
with_fused_kernel = false;
break;
}
}
// Launch TE kernel
if (with_fused_kernel) {
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
});
} else {
for (size_t i = 0; i < py_output_objects_list.size(); i++) {
quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt);
}
}
return py_output_objects_list;
}
at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) { at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) {
init_extension(); init_extension();
...@@ -108,4 +44,5 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor ...@@ -108,4 +44,5 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor
return out; return out;
} }
} // namespace transformer_engine::pytorch } // namespace pytorch
} // namespace transformer_engine
...@@ -283,10 +283,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -283,10 +283,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const { const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
using namespace pybind11::literals; using namespace pybind11::literals;
std::vector<int64_t> torch_shape; std::vector<int64_t> torch_shape;
size_t numel = 1;
for (auto s : shape) { for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s)); torch_shape.emplace_back(static_cast<int64_t>(s));
numel *= s;
} }
TensorWrapper tensor(this->get_scaling_mode()); TensorWrapper tensor(this->get_scaling_mode());
...@@ -296,10 +294,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -296,10 +294,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back();
size_t m_dim = numel / k_dim;
size_t kBlockLen = static_cast<size_t>(blockwise_fp8_block_len());
Float8BlockScaleTensorFormat data_format = Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT (all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY); : Float8BlockScaleTensorFormat::GEMM_READY);
...@@ -310,30 +304,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -310,30 +304,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
} else { } else {
data_rowwise = at::empty(torch_shape, opts); data_rowwise = at::empty(torch_shape, opts);
} }
size_t sinv0 = 0; auto scale_shape = get_scale_shape(shape, false);
size_t sinv1 = 0; size_t sinv0 = scale_shape[0];
if (block_scaling_dim == 2) { size_t sinv1 = scale_shape[1];
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else {
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor rowwise. "
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_inv_rowwise = scale_inv_rowwise =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts); at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape);
...@@ -364,27 +337,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -364,27 +337,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
columnwise_shape = shape; columnwise_shape = shape;
} }
} }
size_t sinv0 = 0; auto scale_shape = get_scale_shape(shape, true);
size_t sinv1 = 0; size_t sinv0 = scale_shape[0];
if (block_scaling_dim == 2) { size_t sinv1 = scale_shape[1];
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else {
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor columnwise. "
"Expected 1 or 2. Got ",
block_scaling_dim);
}
data_colwise = at::empty(torch_columnwise_shape, opts); data_colwise = at::empty(torch_columnwise_shape, opts);
scale_inv_colwise = scale_inv_colwise =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts); at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
...@@ -418,6 +373,81 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -418,6 +373,81 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
return {std::move(tensor), std::move(ret)}; return {std::move(tensor), std::move(ret)};
} }
std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size_t>& shape,
bool columnwise) const {
size_t numel = 1;
for (auto s : shape) {
numel *= s;
}
size_t k_dim = shape.size() == 0 ? 1u : shape.back();
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
std::vector<size_t> scale_shape;
bool rowwise_usage = !columnwise;
if (rowwise_usage) {
// rowwise scaling factor shape
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_shape = {sinv0, sinv1};
} else {
// columnwise scaling factor shape
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_shape = {sinv0, sinv1};
}
return scale_shape;
}
MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>(); this->dtype = quantizer.attr("dtype").cast<DType>();
} }
...@@ -450,11 +480,6 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -450,11 +480,6 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv,
columnwise_scale_inv; // TODO(pgadzinski) - change columnwise_scale_inv; // TODO(pgadzinski) - change
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
auto last_dim = static_cast<size_t>(torch_shape.back());
NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE,
" (got shape=", torch_shape, ")");
at::Tensor data; at::Tensor data;
if (rowwise_usage) { if (rowwise_usage) {
...@@ -463,9 +488,10 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -463,9 +488,10 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
} else { } else {
data = at::empty(torch_shape, opts); data = at::empty(torch_shape, opts);
} }
auto sinv0 = roundup(numel / last_dim, 128); auto scale_shape = get_scale_shape(shape, false);
auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); size_t sinv0 = scale_shape[0];
rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); size_t sinv1 = scale_shape[1];
rowwise_scale_inv = at::zeros({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv( tensor.set_rowwise_scale_inv(
rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
...@@ -473,10 +499,12 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -473,10 +499,12 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
} }
if (columnwise_usage) { if (columnwise_usage) {
auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); auto scale_shape = get_scale_shape(shape, true);
auto sinv1 = roundup(last_dim, 128); size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
columnwise_data = at::empty(torch_shape, opts); columnwise_data = at::empty(torch_shape, opts);
columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); columnwise_scale_inv =
at::zeros({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, opts);
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv( tensor.set_columnwise_scale_inv(
...@@ -504,4 +532,35 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -504,4 +532,35 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
return {std::move(tensor), std::move(ret)}; return {std::move(tensor), std::move(ret)};
} }
std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& shape,
bool columnwise) const {
size_t numel = 1;
for (auto s : shape) {
numel *= s;
}
auto last_dim = shape.back();
NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE,
" (got shape=", shape, ")");
std::vector<size_t> scale_shape;
bool rowwise_usage = !columnwise;
if (rowwise_usage) {
// rowwise scaling factor shape
size_t sinv0 = roundup(numel / last_dim, 128);
size_t sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4);
scale_shape = {sinv0, sinv1};
} else {
// columnwise scaling factor shape
size_t sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4);
size_t sinv1 = roundup(last_dim, 128);
scale_shape = {sinv0, sinv1};
}
return scale_shape;
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Export utilities for TransformerEngine"""
from contextlib import contextmanager
from typing import Generator
import torch
_IN_ONNX_EXPORT_MODE = False
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
@contextmanager
def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
"""
Context manager for exporting to ONNX.
.. code-block:: python
from transformer_engine.pytorch.export import onnx_export, te_translation_table
with onnx_export(enabled=True):
torch.onnx.export(model, dynamo=True, custom_translation_table=te_translation_table)
Parameters
----------
enabled: bool, default = `False`
whether or not to enable export
"""
global _IN_ONNX_EXPORT_MODE
onnx_export_state = _IN_ONNX_EXPORT_MODE
if (TORCH_MAJOR, TORCH_MINOR) < (2, 4):
raise RuntimeError("ONNX export is not supported for PyTorch versions less than 2.4")
try:
_IN_ONNX_EXPORT_MODE = enabled
yield
finally:
_IN_ONNX_EXPORT_MODE = onnx_export_state
def is_in_onnx_export_mode() -> bool:
"""Returns True if onnx export mode is enabled, False otherwise."""
return _IN_ONNX_EXPORT_MODE
def assert_warmed_up(module: torch.nn.Module) -> None:
"""Assert that the model has been warmed up before exporting to ONNX."""
assert hasattr(module, "forwarded_at_least_once"), (
"Model must be warmed up before exporting to ONNX, please run model with the"
" same recipe before exporting."
)
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4 or TORCH_MAJOR > 2:
# pylint: disable=unused-import
from .onnx_extensions import (
torch_onnx_gemm_inf_op,
onnx_quantize_fp8_op,
onnx_dequantize_fp8_op,
onnx_quantize_mxfp8_op,
onnx_dequantize_mxfp8_op,
onnx_layernorm,
onnx_attention_mask_func,
onnx_gemm,
te_translation_table,
)
...@@ -56,6 +56,8 @@ def check_fp8_support() -> Tuple[bool, str]: ...@@ -56,6 +56,8 @@ def check_fp8_support() -> Tuple[bool, str]:
def check_mxfp8_support() -> Tuple[bool, str]: def check_mxfp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available""" """Return if fp8 support is available"""
if get_device_compute_capability() >= (12, 0):
return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if get_device_compute_capability() >= (10, 0): # blackwell and above if get_device_compute_capability() >= (10, 0): # blackwell and above
return True, "" return True, ""
return False, "Device compute capability 10.0 or higher required for MXFP8 execution." return False, "Device compute capability 10.0 or higher required for MXFP8 execution."
...@@ -79,7 +81,11 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ...@@ -79,7 +81,11 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]:
def get_default_fp8_recipe() -> Recipe: def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args.""" """FP8 recipe with default args."""
if get_device_compute_capability() >= (10, 0): # blackwell and above if check_mxfp8_support()[0]:
# This is a temporary restriction until MXFP8 is supported for all
# gemm layouts.
if get_device_compute_capability() >= (12, 0):
return Float8BlockScaling()
return MXFP8BlockScaling() return MXFP8BlockScaling()
return DelayedScaling() return DelayedScaling()
......
...@@ -21,6 +21,7 @@ from .fp8 import ( ...@@ -21,6 +21,7 @@ from .fp8 import (
from .distributed import get_all_rng_states, graph_safe_rng_available from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule from .module.base import TransformerEngineBaseModule
from .ops.op import BasicOperation from .ops.op import BasicOperation
from .utils import make_weak_ref
__all__ = ["make_graphed_callables"] __all__ = ["make_graphed_callables"]
...@@ -63,8 +64,10 @@ def _make_graphed_callables( ...@@ -63,8 +64,10 @@ def _make_graphed_callables(
fp8_weight_caching: bool = False, fp8_weight_caching: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
_num_layers_per_chunk: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None, pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False, retain_graph_in_backward: bool = False,
_reuse_graph_input_output_buffers: bool = False,
) -> SingleOrTuple[Callable]: ) -> SingleOrTuple[Callable]:
""" """
Helper method for `make_graphed_callables` Helper method for `make_graphed_callables`
...@@ -110,29 +113,113 @@ def _make_graphed_callables( ...@@ -110,29 +113,113 @@ def _make_graphed_callables(
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py.
# Note: The model is assumed to consist of layers # Note: The model is assumed to consist of layers
# (corresponding to callables) that are grouped into # (corresponding to callables) that are grouped into
# equally-sized model chunks. _order is a list of chunk # model chunks. _num_layers_per_chunk is a list of integers
# indices (1-indexed) that indicates the order in which the # that indicates the number of layers in each model chunk.
# layers are evaluated. Positive values indicate forward # _order is a list of chunk indices (1-indexed) that
# passes and negative values indicate backward passes. Each # indicates the order in which the layers are evaluated.
# Positive values indicate forward passes and negative
# values indicate backward passes. Each
# entry in sample_args corresponds to one of the forward # entry in sample_args corresponds to one of the forward
# passes. # passes.
num_model_chunks = max(_order) num_model_chunks = max(_order)
num_microbatches = len(_order) // num_model_chunks // 2 num_microbatches = len(_order) // num_model_chunks // 2
assert num_model_chunks * num_microbatches * 2 == len(_order) assert num_model_chunks * num_microbatches * 2 == len(_order)
# Determine number of layers in each model chunk.
if _num_layers_per_chunk is None:
assert len(sample_args) * 2 >= len(_order) and ( assert len(sample_args) * 2 >= len(_order) and (
len(sample_args) * 2 % len(_order) == 0 len(sample_args) * 2 % len(_order) == 0
), f"{len(sample_args)} >= {len(_order)} and {len(sample_args)} % {len(_order)} == 0" ), (
f"{len(sample_args)} * 2 >= {len(_order)} and {len(sample_args)} * 2 %"
f" {len(_order)} == 0"
)
num_layers = len(sample_args) // num_model_chunks // num_microbatches num_layers = len(sample_args) // num_model_chunks // num_microbatches
assert len(callables) == num_model_chunks * num_layers, ( _num_layers_per_chunk = [num_layers] * num_model_chunks
f"Callables should have ({num_model_chunks * num_layers}) " else:
assert (
isinstance(_num_layers_per_chunk, int)
or len(_num_layers_per_chunk) == num_model_chunks
), (
"If _num_layers_per_chunk is provided, it must be an integer or a list of"
f" {num_model_chunks} integers, but got {_num_layers_per_chunk}."
)
if isinstance(_num_layers_per_chunk, int):
_num_layers_per_chunk = [_num_layers_per_chunk] * num_model_chunks
total_num_layers = sum(_num_layers_per_chunk)
assert len(callables) == total_num_layers, (
f"Callables should have ({total_num_layers}) "
+ f"entries when order input is provided but got {len(callables)}." + f"entries when order input is provided but got {len(callables)}."
) )
assert len(sample_args) == num_model_chunks * num_microbatches * num_layers, ( assert len(sample_args) == total_num_layers * num_microbatches, (
f"Expected {num_model_chunks * num_microbatches}" f"Expected {total_num_layers * num_microbatches}"
+ f"args tuple, but got {len(sample_args)}." + f"args tuple, but got {len(sample_args)}."
) )
# Calculate the starting index of each chunk in callables for future use.
_prefix_num_layers = [0]
for m_chunk in range(num_model_chunks):
num_layers = _num_layers_per_chunk[m_chunk]
_prefix_num_layers.append(_prefix_num_layers[-1] + num_layers)
assert len(sample_kwargs) == len(sample_args) assert len(sample_kwargs) == len(sample_args)
# Check reuse graph conditions and reorganize sample_args and sample_kwargs.
# Note: When capturing a graph, we hold onto the args and kwargs so we have static buffers
# when the graph is replayed. If two model chunk microbatches have no overlap between their
# forward and backward, then we can reduce memory usage by reusing the same static buffers.
if _reuse_graph_input_output_buffers:
assert (
_order is not None
), "`_order` must be provided when `_reuse_graph_input_output_buffers` is True."
assert (
is_training
), "`_reuse_graph_input_output_buffers` is only available in training mode."
assert isinstance(
sample_args, list
), "sample_args must be a list for _reuse_graph_input_output_buffers."
len_args = len(sample_args[0])
for i, arg in enumerate(sample_args):
assert len_args == len(
arg
), "Arguments must have same length and shape for `_reuse_graph_input_output_buffers`."
len_kwargs = len(sample_kwargs[0])
assert isinstance(
sample_kwargs, list
), "sample_kwargs must be a list for _reuse_graph_input_output_buffers."
for i, kwarg in enumerate(sample_kwargs):
assert len_kwargs == len(kwarg), (
"Keyword arguments must have same length and shape for"
" `_reuse_graph_input_output_buffers`."
)
# Reorganize args and kwargs for input tensor reuse.
fwd_sample_qs = {}
consumed_sample_q = []
fwd_idx = [0] * num_model_chunks
for c_id in _order:
m_chunk = abs(c_id) - 1
if c_id > 0:
sample_start_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk]
)
fwd_sample_idx = [
sample_start_idx + i for i in range(_num_layers_per_chunk[m_chunk])
]
fwd_sample_qs[m_chunk] = fwd_sample_qs.get(m_chunk, []) + fwd_sample_idx
for per_callable_fwd_idx in fwd_sample_idx:
if consumed_sample_q:
reuse_fwd_idx = consumed_sample_q.pop(0)
sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx]
sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx]
fwd_idx[m_chunk] += 1
else:
num_consumed_samples = min(
len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk]
)
consumed_sample_q += fwd_sample_qs[m_chunk][:num_consumed_samples]
fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:]
if fp8_weight_caching: if fp8_weight_caching:
# Initialize flag that controls FP8 weight updates # Initialize flag that controls FP8 weight updates
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
...@@ -185,10 +272,13 @@ def _make_graphed_callables( ...@@ -185,10 +272,13 @@ def _make_graphed_callables(
per_callable_module_params = [] per_callable_module_params = []
for m_chunk in range(num_model_chunks): for m_chunk in range(num_model_chunks):
for _ in range(num_microbatches): for _ in range(num_microbatches):
for l_no in range(num_layers): for l_no in range(_num_layers_per_chunk[m_chunk]):
per_callable_module_params.append( per_callable_module_params.append(
tuple(callables[m_chunk * num_layers + l_no].parameters()) tuple(callables[_prefix_num_layers[m_chunk] + l_no].parameters())
if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module) if isinstance(
callables[_prefix_num_layers[m_chunk] + l_no],
torch.nn.Module,
)
else () else ()
) )
assert len(per_callable_module_params) == len(flatten_sample_args) assert len(per_callable_module_params) == len(flatten_sample_args)
...@@ -227,10 +317,10 @@ def _make_graphed_callables( ...@@ -227,10 +317,10 @@ def _make_graphed_callables(
for c_id in _order: for c_id in _order:
if c_id > 0: if c_id > 0:
m_chunk = c_id - 1 m_chunk = c_id - 1
for l_no in range(num_layers): for l_no in range(_num_layers_per_chunk[m_chunk]):
func = callables[m_chunk * num_layers + l_no] func = callables[_prefix_num_layers[m_chunk] + l_no]
func_idx = (m_chunk * num_microbatches * num_layers) + ( func_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
fwd_idx[m_chunk] * num_layers + l_no fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
) )
warmup_func_idx.append(func_idx) warmup_func_idx.append(func_idx)
warmup_func.append(func) warmup_func.append(func)
...@@ -255,7 +345,7 @@ def _make_graphed_callables( ...@@ -255,7 +345,7 @@ def _make_graphed_callables(
args = sample_args[func_idx] args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx] kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx]
for _ in range(num_warmup_iters): for warmup_iter in range(num_warmup_iters):
hooks = [] hooks = []
for module in func.modules(): for module in func.modules():
hook = module.register_forward_hook(hook_fn) hook = module.register_forward_hook(hook_fn)
...@@ -271,6 +361,34 @@ def _make_graphed_callables( ...@@ -271,6 +361,34 @@ def _make_graphed_callables(
only_inputs=True, only_inputs=True,
allow_unused=allow_unused_input, allow_unused=allow_unused_input,
) )
# Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks
# registered to these params are not wrongly triggered.
num_required_grad_sample_args = sum(
arg.requires_grad for arg in flatten_sample_args[func_idx]
)
required_grad_input_idx = []
for i, arg in enumerate(static_input_surface):
if arg.requires_grad:
required_grad_input_idx.append(i)
module_params_with_grad = []
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
if (
grad_inputs[grad_inputs_idx] is not None
and grad_inputs_idx >= num_required_grad_sample_args
):
module_params_with_grad.append(static_input_surface[inputs_idx])
if len(module_params_with_grad) != len(per_callable_module_params[func_idx]):
assert warmup_iter == 0, (
"no-grad params should only be used as inputs in the first warmup"
" iteration"
)
per_callable_module_params[func_idx] = tuple(module_params_with_grad)
static_input_surface = flatten_sample_args[func_idx] + tuple(
module_params_with_grad
)
per_callable_static_input_surfaces[func_idx] = static_input_surface
else: else:
grad_inputs = None grad_inputs = None
del outputs, grad_inputs del outputs, grad_inputs
...@@ -292,14 +410,16 @@ def _make_graphed_callables( ...@@ -292,14 +410,16 @@ def _make_graphed_callables(
per_callable_static_grad_inputs = [None] * len(flatten_sample_args) per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
fwd_idx = [0] * num_model_chunks fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks
static_grad_outputs = None
previous_per_callable_bwd_idx = None
for c_id in _order: for c_id in _order:
if c_id > 0: if c_id > 0:
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
m_chunk = c_id - 1 m_chunk = c_id - 1
for l_no in range(num_layers): for l_no in range(_num_layers_per_chunk[m_chunk]):
func = callables[m_chunk * num_layers + l_no] func = callables[_prefix_num_layers[m_chunk] + l_no]
per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + ( per_callable_fwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
fwd_idx[m_chunk] * num_layers + l_no fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
) )
args = sample_args[per_callable_fwd_idx] args = sample_args[per_callable_fwd_idx]
kwargs = sample_kwargs[per_callable_fwd_idx] kwargs = sample_kwargs[per_callable_fwd_idx]
...@@ -314,14 +434,17 @@ def _make_graphed_callables( ...@@ -314,14 +434,17 @@ def _make_graphed_callables(
else: else:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk = -c_id - 1 m_chunk = -c_id - 1
for l_no in list(reversed(range(num_layers))): for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))):
per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) + ( per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
bwd_idx[m_chunk] * num_layers + l_no bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
) )
static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx] static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx]
static_outputs = per_callable_static_outputs[per_callable_bwd_idx] static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx] bwd_graph = bwd_graphs[per_callable_bwd_idx]
# For now, assumes all static_outputs require grad # For now, assumes all static_outputs require grad
if not _reuse_graph_input_output_buffers or static_grad_outputs is None:
# Note for _reuse_graph_input_output_buffers: grad output is only used
# within backward, so we can reuse the same static buffers every time.
static_grad_outputs = tuple( static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o.requires_grad else None for o in static_outputs
) )
...@@ -350,6 +473,30 @@ def _make_graphed_callables( ...@@ -350,6 +473,30 @@ def _make_graphed_callables(
per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs
per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs
# Weak ref the static outputs and static grad inputs that are no longer needed
# in the following steps. These two type of tensors are both in cudagraph
# mempool, so we just deallocate them and let PyTorch's memory allocator
# reuse them elsewhere.
if _reuse_graph_input_output_buffers:
# Weak ref the static outputs of the forward pass of this backward. It's
# no longer needed after the corresponding backward graph is built up.
per_callable_static_outputs[per_callable_bwd_idx] = make_weak_ref(
static_outputs
)
# Weak ref the static grad inputs of the previous backward pass.
# Note: After a backward pass, we assume Mcore will send the
# grad input to another pipeline parallel rank and that the
# communication is finished before the end of the next backward
# pass.
if previous_per_callable_bwd_idx is not None:
per_callable_static_grad_inputs[previous_per_callable_bwd_idx] = (
make_weak_ref(
per_callable_static_grad_inputs[previous_per_callable_bwd_idx]
)
)
previous_per_callable_bwd_idx = per_callable_bwd_idx
bwd_idx[m_chunk] += 1 bwd_idx[m_chunk] += 1
else: else:
# Capture forward graphs # Capture forward graphs
...@@ -593,7 +740,7 @@ def save_fp8_tensors( ...@@ -593,7 +740,7 @@ def save_fp8_tensors(
m.adjust_amax_history_length(fp8_recipe.amax_history_len) m.adjust_amax_history_length(fp8_recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors() module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, BasicOperation): elif isinstance(m, BasicOperation):
m.pre_forward(fp8_enabled=True, fp8_recipe=fp8_recipe) m.pre_first_forward(recipe=fp8_recipe)
module_tensors = m._save_fp8_metas() module_tensors = m._save_fp8_metas()
fp8_tensors.append(module_tensors) fp8_tensors.append(module_tensors)
return fp8_tensors return fp8_tensors
...@@ -634,8 +781,10 @@ def make_graphed_callables( ...@@ -634,8 +781,10 @@ def make_graphed_callables(
fp8_group: Optional[dist_group_type] = None, fp8_group: Optional[dist_group_type] = None,
fp8_weight_caching: bool = False, fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
_num_layers_per_chunk: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None, pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False, retain_graph_in_backward: bool = False,
_reuse_graph_input_output_buffers: bool = False,
) -> Union[Callable, Tuple[Callable, ...]]: ) -> Union[Callable, Tuple[Callable, ...]]:
""" """
Make CUDA graph version of Transformer Engine modules Make CUDA graph version of Transformer Engine modules
...@@ -664,6 +813,11 @@ def make_graphed_callables( ...@@ -664,6 +813,11 @@ def make_graphed_callables(
this graph may share memory with the indicated pool. this graph may share memory with the indicated pool.
retain_graph_in_backward: bool, default = `False` retain_graph_in_backward: bool, default = `False`
Whether to set retain_graph=True in backward graph capture. Whether to set retain_graph=True in backward graph capture.
_reuse_graph_input_output_buffers: bool, default = `False`
Reduce memory usage by reusing input/output data buffers between
graphs. Only supported with Mcore interleaved pipeline parallelism, i.e.
when `_order` is provided. All callables in `modules` are assumed to have
inputs and outputs with the same dtype and shape.
FP8-related parameters FP8-related parameters
---------------------- ----------------------
...@@ -702,10 +856,17 @@ def make_graphed_callables( ...@@ -702,10 +856,17 @@ def make_graphed_callables(
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)
# FP8 wrapper. # FP8 wrapper.
old_call_funcs = {}
def wrap_autocast(block): def wrap_autocast(block):
old_forward = block.forward block_cls = type(block)
if block_cls in old_call_funcs:
return
def forward_func(*args, **kwargs): old_call_funcs[block_cls] = block_cls.__call__
# Wrap the original call function of the module class.
def call_func(*args, **kwargs):
with fp8_autocast( with fp8_autocast(
enabled=fp8_enabled, enabled=fp8_enabled,
calibrating=fp8_calibrating, calibrating=fp8_calibrating,
...@@ -713,10 +874,10 @@ def make_graphed_callables( ...@@ -713,10 +874,10 @@ def make_graphed_callables(
fp8_group=fp8_group, fp8_group=fp8_group,
_graph=True, _graph=True,
): ):
outputs = old_forward(*args, **kwargs) outputs = old_call_funcs[block_cls](*args, **kwargs)
return outputs return outputs
block.forward = forward_func block_cls.__call__ = call_func
forward_funcs = [] forward_funcs = []
for module in modules: for module in modules:
...@@ -747,8 +908,10 @@ def make_graphed_callables( ...@@ -747,8 +908,10 @@ def make_graphed_callables(
fp8_weight_caching=fp8_weight_caching, fp8_weight_caching=fp8_weight_caching,
sample_kwargs=sample_kwargs, sample_kwargs=sample_kwargs,
_order=_order, _order=_order,
_num_layers_per_chunk=_num_layers_per_chunk,
pool=pool, pool=pool,
retain_graph_in_backward=retain_graph_in_backward, retain_graph_in_backward=retain_graph_in_backward,
_reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers,
) )
# Ensures warmup does not affect numerics for ops such as dropout. # Ensures warmup does not affect numerics for ops such as dropout.
...@@ -758,6 +921,10 @@ def make_graphed_callables( ...@@ -758,6 +921,10 @@ def make_graphed_callables(
else: else:
torch.cuda.set_rng_state(original_rng_states) torch.cuda.set_rng_state(original_rng_states)
# Remove FP8 wrapper.
for module_cls, old_call in old_call_funcs.items():
module_cls.__call__ = old_call
# Restore FP8 state. # Restore FP8 state.
restore_fp8_tensors(modules, saved_fp8_tensors) restore_fp8_tensors(modules, saved_fp8_tensors)
......
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
import os import os
from functools import wraps from functools import wraps
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import torch import torch
from . import torch_version from . import torch_version
from .export import is_in_onnx_export_mode
from .utils import gpu_autocast_ctx from .utils import gpu_autocast_ctx
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
...@@ -47,7 +47,17 @@ if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1" ...@@ -47,7 +47,17 @@ if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"
# Decorator to disable Torch Dynamo # Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308 # See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive) no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: (
f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive)
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
def set_jit_fusion_options() -> None: def set_jit_fusion_options() -> None:
......
...@@ -13,6 +13,7 @@ import torch ...@@ -13,6 +13,7 @@ import torch
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_default_init_method from ..utils import get_default_init_method
from ..export import is_in_onnx_export_mode
import warnings import warnings
try: try:
from lightop import rmsnorm_forward,rmsnorm_backward from lightop import rmsnorm_forward,rmsnorm_backward
...@@ -173,6 +174,8 @@ def noop_cat( ...@@ -173,6 +174,8 @@ def noop_cat(
raise ValueError("Attempted to concatenate 0 tensors") raise ValueError("Attempted to concatenate 0 tensors")
if len(tensors) == 1: if len(tensors) == 1:
return tensors[0] return tensors[0]
if is_in_onnx_export_mode():
return torch.cat(tensors, dim=dim)
return _NoopCatFunc.apply(dim, *tensors) return _NoopCatFunc.apply(dim, *tensors)
......
...@@ -1035,6 +1035,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1035,6 +1035,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one. just in case. The autocast exit will pick up the most recent one.
""" """
self.forwarded_at_least_once = True
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
......
...@@ -53,15 +53,16 @@ class _Fp8Padding(torch.autograd.Function): ...@@ -53,15 +53,16 @@ class _Fp8Padding(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
grad_output_mats = torch.split( in_features = grad_output.shape[-1]
grad_output.view(-1, grad_output.shape[-1]), ctx.padded_m_splits
# Allocate cast and transpose output tensor
total_row = sum(ctx.m_splits)
grad_input = torch.empty(
[total_row, in_features], dtype=grad_output.dtype, device=grad_output.device
) )
grad_input = torch.cat(
[ tex.fused_multi_row_unpadding(
grad_output_mat[: ctx.m_splits[i]] grad_output.view(-1, in_features), grad_input, ctx.padded_m_splits, ctx.m_splits
for i, grad_output_mat in enumerate(grad_output_mats)
],
dim=0,
) )
return (grad_input, None, None, None) return (grad_input, None, None, None)
...@@ -73,11 +74,12 @@ class Fp8Padding(torch.nn.Module): ...@@ -73,11 +74,12 @@ class Fp8Padding(torch.nn.Module):
Parameters Parameters
---------- ----------
num_gemms: int num_gemms : int
number of GEMMs to be performed simutaneously. number of GEMMs to be performed simultaneously.
align_size: int, optional align_size : int, optional
the alignment size for the input tensor. If not provided, the alignment size will the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others. be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first
forward pass.
""" """
def __init__( def __init__(
...@@ -88,9 +90,6 @@ class Fp8Padding(torch.nn.Module): ...@@ -88,9 +90,6 @@ class Fp8Padding(torch.nn.Module):
super().__init__() super().__init__()
self.num_gemms = num_gemms self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size self.align_size = align_size
@no_torch_dynamo() @no_torch_dynamo()
...@@ -111,6 +110,8 @@ class Fp8Padding(torch.nn.Module): ...@@ -111,6 +110,8 @@ class Fp8Padding(torch.nn.Module):
""" """
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
......
...@@ -29,10 +29,13 @@ class _Fp8Unpadding(torch.autograd.Function): ...@@ -29,10 +29,13 @@ class _Fp8Unpadding(torch.autograd.Function):
is_grad_enabled: bool, is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits) in_features = inp.shape[-1]
out_ret = torch.cat(
[grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0 # Allocate cast and transpose output tensor
) total_row = sum(m_splits)
out_ret = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device)
tex.fused_multi_row_unpadding(inp.view(-1, in_features), out_ret, padded_m_splits, m_splits)
if is_grad_enabled: if is_grad_enabled:
ctx.m_splits = m_splits ctx.m_splits = m_splits
...@@ -69,11 +72,12 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -69,11 +72,12 @@ class Fp8Unpadding(torch.nn.Module):
Parameters Parameters
---------- ----------
num_gemms: int num_gemms : int
number of GEMMs to be performed simutaneously. number of GEMMs to be performed simultaneously.
align_size: int, optional align_size : int, optional
the alignment size for the input tensor. If not provided, the alignment size will the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others. be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first
forward pass.
""" """
def __init__( def __init__(
...@@ -84,9 +88,6 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -84,9 +88,6 @@ class Fp8Unpadding(torch.nn.Module):
super().__init__() super().__init__()
self.num_gemms = num_gemms self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size self.align_size = align_size
@no_torch_dynamo() @no_torch_dynamo()
...@@ -107,6 +108,8 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -107,6 +108,8 @@ class Fp8Unpadding(torch.nn.Module):
""" """
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
......
...@@ -25,7 +25,6 @@ from ..fp8 import FP8GlobalStateManager ...@@ -25,7 +25,6 @@ from ..fp8 import FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
init_method_constant, init_method_constant,
requires_grad, requires_grad,
...@@ -39,11 +38,12 @@ from ..distributed import ( ...@@ -39,11 +38,12 @@ from ..distributed import (
from ..cpp_extensions import ( from ..cpp_extensions import (
general_grouped_gemm, general_grouped_gemm,
) )
from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensorBase, QuantizedTensorBase,
Quantizer, Quantizer,
...@@ -80,6 +80,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -80,6 +80,7 @@ class _GroupedLinear(torch.autograd.Function):
is_grad_enabled: bool, is_grad_enabled: bool,
module, module,
skip_fp8_weight_update, skip_fp8_weight_update,
save_original_input,
*weights_and_biases, *weights_and_biases,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -88,25 +89,18 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -88,25 +89,18 @@ class _GroupedLinear(torch.autograd.Function):
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:] biases = weights_and_biases[num_gemms:]
device = inp.device device = inp.device
# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmats = torch.split(inp.view(-1, in_features), m_splits)
if fp8:
assert_dim_for_fp8_exec(*inputmats, *weights)
# Cast input to expected dtype
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = []
weight_requires_grad = weights[0].requires_grad weight_requires_grad = weights[0].requires_grad
# Configure quantizers
if save_original_input and isinstance(input_quantizers[0], Float8Quantizer):
raise ValueError("DelayedScaling recipe is not supported with save_original_input")
if input_quantizers[0] is not None: if input_quantizers[0] is not None:
for input_quantizer in input_quantizers: for input_quantizer in input_quantizers:
input_quantizer.set_usage( input_quantizer.set_usage(
rowwise=True, rowwise=True,
columnwise=(is_grad_enabled and weight_requires_grad), columnwise=(
is_grad_enabled and weight_requires_grad and not save_original_input
),
) )
columnwise_usage = is_grad_enabled and inp.requires_grad columnwise_usage = is_grad_enabled and inp.requires_grad
if not columnwise_usage: if not columnwise_usage:
...@@ -121,17 +115,25 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -121,17 +115,25 @@ class _GroupedLinear(torch.autograd.Function):
for output_quantizer in output_quantizers: for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False) output_quantizer.set_usage(rowwise=True, columnwise=False)
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP # Initialize input tensors
if fp8: in_features = weights[0].size(-1)
recipe = FP8GlobalStateManager.get_fp8_recipe() if inp.size(-1) != in_features:
if hasattr(recipe, "fp8_gemm_fprop"): raise ValueError(
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator f"Input tensor (shape={tuple(inp.size())}) is not compatible with "
inputmats = tex.fused_multi_quantize( f"weight tensor (shape={tuple(weights[0].size())})"
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
) )
weights_fp8 = [] inp_view = inp.reshape(-1, in_features)
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype inputmats: list
if fp8:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)
# Initialize weights
weights_fp8: list
if fp8:
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
weights_fp8 = []
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
for i in range(num_gemms): for i in range(num_gemms):
weight_fp8 = module.get_weight_workspace( weight_fp8 = module.get_weight_workspace(
...@@ -144,18 +146,29 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -144,18 +146,29 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8.append(weight_fp8) weights_fp8.append(weight_fp8)
else: else:
inputmats = inputmats_no_fp8
bias_dtype = activation_dtype
weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights] weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights]
# Initialize biases
bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16 # FP8 GEMM only supports BF16/FP16 bias
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
# Initialize output tensor
out = torch.empty( out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0)], [sum(m_splits), weights_fp8[0].size(0)],
dtype=activation_dtype, dtype=activation_dtype,
device=device, device=device,
) )
# Choose whether to use split accumulator
use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Perform GEMM
_ = general_grouped_gemm( _ = general_grouped_gemm(
weights_fp8, weights_fp8,
inputmats, inputmats,
...@@ -166,7 +179,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -166,7 +179,7 @@ class _GroupedLinear(torch.autograd.Function):
m_splits=m_splits, m_splits=m_splits,
bias=biases, bias=biases,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=fprop_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
) )
if fp8_calibration: if fp8_calibration:
...@@ -183,9 +196,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -183,9 +196,15 @@ class _GroupedLinear(torch.autograd.Function):
# TODO: update after #1638 is merged. # pylint: disable=fixme # TODO: update after #1638 is merged. # pylint: disable=fixme
if weight_requires_grad: if weight_requires_grad:
if save_original_input:
inputmats = [None] * num_gemms
inputmats[0] = inp
else:
for inputmat in inputmats: for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensorBase): if isinstance(inputmat, QuantizedTensorBase):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
inputmats = [None] * num_gemms
if inp.requires_grad: if inp.requires_grad:
for weight in weights_fp8: for weight in weights_fp8:
if isinstance(weight, QuantizedTensorBase): if isinstance(weight, QuantizedTensorBase):
...@@ -202,9 +221,18 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -202,9 +221,18 @@ class _GroupedLinear(torch.autograd.Function):
ctx.weights_requires_grad = weights[0].requires_grad ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad: if fuse_wgrad_accumulation and ctx.weights_requires_grad:
ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)] # This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if hasattr(weights[0], "__fsdp_param__"):
# MCore FSDP creates main_grad lazily before backward
ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)]
else:
ctx.main_grad_funcs = [
lambda j=i: weights[j].main_grad for i in range(num_gemms)
]
else: else:
ctx.main_grads = [None] * num_gemms ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)]
ctx.device = device ctx.device = device
ctx.grad_output_quantizers = grad_output_quantizers ctx.grad_output_quantizers = grad_output_quantizers
ctx.m_splits = m_splits ctx.m_splits = m_splits
...@@ -226,6 +254,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -226,6 +254,8 @@ class _GroupedLinear(torch.autograd.Function):
or FP8GlobalStateManager.is_first_fp8_module() or FP8GlobalStateManager.is_first_fp8_module()
) )
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
ctx.save_original_input = save_original_input
ctx.input_quantizers = input_quantizers
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1]) return out.view(-1, *inp.shape[1:-1], out.shape[-1])
...@@ -240,7 +270,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -240,7 +270,7 @@ class _GroupedLinear(torch.autograd.Function):
weights = saved_tensors[N : 2 * N] weights = saved_tensors[N : 2 * N]
origin_weights = saved_tensors[2 * N : 3 * N] origin_weights = saved_tensors[2 * N : 3 * N]
biases = saved_tensors[3 * N : 4 * N] biases = saved_tensors[3 * N : 4 * N]
main_grads = ctx.main_grads main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
...@@ -248,36 +278,44 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -248,36 +278,44 @@ class _GroupedLinear(torch.autograd.Function):
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
# preprocess grad_output # Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
grad_output = grad_output.contiguous()
grad_output_mats = torch.split(
grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits
)
grad_output = [None] * ctx.num_gemms grad_output = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms
if ctx.fp8: if ctx.fp8:
if ctx.use_bias: if ctx.use_bias:
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
# for Float8BlockQuantizer. recipe = ctx.fp8_recipe
if ctx.fp8_recipe.float8_block_scaling(): if recipe.delayed() or recipe.float8_current_scaling() or recipe.mxfp8():
# Fused bias grad + quantize kernel
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0) grad_biases[i], grad_output[i] = tex.bgrad_quantize(
grad_output[i] = ctx.grad_output_quantizers[i](grad_output_mats[i]) grad_output_mats[i],
ctx.grad_output_quantizers[i],
)
else: else:
# Unfused bias grad and multi-tensor quantize
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
grad_biases[i], grad_output[i] = tex.bgrad_quantize( grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output_mats[i], ctx.grad_output_quantizers[i] grad_output = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
) )
else: else:
grad_output = tex.fused_multi_quantize( # Multi-tensor quantize
grad_output_mats, grad_output = tex.split_quantize(
None, grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers, ctx.grad_output_quantizers,
TE_DType[ctx.activation_dtype],
) )
else: else:
grad_output = grad_output_mats # Only split grad output. Grad bias is fused with
# wgrad GEMM.
grad_output = torch.split(
cast_if_needed(grad_output_view, ctx.activation_dtype),
ctx.m_splits,
)
if ctx.is_first_microbatch is not None: if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = ( accumulate_wgrad_into_param_main_grad = (
...@@ -334,6 +372,27 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -334,6 +372,27 @@ class _GroupedLinear(torch.autograd.Function):
torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device)
for w in weights for w in weights
] ]
if ctx.save_original_input:
inp = inputmats[0]
in_features = inp.shape[-1]
inp_view = inp.reshape(-1, in_features)
if ctx.input_quantizers[0] is not None:
for input_quantizer in ctx.input_quantizers:
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
input_quantizer.set_usage(rowwise=True, columnwise=True)
else:
input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmats: list
if ctx.fp8:
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
else:
inputmats = torch.split(
cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits
)
grouped_gemm_wgrad = functools.partial( grouped_gemm_wgrad = functools.partial(
general_grouped_gemm, general_grouped_gemm,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
...@@ -429,6 +488,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -429,6 +488,7 @@ class _GroupedLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
*wgrad_list, *wgrad_list,
*grad_biases, *grad_biases,
) )
...@@ -479,6 +539,11 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -479,6 +539,11 @@ class GroupedLinear(TransformerEngineBaseModule):
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False` delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation Whether to delay weight gradient computation
save_original_input : bool, default = `False`
If set to `True`, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe.
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
`parallel_mode` are used to determine the shapes of weights and biases. `parallel_mode` are used to determine the shapes of weights and biases.
...@@ -506,6 +571,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -506,6 +571,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
save_original_input: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -520,6 +586,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -520,6 +586,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self.ub_overlap_rs = ub_overlap_rs self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag self.ub_overlap_ag = ub_overlap_ag
self.ub_name = ub_name self.ub_name = ub_name
self.save_original_input = save_original_input
assert ( assert (
not ub_overlap_rs and not ub_overlap_ag not ub_overlap_rs and not ub_overlap_ag
), "GroupedLinear doesn't support Userbuffer overlap." ), "GroupedLinear doesn't support Userbuffer overlap."
...@@ -735,6 +802,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -735,6 +802,7 @@ class GroupedLinear(TransformerEngineBaseModule):
torch.is_grad_enabled(), torch.is_grad_enabled(),
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
self.save_original_input,
*weight_tensors, *weight_tensors,
*bias_tensors, *bias_tensors,
) )
......
...@@ -68,6 +68,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize ...@@ -68,6 +68,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import ( from ..cpp_extensions import (
...@@ -454,7 +455,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -454,7 +455,14 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.requires_wgrad = weight.requires_grad ctx.requires_wgrad = weight.requires_grad
ctx.quantized_weight = quantized_weight ctx.quantized_weight = quantized_weight
if fuse_wgrad_accumulation and weight.requires_grad: if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad # This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if hasattr(weight, "__fsdp_param__"):
# MCore FSDP creates main_grad lazily before backward
ctx.main_grad_func = weight.get_main_grad
else:
ctx.main_grad_func = lambda: weight.main_grad
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_output_quantizer = grad_output_quantizer
...@@ -500,7 +508,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -500,7 +508,7 @@ class _LayerNormLinear(torch.autograd.Function):
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp.shape) shape = list(inp_shape)
shape[0] *= tp_size if with_input_all_gather else 1 shape[0] *= tp_size if with_input_all_gather else 1
return out, ln_out_return.view(shape) return out, ln_out_return.view(shape)
return out, ln_out_return.view(inp_shape) return out, ln_out_return.view(inp_shape)
...@@ -535,7 +543,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -535,7 +543,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Since main_grad can be modified inplace, it should not be a part of saved_tensors # Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = ( main_grad = (
ctx.main_grad ctx.main_grad_func()
if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
else None else None
) )
...@@ -1470,6 +1478,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1470,6 +1478,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled debug = TEDebugState.debug_enabled
if debug: if debug:
self._validate_name() self._validate_name()
...@@ -1493,12 +1503,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1493,12 +1503,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
) as inp: ) as inp:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors() weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
quantizers = ( quantizers = (
self._get_quantizers(fp8_output, fp8_grad) self._get_quantizers(fp8_output, fp8_grad)
...@@ -1628,6 +1633,72 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1628,6 +1633,72 @@ class LayerNormLinear(TransformerEngineBaseModule):
for name, q in zip(names, original_quantizers) for name, q in zip(names, original_quantizers)
) )
def _get_weight_and_bias_tensors(self):
# Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
return weight_tensor, bias_tensor
def onnx_forward(
self,
inp: torch.Tensor,
fp8_output: bool,
) -> torch.Tensor:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from ..export import onnx_layernorm, onnx_gemm
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export"
assert_warmed_up(self)
(
input_quantizer,
weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(fp8_output, fp8_grad=False)
inp_dtype = inp.dtype
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
ln_out, ln_out_return = onnx_layernorm(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self.eps,
self.normalization,
self.zero_centered_gamma,
inp_dtype,
self.return_layernorm_output,
input_quantizer,
)
if weight_quantizer is not None:
weight_tensor_quantized = weight_quantizer.onnx_quantize(weight_tensor)
weight_tensor = weight_quantizer.onnx_dequantize(weight_tensor_quantized)
weight_tensor = weight_tensor.to(inp_dtype)
if bias_tensor is not None:
bias_tensor = bias_tensor.to(inp_dtype)
output = onnx_gemm(weight_tensor, ln_out, bias_tensor if self.apply_bias else None)
if output_quantizer is not None:
raise NotImplementedError("ONNX export of quantized output is not supported")
if self.return_layernorm_output and self.return_bias:
return output, bias_tensor.to(inp_dtype), ln_out_return
if self.return_layernorm_output:
return output, ln_out_return
if self.return_bias:
return output, bias_tensor.to(inp_dtype)
return output
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_linear.""" """Customize quantizers based on current scaling recipe + layernorm_linear."""
assert ( assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment