Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
...@@ -52,7 +52,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob ...@@ -52,7 +52,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer // my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get()); auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs) // check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) { if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group; c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
...@@ -69,7 +71,10 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob ...@@ -69,7 +71,10 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
// so in nvte_quantize_v2 with current scaling, the quant config is not used again // so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
...@@ -77,8 +82,10 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob ...@@ -77,8 +82,10 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
} }
nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, NVTE_SCOPED_GIL_RELEASE({
at::cuda::getCurrentCUDAStream()); nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
return out; return out;
} }
...@@ -96,7 +103,9 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype) ...@@ -96,7 +103,9 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype)
auto [out_tensor, out] = q.create_tensor(shape, otype); auto [out_tensor, out] = q.create_tensor(shape, otype);
nvte_dequantize(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); NVTE_SCOPED_GIL_RELEASE({
nvte_dequantize(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream());
});
return out; return out;
} }
...@@ -120,15 +129,19 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens ...@@ -120,15 +129,19 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
// Query workspace size and allocate workspace // Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace; transformer_engine::TensorWrapper workspace;
func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), NVTE_SCOPED_GIL_RELEASE({
workspace.data(), at::cuda::getCurrentCUDAStream()); func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel // Launch kernel
func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), NVTE_SCOPED_GIL_RELEASE({
workspace.data(), at::cuda::getCurrentCUDAStream()); func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
return {py::cast(grad_bias), dact}; return {py::cast(grad_bias), dact};
} }
......
...@@ -141,81 +141,79 @@ CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType ...@@ -141,81 +141,79 @@ CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType
num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm,
set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {} set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {}
void CommOverlap::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
_ubuf_scale_inv_initialized = true;
}
/* /*
** Helper function to copy input to _ubuf ** Helper function to copy input to _ubuf
*/ */
void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { void CommOverlap::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); const auto &input_ = input.contiguous();
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); // Check element size
const size_t element_size = input.element_size();
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()); NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
const size_t ubuf_size = _ubuf.numel();
void *dst_ptr = _ubuf.dptr();
if (local_chunk) { if (local_chunk) {
if (input_tensor.numel() * _tp_size > _ubuf.numel()) NVTE_CHECK(input_size * _tp_size == ubuf_size,
NVTE_ERROR("input is larger than the local communication buffer!"); "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
if (input_tensor.element_size() != _ubuf.element_size()) "(input_size=", input_size, ", tensor_parallel_size=", _tp_size,
NVTE_ERROR("input data type does not match communication buffer!"); ", ubuf_size=", ubuf_size, ")");
ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size(); dst_ptr = (reinterpret_cast<char *>(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size);
} else { } else {
if (input_tensor.numel() > _ubuf.numel()) NVTE_CHECK(input_size == ubuf_size,
NVTE_ERROR("input is larger than the global communication buffer!"); "Tried to copy an invalid tensor into a Userbuffers buffer ",
if (input_tensor.element_size() != _ubuf.element_size()) "(input_size=", input_size, ", ubuf_size=", ubuf_size, ")");
NVTE_ERROR("input data type does not match communication buffer!");
} }
// Copy either row or columnwise data into the communication buffer's columnwise data // Copy data
// NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with auto stream_main = at::cuda::getCurrentCUDAStream();
// the Userbuffers communicator.
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input_tensor.dptr(), NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
} }
py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk, at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
std::optional<const std::vector<int64_t>> shape) { // Check buffer shape
using namespace te::pytorch; const size_t ubuf_size = _ubuf.numel();
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.dptr()); if (shape) {
if (local_chunk) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
std::vector<int64_t> torch_shape; NVTE_CHECK(requested_size * _tp_size == ubuf_size,
if (shape.has_value()) { "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
torch_shape = shape.value(); ", tensor_parallel_size=", _tp_size, ", ubuf_size=", ubuf_size, ")");
size_t requested = product(torch_shape); } else {
auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel(); NVTE_CHECK(requested_size == ubuf_size,
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, "Invalid shape for a Userbuffers buffer (requested shape=", *shape,
") does not match allocated buffer size (", expected, ")!"); ", ubuf_size=", ubuf_size, ")");
}
} else { } else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); int64_t dim0 = _ubuf.size(0);
int64_t output_c_dim1 = _ubuf.size(1); int64_t dim1 = _ubuf.size(1);
torch_shape = {output_c_dim0, output_c_dim1}; if (local_chunk) {
dim0 /= _tp_size;
}
shape = {dim0, dim1};
} }
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape, // Data pointer
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); void *ubuf_ptr = _ubuf.dptr();
if (local_chunk) {
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer); ubuf_ptr = (reinterpret_cast<char *>(ubuf_ptr) +
std::vector<size_t> te_shape; (ubuf_size / _tp_size) * _tp_id * _ubuf.element_size());
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s)); }
// Always output a rowwise-only QuantizedTensor // Construct PyTorch tensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required. const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
auto is_internal = my_quantizer->internal; return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
auto uses_columnwise = my_quantizer->columnwise_usage;
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false;
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor);
my_quantizer->internal = is_internal;
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
} }
/*************************************************************************************************** /***************************************************************************************************
...@@ -236,74 +234,69 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal ...@@ -236,74 +234,69 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal
comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm, aggregate) {} atomic_gemm, aggregate) {}
void CommOverlapP2P::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
for (size_t i = 0; i < _ubufs.size(); i++) my_quantizer->set_quantization_params(&_ubufs[i]);
}
/* /*
** Copy input to _ubufs[0] ** Copy input to _ubufs[0]
*/ */
void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); const auto &input_ = input.contiguous();
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); // Check element size
const size_t element_size = input.element_size();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
void *dst_ptr;
if (local_chunk) { if (local_chunk) {
// Copy input to the target ubuf chunk by rank offset NVTE_CHECK(_ubufs[_tp_id].numel() == input_size,
if (input_tensor.numel() * _tp_size > _ubuf.numel()) "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
NVTE_ERROR("input is larger than the local communication buffer!"); "(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
if (input_tensor.element_size() != _ubuf.element_size()) dst_ptr = _ubufs[_tp_id].dptr();
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
} else { } else {
if (input_tensor.numel() > _ubuf.numel()) NVTE_CHECK(_ubuf.numel() == input_size,
NVTE_ERROR("input is larger than the global communication buffer!"); "Tried to copy an invalid tensor into a Userbuffers buffer ",
if (input_tensor.element_size() != _ubuf.element_size()) "(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")");
NVTE_ERROR("input data type does not match communication buffer!"); dst_ptr = _ubuf.dptr();
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
} }
// Copy data
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
cudaMemcpyDeviceToDevice,
(cudaStream_t)at::cuda::getCurrentCUDAStream()));
} }
py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk, at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
std::optional<const std::vector<int64_t>> shape) { // Check buffer shape
using namespace te::pytorch; if (shape) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr()); const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
std::vector<int64_t> torch_shape; NVTE_CHECK(requested_size == _ubufs[_tp_id].numel(),
if (shape.has_value()) { "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
torch_shape = shape.value(); ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
size_t requested = product(torch_shape); } else {
auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel(); NVTE_CHECK(requested_size == _ubuf.numel(),
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, "Invalid shape for a Userbuffers buffer (requested shape=", *shape,
") does not match allocated buffer size (", expected, ")!"); ", ubuf_size=", _ubuf.numel(), ")");
}
} else { } else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); int64_t dim0 = _ubuf.size(0);
int64_t output_c_dim1 = _ubuf.size(1); int64_t dim1 = _ubuf.size(1);
torch_shape = {output_c_dim0, output_c_dim1}; if (local_chunk) {
dim0 /= _tp_size;
}
shape = {dim0, dim1};
} }
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape,
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); // Data pointer
void *ubuf_ptr = local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr();
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
std::vector<size_t> te_shape; // Construct PyTorch tensor
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s)); const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
// Always output a rowwise-only QuantizedTensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto is_internal = my_quantizer->internal;
auto uses_columnwise = my_quantizer->columnwise_usage;
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false;
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor);
my_quantizer->internal = is_internal;
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
} }
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len) {
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor");
TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor");
TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16,
"tensor must be a float or bfloat16 tensor");
const TensorWrapper tensor_cu = makeTransformerEngineTensor(tensor);
TensorWrapper amax_cu = makeTransformerEngineTensor(amax);
nvte_fp8_block_scaling_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w,
amax.stride(0), amax.stride(1), start_offset,
block_len, at::cuda::getCurrentCUDAStream());
}
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype) {
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor");
TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor");
TORCH_CHECK(
inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16,
"input must be a float or bfloat16 tensor");
TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor");
TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 ||
out_dtype == transformer_engine::DType::kFloat8E5M2,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2");
const TensorWrapper inp_cu = makeTransformerEngineTensor(inp);
TensorWrapper out_cu = makeTransformerEngineTensor(out);
const TensorWrapper scale_cu = makeTransformerEngineTensor(scale);
nvte_fp8_block_scaling_partial_cast(
inp_cu.data(), out_cu.data(), scale_cu.data(), h, w, scale.stride(0), scale.stride(1),
start_offset, block_len, static_cast<NVTEDType>(out_dtype), at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <Python.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <optional> #include <optional>
...@@ -21,12 +20,12 @@ ...@@ -21,12 +20,12 @@
namespace { namespace {
void* get_data_ptr(MaybeTensor tensor) { void* get_data_ptr(transformer_engine::pytorch::MaybeTensor tensor) {
if (tensor.has_value()) return tensor->data_ptr(); if (tensor.has_value()) return tensor->data_ptr();
return nullptr; return nullptr;
} }
size_t get_size(MaybeTensor tensor, int dim) { size_t get_size(transformer_engine::pytorch::MaybeTensor tensor, int dim) {
if (tensor.has_value()) return static_cast<size_t>(tensor->size(dim)); if (tensor.has_value()) return static_cast<size_t>(tensor->size(dim));
return 0; return 0;
} }
...@@ -167,8 +166,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -167,8 +166,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type);
// Workspace // Workspace
auto te_workspace = auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); std::vector<size_t>{workspaceSize}, DType::kByte);
// Set an external SM Margin to all the GEMMs. // Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs // This comes in handy when DP is overlapped with GEMMs
...@@ -197,38 +196,52 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -197,38 +196,52 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Direct GEMM call to the correct overlap // Direct GEMM call to the correct overlap
if (bulk_overlap) { if (bulk_overlap) {
comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, NVTE_SCOPED_GIL_RELEASE({
te_pre_gelu_out, te_workspace, grad, accumulate, comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor,
use_split_accumulator, comm_type.value(), extra_output_tensor, te_pre_gelu_out, te_workspace, grad, accumulate,
main_stream); use_split_accumulator, comm_type.value(), extra_output_tensor,
main_stream);
});
} else if (comm_type.value() == CommOverlapType::AG) { } else if (comm_type.value() == CommOverlapType::AG) {
if (comm_overlap->is_atomic_gemm()) { if (comm_overlap->is_atomic_gemm()) {
comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, NVTE_SCOPED_GIL_RELEASE({
bias_tensor, te_pre_gelu_out, te_workspace, grad, comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
accumulate, use_split_accumulator, bias_tensor, te_pre_gelu_out, te_workspace, grad,
extra_output_tensor, main_stream); accumulate, use_split_accumulator,
extra_output_tensor, main_stream);
});
} else { } else {
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, NVTE_SCOPED_GIL_RELEASE({
te_pre_gelu_out, te_workspace, grad, accumulate, comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
use_split_accumulator, extra_output_tensor, main_stream); bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator, extra_output_tensor,
main_stream);
});
} }
} else { } else {
if (comm_overlap->is_atomic_gemm()) { if (comm_overlap->is_atomic_gemm()) {
comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, NVTE_SCOPED_GIL_RELEASE({
bias_tensor, te_pre_gelu_out, te_workspace, grad, comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
accumulate, use_split_accumulator, bias_tensor, te_pre_gelu_out, te_workspace, grad,
extra_output_tensor, main_stream); accumulate, use_split_accumulator,
extra_output_tensor, main_stream);
});
} else { } else {
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, NVTE_SCOPED_GIL_RELEASE({
te_pre_gelu_out, te_workspace, grad, accumulate, comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
use_split_accumulator, extra_output_tensor, main_stream); bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator, extra_output_tensor,
main_stream);
});
} }
} }
} else { } else {
// Launch GEMM // Launch GEMM
nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), NVTE_SCOPED_GIL_RELEASE({
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(),
accumulate, use_split_accumulator, num_math_sms, main_stream); te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, main_stream);
});
} }
} else { } else {
if (D_tensor.numel() != 0 && !accumulate) { if (D_tensor.numel() != 0 && !accumulate) {
...@@ -258,20 +271,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -258,20 +271,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
return out; return out;
} }
} // namespace transformer_engine::pytorch void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B, std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, transformer_engine::DType B_type, at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
std::vector<int64_t> B_scaling_mode, bool transb, at::Tensor D, bool transb, at::Tensor D, at::Tensor D_scale, DType D_type, at::Tensor D_amax,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, DType bias_type, at::Tensor pre_gelu_out, bool grad,
at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter) { bool gemm_producer, at::Tensor counter) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
// TODO: Handle scaling modes // TODO: Handle scaling modes
NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING;
NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING;
...@@ -286,12 +293,13 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine ...@@ -286,12 +293,13 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
nvte_scaling_modeB); nvte_scaling_modeB);
// TODO: D_scale_inv cannot be nullptr when D_type is FP8. // TODO: D_scale_inv cannot be nullptr when D_type is FP8.
auto te_D = makeTransformerEngineTensor( auto te_D = makeTransformerEngineTensor(
D.data_ptr(), {static_cast<size_t>(D.size(0)), static_cast<size_t>(D.size(1))}, D_type, D.data_ptr(),
std::vector<size_t>{static_cast<size_t>(D.size(0)), static_cast<size_t>(D.size(1))}, D_type,
D_amax.data_ptr(), D_scale.data_ptr(), nullptr); D_amax.data_ptr(), D_scale.data_ptr(), nullptr);
auto te_bias = auto te_bias = makeTransformerEngineTensor(
makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))}, bias_type); bias.data_ptr(), std::vector<size_t>{static_cast<size_t>(bias.size(0))}, bias_type);
auto te_counter = makeTransformerEngineTensor( auto te_counter = makeTransformerEngineTensor(
counter.data_ptr(), {static_cast<size_t>(counter.size(0))}, DType::kInt32); counter.data_ptr(), std::vector<size_t>{static_cast<size_t>(counter.size(0))}, DType::kInt32);
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))} ? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
...@@ -299,24 +307,23 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine ...@@ -299,24 +307,23 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
static_cast<size_t>(pre_gelu_out.size(1))}; static_cast<size_t>(pre_gelu_out.size(1))};
auto te_pre_gelu_out = makeTransformerEngineTensor( auto te_pre_gelu_out = makeTransformerEngineTensor(
pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type()));
auto te_workspace = auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); std::vector<size_t>{workspaceSize}, DType::kByte);
nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), NVTE_SCOPED_GIL_RELEASE({
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
accumulate, use_split_accumulator, math_sm_count, m_split, n_split, te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream()); accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream());
});
} }
std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb, std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type, std::optional<std::vector<at::Tensor>> D, DType D_type, std::vector<int64_t> m_splits,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias, std::vector<at::Tensor> bias, DType bias_type, bool single_output,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out, std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) {
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector, std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector; te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers; std::vector<TensorWrapper> wrappers;
...@@ -419,16 +426,19 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -419,16 +426,19 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
wrappers.emplace_back(std::move(te_pre_gelu_out)); wrappers.emplace_back(std::move(te_pre_gelu_out));
} }
for (size_t i = 0; i < workspace.size(); i++) { for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte); auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data()); te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp)); wrappers.emplace_back(std::move(wsp));
} }
// For now, we only have multi-stream cublas backend. // For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), NVTE_SCOPED_GIL_RELEASE({
te_bias_vector.data(), te_pre_gelu_out_vector.data(), nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_A_vector.size(), transa, transb, grad, te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_workspace_vector.data(), accumulate, use_split_accumulator, te_A_vector.size(), transa, transb, grad,
math_sm_count, at::cuda::getCurrentCUDAStream()); te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
});
return bias; return bias;
} }
...@@ -534,3 +544,5 @@ std::vector<at::Tensor> te_batchgemm_ts( ...@@ -534,3 +544,5 @@ std::vector<at::Tensor> te_batchgemm_ts(
} }
#endif #endif
} // namespace transformer_engine::pytorch
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
#ifdef USE_ROCM #ifdef USE_ROCM
size_t get_cublasLt_version() { int version = 10000000; return version; } size_t get_cublasLt_version() { int version = 10000000; return version; }
...@@ -15,3 +17,5 @@ size_t get_cublasLt_version() { return cublasLtGetVersion(); } ...@@ -15,3 +17,5 @@ size_t get_cublasLt_version() { return cublasLtGetVersion(); }
size_t get_cudnn_version() { return cudnnGetVersion(); } size_t get_cudnn_version() { return cudnnGetVersion(); }
#endif #endif
} // namespace transformer_engine::pytorch
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction,
weight_decay, device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_param_remainder_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1,
beta2, epsilon, step, mode, bias_correction, weight_decay, device_id,
at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(),
num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode,
bias_correction, weight_decay, static_cast<NVTEDType>(fp8_dtype),
device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8,
force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
auto ret = at::empty({1}, output.options());
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
output_per_tensor = at::empty({0}, float_options);
ret_per_tensor = at::empty({0}, float_options);
}
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto output_cu = makeTransformerEngineTensor(output);
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, output_cu.data(), output_per_tensor_cu.data(),
ret_cu.data(), ret_per_tensor_cu.data(), per_tensor,
max_chunks_per_tensor, device_id, at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
// Create output tensors for multi scale L2 norm kernel.
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
output_per_tensor = at::empty({0}, float_options);
ret_per_tensor = at::empty({0}, float_options);
}
auto ret = at::empty({1}, output.options());
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto output_cu = makeTransformerEngineTensor(output);
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_unscale_l2norm_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(),
inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, device_id,
at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
} // namespace transformer_engine::pytorch
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, wd, momentum, dampening, lr, nesterov, first_run,
wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
...@@ -9,28 +9,11 @@ ...@@ -9,28 +9,11 @@
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
std::pair<TensorWrapper, py::object> createOutputTensor(const NVTEShape &shape, DType dtype,
py::handle quantizer) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; i++) {
size_t t = shape.data[i];
shape_vec.push_back(t);
}
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape_vec, dtype);
}
std::pair<TensorWrapper, py::object> createOutputTensor(std::vector<size_t> &shape, DType dtype,
py::handle quantizer) {
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape, dtype);
}
} // namespace transformer_engine::pytorch
std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma, const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin, const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous(); const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous(); const auto &x_ = x.contiguous();
const auto &mu_ = mu.contiguous(); const auto &mu_ = mu.contiguous();
...@@ -40,7 +23,7 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -40,7 +23,7 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dx = at::empty_like(x_); auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_); auto dgamma = at::empty_like(gamma_);
auto dbeta = at::empty_like(gamma_); auto dbeta = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace; TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_); auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_); auto x_cu = makeTransformerEngineTensor(x_);
...@@ -52,10 +35,12 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -52,10 +35,12 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dbeta_cu = makeTransformerEngineTensor(dbeta); auto dbeta_cu = makeTransformerEngineTensor(dbeta);
// This call populates tensors with the required config. // This call populates tensors with the required config.
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), NVTE_SCOPED_GIL_RELEASE({
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Alloc space for Tensors. // Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -63,10 +48,12 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -63,10 +48,12 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Actual call to bwd kernel. // Actual call to bwd kernel.
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), NVTE_SCOPED_GIL_RELEASE({
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
return {py::cast(dx), py::cast(dgamma), py::cast(dbeta)}; return {py::cast(dx), py::cast(dgamma), py::cast(dbeta)};
} }
...@@ -76,8 +63,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -76,8 +63,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
DType out_dtype, const int sm_margin, DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
...@@ -131,11 +116,13 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -131,11 +116,13 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size // Query workspace size
transformer_engine::TensorWrapper workspace; TensorWrapper workspace;
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), NVTE_SCOPED_GIL_RELEASE({
mu_cu.data(), rsigma_cu.data(), workspace.data(), nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, mu_cu.data(), rsigma_cu.data(), workspace.data(),
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Allocate workspace // Allocate workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -143,10 +130,12 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -143,10 +130,12 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel // Launch kernel
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), NVTE_SCOPED_GIL_RELEASE({
mu_cu.data(), rsigma_cu.data(), workspace.data(), nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, mu_cu.data(), rsigma_cu.data(), workspace.data(),
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
...@@ -154,7 +143,10 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -154,7 +143,10 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer // my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(),
at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs) // check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) { if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = c10::intrusive_ptr<dist_group_type> process_group_ptr =
...@@ -169,7 +161,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -169,7 +161,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
} }
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
...@@ -177,8 +171,10 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -177,8 +171,10 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
} }
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, NVTE_SCOPED_GIL_RELEASE({
at::cuda::getCurrentCUDAStream()); nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
} }
return {out, py::cast(mu), py::cast(rsigma)}; return {out, py::cast(mu), py::cast(rsigma)};
...@@ -187,7 +183,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -187,7 +183,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &rsigma, const at::Tensor &gamma, const at::Tensor &rsigma, const at::Tensor &gamma,
const int sm_margin, const bool zero_centered_gamma) { const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous(); const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous(); const auto &x_ = x.contiguous();
const auto &rsigma_ = rsigma.contiguous(); const auto &rsigma_ = rsigma.contiguous();
...@@ -195,7 +190,7 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -195,7 +190,7 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dx = at::empty_like(x_); auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_); auto dgamma = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace; TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_); auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_); auto x_cu = makeTransformerEngineTensor(x_);
...@@ -205,10 +200,12 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -205,10 +200,12 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dgamma_cu = makeTransformerEngineTensor(dgamma); auto dgamma_cu = makeTransformerEngineTensor(dgamma);
// This call populates tensors with the required config. // This call populates tensors with the required config.
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), NVTE_SCOPED_GIL_RELEASE({
dgamma_cu.data(), workspace.data(), nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, dgamma_cu.data(), workspace.data(),
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Alloc space for Tensors. // Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -216,21 +213,20 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -216,21 +213,20 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Actual call to bwd kernel. // Actual call to bwd kernel.
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), NVTE_SCOPED_GIL_RELEASE({
dgamma_cu.data(), workspace.data(), nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, dgamma_cu.data(), workspace.data(),
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
return {py::cast(dx), py::cast(dgamma)}; return {py::cast(dx), py::cast(dgamma)};
} }
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object out, py::handle quantizer, py::object out, py::handle quantizer, DType out_dtype,
transformer_engine::DType out_dtype, const int sm_margin, const int sm_margin, const bool zero_centered_gamma) {
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
...@@ -278,11 +274,13 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -278,11 +274,13 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size // Query workspace size
transformer_engine::TensorWrapper workspace; TensorWrapper workspace;
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), NVTE_SCOPED_GIL_RELEASE({
workspace.data(), nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(),
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Allocate workspace // Allocate workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -290,10 +288,12 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -290,10 +288,12 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel // Launch kernel
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), NVTE_SCOPED_GIL_RELEASE({
workspace.data(), nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(),
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
...@@ -301,7 +301,10 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -301,7 +301,10 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer // my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(),
at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs) // check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) { if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = c10::intrusive_ptr<dist_group_type> process_group_ptr =
...@@ -316,7 +319,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -316,7 +319,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
} }
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
...@@ -324,9 +329,13 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -324,9 +329,13 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
} }
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, NVTE_SCOPED_GIL_RELEASE({
at::cuda::getCurrentCUDAStream()); nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
} }
return {out, py::none(), py::cast(rsigma)}; return {out, py::none(), py::cast(rsigma)};
} }
} // namespace transformer_engine::pytorch
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#include <torch/cuda.h> #include <torch/cuda.h>
#include <torch/extension.h> #include <torch/extension.h>
namespace nvshmem_api { namespace transformer_engine::pytorch {
void init_nvshmem_backend(c10d::ProcessGroup *process_group) { void init_nvshmem_backend(c10d::ProcessGroup *process_group) {
#ifdef NVTE_ENABLE_NVSHMEM #ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t attr = {}; nvshmemx_init_attr_t attr = {};
...@@ -126,4 +127,5 @@ void nvshmem_finalize() { ...@@ -126,4 +127,5 @@ void nvshmem_finalize() {
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif #endif
} }
} // namespace nvshmem_api
} // namespace transformer_engine::pytorch
...@@ -5,13 +5,13 @@ ...@@ -5,13 +5,13 @@
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#include "pybind.h"
namespace transformer_engine::pytorch {
void fused_multi_row_padding(at::Tensor input, at::Tensor output, void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list, std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list) { std::vector<size_t> padded_input_row_list) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(), NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(),
"Number of input row list and padded row list must match."); "Number of input row list and padded row list must match.");
NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2.");
...@@ -21,7 +21,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -21,7 +21,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
// Extract properties from PyTorch tensors // Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, output_dptr_list; std::vector<void*> input_dptr_list, output_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, output_shape_list; std::vector<std::vector<size_t>> input_shape_list, output_shape_list;
std::vector<transformer_engine::DType> input_type_list; std::vector<DType> input_type_list;
void* d_input_ptr = reinterpret_cast<void*>(input.data_ptr()); void* d_input_ptr = reinterpret_cast<void*>(input.data_ptr());
void* d_output_ptr = reinterpret_cast<void*>(output.data_ptr()); void* d_output_ptr = reinterpret_cast<void*>(output.data_ptr());
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
...@@ -51,9 +51,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -51,9 +51,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
// Construct TE tensors // Construct TE tensors
std::vector<NVTETensor> nvte_input_list, nvte_output_list; std::vector<NVTETensor> nvte_input_list, nvte_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers; std::vector<TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape, auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype) -> NVTETensor { DType dtype) -> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype));
return tensor_wrappers.back().data(); return tensor_wrappers.back().data();
}; };
...@@ -75,6 +75,10 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -75,6 +75,10 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
"Number of input and padded row list must match"); "Number of input and padded row list must match");
// Launch TE kernel // Launch TE kernel
nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(), NVTE_SCOPED_GIL_RELEASE({
padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream()); nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(),
padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream());
});
} }
} // namespace transformer_engine::pytorch
...@@ -4,14 +4,13 @@ ...@@ -4,14 +4,13 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cub/cub.cuh>
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, at::Tensor input, const DType dtype, at::Tensor indices, int64_t num_out_tokens,
int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) { std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) {
using namespace transformer_engine::pytorch;
const int num_tokens = input.size(0); const int num_tokens = input.size(0);
int num_cols = input.size(1); int num_cols = input.size(1);
const int topK = indices.size(1); const int topK = indices.size(1);
...@@ -28,9 +27,8 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( ...@@ -28,9 +27,8 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
int *temp_ptr = nullptr; nvte_device_radix_sort_pairs(nullptr, &temp_storage_bytes, nullptr, nullptr, nullptr, nullptr,
cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_ptr, temp_ptr, temp_ptr, max_expanded_token_num);
temp_ptr, max_expanded_token_num);
at::Tensor temp_storage = torch::empty( at::Tensor temp_storage = torch::empty(
temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
...@@ -40,17 +38,18 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( ...@@ -40,17 +38,18 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
workspace.push_back(temp_storage); workspace.push_back(temp_storage);
} }
int *indices_ptr = reinterpret_cast<int *>(getDataPtr(indices, 0)); void *indices_ptr = getDataPtr(indices, 0);
int *sorted_indices_ptr = reinterpret_cast<int *>(getDataPtr(workspace[0], 0)); void *sorted_indices_ptr = getDataPtr(workspace[0], 0);
int *row_id_ptr = reinterpret_cast<int *>(getDataPtr(workspace[1], 0)); void *row_id_ptr = getDataPtr(workspace[1], 0);
int *sorted_row_id_ptr = reinterpret_cast<int *>(getDataPtr(workspace[2], 0)); void *sorted_row_id_ptr = getDataPtr(workspace[2], 0);
void *d_temp_storage = getDataPtr(workspace[3], 0); void *d_temp_storage = getDataPtr(workspace[3], 0);
size_t temp_storage_bytes = std::numeric_limits<size_t>::max(); size_t temp_storage_bytes = std::numeric_limits<size_t>::max();
cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, indices_ptr, nvte_device_radix_sort_pairs(
sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, d_temp_storage, &temp_storage_bytes, reinterpret_cast<int *>(indices_ptr),
num_tokens * topK); reinterpret_cast<int *>(sorted_indices_ptr), reinterpret_cast<int *>(row_id_ptr),
reinterpret_cast<int *>(sorted_row_id_ptr), num_tokens * topK);
// Output buffer alloc // Output buffer alloc
num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
...@@ -63,34 +62,33 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( ...@@ -63,34 +62,33 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_cu = makeTransformerEngineTensor( auto input_cu = makeTransformerEngineTensor(
input.data_ptr(), {static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)}, dtype); input.data_ptr(),
auto permuted_output_cu = makeTransformerEngineTensor( std::vector<size_t>{static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)},
permuted_output.data_ptr(), dtype);
{static_cast<size_t>(permuted_output.size(0)), static_cast<size_t>(num_cols)}, dtype); auto permuted_output_cu =
auto sorted_row_id_cu = makeTransformerEngineTensor(permuted_output.data_ptr(),
makeTransformerEngineTensor(sorted_row_id_ptr, {static_cast<size_t>(num_tokens * topK)}, std::vector<size_t>{static_cast<size_t>(permuted_output.size(0)),
transformer_engine::DType::kInt32); static_cast<size_t>(num_cols)},
dtype);
auto sorted_row_id_cu = makeTransformerEngineTensor(
sorted_row_id_ptr, std::vector<size_t>{static_cast<size_t>(num_tokens * topK)},
DType::kInt32);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(),
row_id_map_cu.data(), transformer_engine::TensorWrapper().data(), row_id_map_cu.data(), TensorWrapper().data(), TensorWrapper().data(),
transformer_engine::TensorWrapper().data(), TensorWrapper().data(), num_tokens, topK, num_cols, num_out_tokens, stream);
transformer_engine::TensorWrapper().data(), num_tokens, topK, num_cols,
num_out_tokens, stream);
return std::make_tuple(permuted_output, row_id_map, workspace); return std::make_tuple(permuted_output, row_id_map, workspace);
} }
at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor moe_permute_bwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, at::Tensor prob, int64_t num_tokens, int64_t topK) {
int64_t topK) {
return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK); return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK);
} }
at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, at::Tensor prob, int64_t num_tokens, int64_t topK) {
int64_t topK) {
using namespace transformer_engine::pytorch;
int num_cols = input.size(1); int num_cols = input.size(1);
// Output buffer alloc // Output buffer alloc
...@@ -101,10 +99,14 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d ...@@ -101,10 +99,14 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_cu = makeTransformerEngineTensor( auto input_cu = makeTransformerEngineTensor(
input.data_ptr(), {static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)}, dtype); input.data_ptr(),
std::vector<size_t>{static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)},
dtype);
auto unpermuted_output_cu = makeTransformerEngineTensor( auto unpermuted_output_cu = makeTransformerEngineTensor(
unpermuted_output.data_ptr(), unpermuted_output.data_ptr(),
{static_cast<size_t>(unpermuted_output.size(0)), static_cast<size_t>(num_cols)}, dtype); std::vector<size_t>{static_cast<size_t>(unpermuted_output.size(0)),
static_cast<size_t>(num_cols)},
dtype);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
auto prob_cu = makeTransformerEngineTensor(prob); auto prob_cu = makeTransformerEngineTensor(prob);
...@@ -115,9 +117,8 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d ...@@ -115,9 +117,8 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
} }
std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
const transformer_engine::DType dtype, const DType dtype, at::Tensor row_id_map,
at::Tensor row_id_map, at::Tensor prob) { at::Tensor prob) {
using namespace transformer_engine::pytorch;
const int topK = (prob.numel() > 0) ? prob.size(1) : 1; const int topK = (prob.numel() > 0) ? prob.size(1) : 1;
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1); int num_cols = input_bwd.size(1);
...@@ -132,21 +133,26 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T ...@@ -132,21 +133,26 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_bwd_cu = makeTransformerEngineTensor( auto input_bwd_cu = makeTransformerEngineTensor(
input_bwd.data_ptr(), {static_cast<size_t>(input_bwd.size(0)), static_cast<size_t>(num_cols)}, input_bwd.data_ptr(),
std::vector<size_t>{static_cast<size_t>(input_bwd.size(0)), static_cast<size_t>(num_cols)},
dtype); dtype);
auto act_grad_cu = makeTransformerEngineTensor( auto act_grad_cu = makeTransformerEngineTensor(
act_grad.data_ptr(), {static_cast<size_t>(act_grad.size(0)), static_cast<size_t>(num_cols)}, act_grad.data_ptr(),
std::vector<size_t>{static_cast<size_t>(act_grad.size(0)), static_cast<size_t>(num_cols)},
dtype); dtype);
auto input_fwd_cu = makeTransformerEngineTensor( auto input_fwd_cu = makeTransformerEngineTensor(
input_fwd.data_ptr(), {static_cast<size_t>(input_fwd.size(0)), static_cast<size_t>(num_cols)}, input_fwd.data_ptr(),
std::vector<size_t>{static_cast<size_t>(input_fwd.size(0)), static_cast<size_t>(num_cols)},
dtype); dtype);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
auto prob_cu = makeTransformerEngineTensor(prob); auto prob_cu = makeTransformerEngineTensor(prob);
auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); auto prob_grad_cu = makeTransformerEngineTensor(prob_grad);
nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), transformer_engine::TensorWrapper().data(), nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), TensorWrapper().data(),
row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(), row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(),
num_tokens, topK, num_cols, 0, stream); num_tokens, topK, num_cols, 0, stream);
return std::make_tuple(act_grad, prob_grad); return std::make_tuple(act_grad, prob_grad);
} }
} // namespace transformer_engine::pytorch
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include "pybind.h" #include "pybind.h"
#include <Python.h>
#include <pybind11/cast.h> #include <pybind11/cast.h>
#include <pybind11/detail/common.h> #include <pybind11/detail/common.h>
#include <pybind11/functional.h> #include <pybind11/functional.h>
...@@ -111,10 +110,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -111,10 +110,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
m.def("rowwise_swizzle", &rowwise_swizzle, "Swizzle rowwise scale inverses.",
py::call_guard<py::gil_scoped_release>());
m.def("columnwise_swizzle", &columnwise_swizzle, "Swizzle columnwise scale inverses.",
py::call_guard<py::gil_scoped_release>());
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
...@@ -160,85 +155,111 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -160,85 +155,111 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("quantizer")); py::arg("quantizer"));
// Permutation functions // Permutation functions
m.def("moe_permute_fwd", moe_permute_fwd); m.def("moe_permute_fwd", transformer_engine::pytorch::moe_permute_fwd, "MOE permute FWD",
m.def("moe_permute_bwd", moe_permute_bwd);
m.def("moe_unpermute_fwd", moe_unpermute_fwd);
m.def("moe_unpermute_bwd", moe_unpermute_bwd);
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD", m.def("moe_permute_bwd", transformer_engine::pytorch::moe_permute_bwd, "MOE permute BWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward, m.def("moe_unpermute_fwd", transformer_engine::pytorch::moe_unpermute_fwd, "MOE unpermute FWD",
"Scaled Masked Softmax FWD", py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward, m.def("moe_unpermute_bwd", transformer_engine::pytorch::moe_unpermute_bwd, "MOE unpermute BWD",
"Scaled Masked Softmax BWD", py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_forward", &scaled_upper_triang_masked_softmax_forward,
// Softmax functions
m.def("scaled_softmax_forward", &transformer_engine::pytorch::scaled_softmax_forward,
"Scaled Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_softmax_backward", &transformer_engine::pytorch::scaled_softmax_backward,
"Scaled Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_forward",
&transformer_engine::pytorch::scaled_masked_softmax_forward, "Scaled Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_backward",
&transformer_engine::pytorch::scaled_masked_softmax_backward, "Scaled Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_forward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD", py::call_guard<py::gil_scoped_release>()); "Scaled Upper-Triangular Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_backward", &scaled_upper_triang_masked_softmax_backward, m.def("scaled_upper_triang_masked_softmax_backward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD", py::call_guard<py::gil_scoped_release>()); "Scaled Upper-Triangular Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_forward", m.def("scaled_aligned_causal_masked_softmax_forward",
&scaled_aligned_causal_masked_softmax_forward, &transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_forward,
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD", "Scaled Bottom-Right Corner Aligned Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_backward", m.def("scaled_aligned_causal_masked_softmax_backward",
&scaled_aligned_causal_masked_softmax_backward, &transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_backward,
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD", "Scaled Bottom-Right Corner Aligned Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
// Other granular functions // Other granular functions
m.def("layernorm_fwd", &layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"), m.def("layernorm_fwd", &transformer_engine::pytorch::layernorm_fwd, "LayerNorm", py::arg("input"),
py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("sm_margin"), py::arg("zero_centered_gamma")); py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("layernorm_bwd", &layernorm_bwd, "Backward of LayerNorm"); m.def("layernorm_bwd", &transformer_engine::pytorch::layernorm_bwd, "Backward of LayerNorm");
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm", py::arg("input"), py::arg("weight"), py::arg("eps"), m.def("rmsnorm_fwd", &transformer_engine::pytorch::rmsnorm_fwd, "RMSNorm", py::arg("input"),
py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("zero_centered_gamma")); py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm"); m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm");
m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize, m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize,
"Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"), "Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"),
py::arg("quantizer_list"), py::arg("otype")); py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM");
#ifdef USE_ROCM #ifdef USE_ROCM
m.def("te_batchgemm_ts", &te_batchgemm_ts, "Batched GEMM"); /// rocblas m.def("te_batchgemm_ts", &transformer_engine::pytorch::te_batchgemm_ts, "Batched GEMM"); /// rocblas
#endif #endif
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend,
"Get Fused Attention backend", py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &transformer_engine::pytorch::compute_amax,
"Compute absolute max value in tensor", py::arg("input"), py::arg("amax"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax")); m.def("fused_amax_and_scale_update_after_reduction",
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, &transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction", "Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding", m.def("fp8_block_scaling_compute_partial_amax",
py::call_guard<py::gil_scoped_release>()); &transformer_engine::pytorch::fp8_block_scaling_compute_partial_amax,
"Compute partial amax from master weights for fp8 block scaling", py::arg("tensor"),
py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_partial_cast",
&transformer_engine::pytorch::fp8_block_scaling_partial_cast,
"Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"),
py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding,
"Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
// attention kernels // attention kernels
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention", m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd,
py::call_guard<py::gil_scoped_release>()); "Prepare QKV for Flash Attention", py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", m.def("fa_prepare_bwd", &transformer_engine::pytorch::fa_prepare_bwd,
"Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &fused_attn_fwd, m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd, m.def("fused_attn_bwd", &transformer_engine::pytorch::fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("copy_to_kv_cache", &copy_to_kv_cache, "Copy new KV tokens to KV cache"); m.def("copy_to_kv_cache", &transformer_engine::pytorch::copy_to_kv_cache,
m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD"); "Copy new KV tokens to KV cache", py::call_guard<py::gil_scoped_release>());
m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD"); m.def("convert_thd_to_bshd", &transformer_engine::pytorch::convert_thd_to_bshd,
"Convert a tensor from THD to BSHD", py::call_guard<py::gil_scoped_release>());
m.def("convert_bshd_to_thd", &transformer_engine::pytorch::convert_bshd_to_thd,
"Convert a tesnor from BSHD to THD", py::call_guard<py::gil_scoped_release>());
// fused apply rope // fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD", m.def("fused_rope_forward", &transformer_engine::pytorch::fused_rope_forward,
py::call_guard<py::gil_scoped_release>()); "Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD", m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
py::call_guard<py::gil_scoped_release>()); "Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version", m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version,
py::call_guard<py::gil_scoped_release>()); "Get cublasLt version", py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version", m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams); m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams);
#ifdef USE_ROCM #ifdef USE_ROCM
...@@ -246,74 +267,82 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -246,74 +267,82 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif #endif
// Support THD format for Context Parallel // Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor, m.def("thd_read_half_tensor", &transformer_engine::pytorch::thd_read_half_tensor,
"Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD " "Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD "
"tensor", "tensor",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction, m.def("thd_second_half_lse_correction",
&transformer_engine::pytorch::thd_second_half_lse_correction,
"Correct the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>()); "Correct the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_read_second_half_lse", &thd_read_second_half_lse, m.def("thd_read_second_half_lse", &transformer_engine::pytorch::thd_read_second_half_lse,
"Read the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>()); "Read the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_out_correction", &thd_out_correction, m.def("thd_out_correction", &transformer_engine::pytorch::thd_out_correction,
"Correct the THD format output of context parallelism in forward pass", "Correct the THD format output of context parallelism in forward pass",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("thd_grad_correction", &thd_grad_correction, m.def("thd_grad_correction", &transformer_engine::pytorch::thd_grad_correction,
"Correct the THD format gradients of context parallelism in backward pass", "Correct the THD format gradients of context parallelism in backward pass",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices, m.def("thd_get_partitioned_indices", &transformer_engine::pytorch::thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format", "Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
// nvshmem functions // nvshmem functions
m.def("init_nvshmem_backend", &nvshmem_api::init_nvshmem_backend, m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_nvshmem_backend,
"Initialize nvshmem backend with Pytorch distributed process groups", "Initialize nvshmem backend with Pytorch distributed process groups",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("create_nvshmem_tensor", &nvshmem_api::create_nvshmem_tensor, m.def("create_nvshmem_tensor", &transformer_engine::pytorch::create_nvshmem_tensor,
"Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>()); "Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_send_on_current_stream", &nvshmem_api::nvshmem_send_on_current_stream, m.def("nvshmem_send_on_current_stream",
&transformer_engine::pytorch::nvshmem_send_on_current_stream,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream", "Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_wait_on_current_stream", &nvshmem_api::nvshmem_wait_on_current_stream, m.def("nvshmem_wait_on_current_stream",
&transformer_engine::pytorch::nvshmem_wait_on_current_stream,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA " "Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream", "stream",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_finalize", &nvshmem_api::nvshmem_finalize, m.def("nvshmem_finalize", &transformer_engine::pytorch::nvshmem_finalize,
"Clean up and finalize the NVSHMEM communication backend and free associated resources", "Clean up and finalize the NVSHMEM communication backend and free associated resources",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
// multi-tensor functions // multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &transformer_engine::pytorch::multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors", "Fused overflow check + scale for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &transformer_engine::pytorch::multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors", "Computes L2 norm for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda, m.def("multi_tensor_unscale_l2norm",
&transformer_engine::pytorch::multi_tensor_unscale_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only " "Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only "
"performed for L2 norm computation, and tensors are not updated)", "performed for L2 norm computation, and tensors are not updated)",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam", &multi_tensor_adam_cuda, m.def("multi_tensor_adam", &transformer_engine::pytorch::multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer", "Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_param_remainder", &multi_tensor_adam_param_remainder_cuda, m.def("multi_tensor_adam_param_remainder",
&transformer_engine::pytorch::multi_tensor_adam_param_remainder_cuda,
"Compute and apply gradient update to parameters for Adam optimizer" "Compute and apply gradient update to parameters for Adam optimizer"
"where the master parameters only store the remainder bits", "where the master parameters only store the remainder bits",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda, m.def("multi_tensor_adam_fp8", &transformer_engine::pytorch::multi_tensor_adam_fp8_cuda,
"Compute and apply gradient update to parameters for Adam optimizer", "Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda, m.def("multi_tensor_adam_capturable",
&transformer_engine::pytorch::multi_tensor_adam_capturable_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support and LR scheduling", "support and LR scheduling",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda, m.def("multi_tensor_adam_capturable_master",
&transformer_engine::pytorch::multi_tensor_adam_capturable_master_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support, LR scheduling and FP32 master weights", "support, LR scheduling and FP32 master weights",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, m.def("multi_tensor_sgd", &transformer_engine::pytorch::multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors", "Fused SGD optimizer for list of contiguous tensors",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_compute_scale_and_scale_inv", &multi_tensor_compute_scale_and_scale_inv_cuda, m.def("multi_tensor_compute_scale_and_scale_inv",
&transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>()); "Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
// Data structures // Data structures
...@@ -359,10 +388,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -359,10 +388,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false) py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"), .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) py::arg("shape") = std::nullopt);
.def("set_buffer_params", &CommOverlap::set_buffer_params);
py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>, py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>,
transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>(
...@@ -377,8 +405,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -377,8 +405,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
py::arg("use_ce") = true, py::arg("aggregate") = false) py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false) py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"), .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) py::arg("shape") = std::nullopt);
.def("set_buffer_params", &CommOverlapP2P::set_buffer_params);
} }
...@@ -12,10 +12,9 @@ ...@@ -12,10 +12,9 @@
#include "common/common.h" #include "common/common.h"
#include "extensions.h" #include "extensions.h"
void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { namespace transformer_engine::pytorch {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
auto input_tensor = tensor.contiguous(); auto input_tensor = tensor.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
...@@ -23,7 +22,7 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { ...@@ -23,7 +22,7 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
TensorWrapper fake_te_output( TensorWrapper fake_te_output(
nullptr, te_input.shape(), nullptr, te_input.shape(),
transformer_engine::DType::kFloat8E4M3, // It doesn't matter because we only compute amax. DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
amax.data_ptr<float>()); amax.data_ptr<float>());
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
...@@ -33,10 +32,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio ...@@ -33,10 +32,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
std::vector<at::Tensor> amax_histories, std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales, std::vector<at::Tensor> scales,
const std::string& amax_compute_algo, const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype, DType fp8_dtype, float margin) {
float margin) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
size_t num_tensors = amax_histories.size(); size_t num_tensors = amax_histories.size();
std::vector<Tensor> t_amax_histories(num_tensors); std::vector<Tensor> t_amax_histories(num_tensors);
std::vector<Tensor> t_scales(num_tensors); std::vector<Tensor> t_scales(num_tensors);
...@@ -63,3 +59,5 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio ...@@ -63,3 +59,5 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin, amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch
...@@ -6,8 +6,9 @@ ...@@ -6,8 +6,9 @@
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
...@@ -38,8 +39,6 @@ at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { ...@@ -38,8 +39,6 @@ at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) {
at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous(); auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
...@@ -65,8 +64,6 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_r ...@@ -65,8 +64,6 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_r
} }
at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) { at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
...@@ -105,8 +102,6 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, floa ...@@ -105,8 +102,6 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, floa
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous(); auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
...@@ -132,8 +127,6 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor so ...@@ -132,8 +127,6 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor so
} }
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
...@@ -159,8 +152,6 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float sc ...@@ -159,8 +152,6 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float sc
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grads_.contiguous(); auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
...@@ -188,7 +179,6 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, ...@@ -188,7 +179,6 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
} }
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
...@@ -220,8 +210,6 @@ at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float ...@@ -220,8 +210,6 @@ at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float
at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_, at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous(); auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
...@@ -245,3 +233,5 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_ ...@@ -245,3 +233,5 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_
return output_grads; return output_grads;
} }
} // namespace transformer_engine::pytorch
...@@ -13,13 +13,12 @@ namespace transformer_engine::pytorch { ...@@ -13,13 +13,12 @@ namespace transformer_engine::pytorch {
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list, std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list, std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, std::vector<py::handle> quantizer_list, DType otype) {
transformer_engine::DType otype) {
init_extension(); init_extension();
std::vector<NVTETensor> nvte_tensor_input_list; std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list; std::vector<NVTETensor> nvte_tensor_output_list;
std::vector<py::object> py_output_objects_list; std::vector<py::object> py_output_objects_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers; std::vector<TensorWrapper> tensor_wrappers;
if (output_list.has_value()) { if (output_list.has_value()) {
py_output_objects_list = output_list.value(); py_output_objects_list = output_list.value();
} }
...@@ -33,7 +32,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list, ...@@ -33,7 +32,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
auto input_tensor = makeTransformerEngineTensor(input_list[i]); auto input_tensor = makeTransformerEngineTensor(input_list[i]);
const NVTEShape input_shape = input_tensor.shape(); const NVTEShape input_shape = input_tensor.shape();
transformer_engine::TensorWrapper output_tensor; TensorWrapper output_tensor;
if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) { if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) {
with_fused_kernel = false; with_fused_kernel = false;
...@@ -68,8 +67,10 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list, ...@@ -68,8 +67,10 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
// Launch TE kernel // Launch TE kernel
if (with_fused_kernel) { if (with_fused_kernel) {
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), NVTE_SCOPED_GIL_RELEASE({
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
});
} else { } else {
for (size_t i = 0; i < py_output_objects_list.size(); i++) { for (size_t i = 0; i < py_output_objects_list.size(); i++) {
quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt); quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt);
...@@ -78,8 +79,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list, ...@@ -78,8 +79,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
return py_output_objects_list; return py_output_objects_list;
} }
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) {
std::optional<at::Tensor> output) {
init_extension(); init_extension();
const auto dim = input.dim(); const auto dim = input.dim();
...@@ -100,8 +100,8 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, ...@@ -100,8 +100,8 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
} }
if (M == 0 || N == 0) return out; if (M == 0 || N == 0) return out;
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector<size_t>{M, N}, otype);
auto output_cu = makeTransformerEngineTensor(out.data_ptr(), {N, M}, otype); auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector<size_t>{N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_
#include <Python.h>
#include <pybind11/detail/common.h> #include <pybind11/detail/common.h>
#include <pybind11/functional.h> #include <pybind11/functional.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
...@@ -18,6 +20,16 @@ ...@@ -18,6 +20,16 @@
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
#define NVTE_SCOPED_GIL_RELEASE(code_block) \
do { \
if (PyGILState_Check()) { \
pybind11::gil_scoped_release _gil_release; \
code_block \
} else { \
code_block \
} \
} while (false);
extern PyTypeObject *Float8TensorPythonClass; extern PyTypeObject *Float8TensorPythonClass;
extern PyTypeObject *Float8TensorBasePythonClass; extern PyTypeObject *Float8TensorBasePythonClass;
extern PyTypeObject *Float8QuantizerClass; extern PyTypeObject *Float8QuantizerClass;
......
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
#include "common.h" #include "common.h"
#include "pybind.h" #include "pybind.h"
#include "torch/torch.h" #include "torch/torch.h"
#include "util.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
...@@ -103,7 +102,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor( ...@@ -103,7 +102,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
} }
const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data; at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) { if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts); columnwise_data = at::empty(columnwise_torch_shape, opts);
} }
...@@ -215,7 +214,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso ...@@ -215,7 +214,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
} }
const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data; at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) { if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts); columnwise_data = at::empty(columnwise_torch_shape, opts);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment