Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
......@@ -52,7 +52,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
......@@ -69,7 +71,10 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
......@@ -77,8 +82,10 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
return out;
}
......@@ -96,7 +103,9 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype)
auto [out_tensor, out] = q.create_tensor(shape, otype);
NVTE_SCOPED_GIL_RELEASE({
nvte_dequantize(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream());
});
return out;
}
......@@ -120,15 +129,19 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({
func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
return {py::cast(grad_bias), dact};
}
......
......@@ -141,81 +141,79 @@ CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType
num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm,
set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {}
void CommOverlap::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
_ubuf_scale_inv_initialized = true;
}
/*
** Helper function to copy input to _ubuf
*/
void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer);
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!");
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr());
void CommOverlap::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
const auto &input_ = input.contiguous();
// Check element size
const size_t element_size = input.element_size();
NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
const size_t ubuf_size = _ubuf.numel();
void *dst_ptr = _ubuf.dptr();
if (local_chunk) {
if (input_tensor.numel() * _tp_size > _ubuf.numel())
NVTE_ERROR("input is larger than the local communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size();
NVTE_CHECK(input_size * _tp_size == ubuf_size,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
"(input_size=", input_size, ", tensor_parallel_size=", _tp_size,
", ubuf_size=", ubuf_size, ")");
dst_ptr = (reinterpret_cast<char *>(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size);
} else {
if (input_tensor.numel() > _ubuf.numel())
NVTE_ERROR("input is larger than the global communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK(input_size == ubuf_size,
"Tried to copy an invalid tensor into a Userbuffers buffer ",
"(input_size=", input_size, ", ubuf_size=", ubuf_size, ")");
}
// Copy either row or columnwise data into the communication buffer's columnwise data
// NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with
// the Userbuffers communicator.
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
// Copy data
auto stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input_tensor.dptr(),
input_tensor.numel() * input_tensor.element_size(),
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
}
py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk,
std::optional<const std::vector<int64_t>> shape) {
using namespace te::pytorch;
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.dptr());
if (local_chunk) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
std::vector<int64_t> torch_shape;
if (shape.has_value()) {
torch_shape = shape.value();
size_t requested = product(torch_shape);
auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel();
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested,
") does not match allocated buffer size (", expected, ")!");
at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
// Check buffer shape
const size_t ubuf_size = _ubuf.numel();
if (shape) {
const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
NVTE_CHECK(requested_size * _tp_size == ubuf_size,
"Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
", tensor_parallel_size=", _tp_size, ", ubuf_size=", ubuf_size, ")");
} else {
NVTE_CHECK(requested_size == ubuf_size,
"Invalid shape for a Userbuffers buffer (requested shape=", *shape,
", ubuf_size=", ubuf_size, ")");
}
} else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0);
int64_t output_c_dim1 = _ubuf.size(1);
torch_shape = {output_c_dim0, output_c_dim1};
int64_t dim0 = _ubuf.size(0);
int64_t dim1 = _ubuf.size(1);
if (local_chunk) {
dim0 /= _tp_size;
}
shape = {dim0, dim1};
}
// Data pointer
void *ubuf_ptr = _ubuf.dptr();
if (local_chunk) {
ubuf_ptr = (reinterpret_cast<char *>(ubuf_ptr) +
(ubuf_size / _tp_size) * _tp_id * _ubuf.element_size());
}
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape,
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA));
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
std::vector<size_t> te_shape;
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s));
// Always output a rowwise-only QuantizedTensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto is_internal = my_quantizer->internal;
auto uses_columnwise = my_quantizer->columnwise_usage;
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false;
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor);
my_quantizer->internal = is_internal;
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
// Construct PyTorch tensor
const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}
/***************************************************************************************************
......@@ -236,74 +234,69 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal
comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm, aggregate) {}
void CommOverlapP2P::set_buffer_params(py::handle quantizer) {
std::unique_ptr<te::pytorch::Quantizer> my_quantizer = te::pytorch::convert_quantizer(quantizer);
my_quantizer->set_quantization_params(&_ubuf);
for (size_t i = 0; i < _ubufs.size(); i++) my_quantizer->set_quantization_params(&_ubufs[i]);
}
/*
** Copy input to _ubufs[0]
*/
void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) {
auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer);
auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr();
NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!");
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) {
const auto &input_ = input.contiguous();
// Check element size
const size_t element_size = input.element_size();
NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t input_size = input_.numel();
const void *src_ptr = input_.data_ptr();
// Userbuffers data
void *dst_ptr;
if (local_chunk) {
// Copy input to the target ubuf chunk by rank offset
if (input_tensor.numel() * _tp_size > _ubuf.numel())
NVTE_ERROR("input is larger than the local communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
NVTE_CHECK(_ubufs[_tp_id].numel() == input_size,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
"(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
dst_ptr = _ubufs[_tp_id].dptr();
} else {
if (input_tensor.numel() > _ubuf.numel())
NVTE_ERROR("input is larger than the global communication buffer!");
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
NVTE_CHECK(_ubuf.numel() == input_size,
"Tried to copy an invalid tensor into a Userbuffers buffer ",
"(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")");
dst_ptr = _ubuf.dptr();
}
// Copy data
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size,
cudaMemcpyDeviceToDevice,
(cudaStream_t)at::cuda::getCurrentCUDAStream()));
}
py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk,
std::optional<const std::vector<int64_t>> shape) {
using namespace te::pytorch;
char *ubuf_wt_ptr = reinterpret_cast<char *>(local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr());
std::vector<int64_t> torch_shape;
if (shape.has_value()) {
torch_shape = shape.value();
size_t requested = product(torch_shape);
auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel();
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested,
") does not match allocated buffer size (", expected, ")!");
at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vector<int64_t>> shape) {
// Check buffer shape
if (shape) {
const size_t requested_size = transformer_engine::pytorch::product(*shape);
if (local_chunk) {
NVTE_CHECK(requested_size == _ubufs[_tp_id].numel(),
"Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape,
", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
} else {
int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0);
int64_t output_c_dim1 = _ubuf.size(1);
torch_shape = {output_c_dim0, output_c_dim1};
NVTE_CHECK(requested_size == _ubuf.numel(),
"Invalid shape for a Userbuffers buffer (requested shape=", *shape,
", ubuf_size=", _ubuf.numel(), ")");
}
auto ubuf_tensor = torch::from_blob(reinterpret_cast<void *>(ubuf_wt_ptr), torch_shape,
at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA));
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
std::vector<size_t> te_shape;
for (auto s : torch_shape) te_shape.emplace_back(static_cast<size_t>(s));
// Always output a rowwise-only QuantizedTensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto is_internal = my_quantizer->internal;
auto uses_columnwise = my_quantizer->columnwise_usage;
my_quantizer->internal = false;
my_quantizer->columnwise_usage = false;
auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor);
my_quantizer->internal = is_internal;
my_quantizer->columnwise_usage = uses_columnwise;
return py_tensor;
} else {
int64_t dim0 = _ubuf.size(0);
int64_t dim1 = _ubuf.size(1);
if (local_chunk) {
dim0 /= _tp_size;
}
shape = {dim0, dim1};
}
// Data pointer
void *ubuf_ptr = local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr();
// Construct PyTorch tensor
const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len) {
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor");
TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor");
TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16,
"tensor must be a float or bfloat16 tensor");
const TensorWrapper tensor_cu = makeTransformerEngineTensor(tensor);
TensorWrapper amax_cu = makeTransformerEngineTensor(amax);
nvte_fp8_block_scaling_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w,
amax.stride(0), amax.stride(1), start_offset,
block_len, at::cuda::getCurrentCUDAStream());
}
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype) {
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor");
TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor");
TORCH_CHECK(
inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16,
"input must be a float or bfloat16 tensor");
TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor");
TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 ||
out_dtype == transformer_engine::DType::kFloat8E5M2,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2");
const TensorWrapper inp_cu = makeTransformerEngineTensor(inp);
TensorWrapper out_cu = makeTransformerEngineTensor(out);
const TensorWrapper scale_cu = makeTransformerEngineTensor(scale);
nvte_fp8_block_scaling_partial_cast(
inp_cu.data(), out_cu.data(), scale_cu.data(), h, w, scale.stride(0), scale.stride(1),
start_offset, block_len, static_cast<NVTEDType>(out_dtype), at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -4,7 +4,6 @@
* See LICENSE for license information.
************************************************************************/
#include <Python.h>
#include <pybind11/pybind11.h>
#include <optional>
......@@ -21,12 +20,12 @@
namespace {
void* get_data_ptr(MaybeTensor tensor) {
void* get_data_ptr(transformer_engine::pytorch::MaybeTensor tensor) {
if (tensor.has_value()) return tensor->data_ptr();
return nullptr;
}
size_t get_size(MaybeTensor tensor, int dim) {
size_t get_size(transformer_engine::pytorch::MaybeTensor tensor, int dim) {
if (tensor.has_value()) return static_cast<size_t>(tensor->size(dim));
return 0;
}
......@@ -167,8 +166,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type);
// Workspace
auto te_workspace =
makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte);
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
......@@ -197,38 +196,52 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Direct GEMM call to the correct overlap
if (bulk_overlap) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor,
te_pre_gelu_out, te_workspace, grad, accumulate,
use_split_accumulator, comm_type.value(), extra_output_tensor,
main_stream);
});
} else if (comm_type.value() == CommOverlapType::AG) {
if (comm_overlap->is_atomic_gemm()) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator,
extra_output_tensor, main_stream);
});
} else {
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor,
te_pre_gelu_out, te_workspace, grad, accumulate,
use_split_accumulator, extra_output_tensor, main_stream);
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator, extra_output_tensor,
main_stream);
});
}
} else {
if (comm_overlap->is_atomic_gemm()) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator,
extra_output_tensor, main_stream);
});
} else {
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor,
te_pre_gelu_out, te_workspace, grad, accumulate,
use_split_accumulator, extra_output_tensor, main_stream);
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator, extra_output_tensor,
main_stream);
});
}
}
} else {
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, main_stream);
});
}
} else {
if (D_tensor.numel() != 0 && !accumulate) {
......@@ -258,20 +271,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
return out;
}
} // namespace transformer_engine::pytorch
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, transformer_engine::DType B_type,
std::vector<int64_t> B_scaling_mode, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate,
at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
bool transb, at::Tensor D, at::Tensor D_scale, DType D_type, at::Tensor D_amax,
at::Tensor bias, DType bias_type, at::Tensor pre_gelu_out, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
// TODO: Handle scaling modes
NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING;
NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING;
......@@ -286,12 +293,13 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
nvte_scaling_modeB);
// TODO: D_scale_inv cannot be nullptr when D_type is FP8.
auto te_D = makeTransformerEngineTensor(
D.data_ptr(), {static_cast<size_t>(D.size(0)), static_cast<size_t>(D.size(1))}, D_type,
D.data_ptr(),
std::vector<size_t>{static_cast<size_t>(D.size(0)), static_cast<size_t>(D.size(1))}, D_type,
D_amax.data_ptr(), D_scale.data_ptr(), nullptr);
auto te_bias =
makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))}, bias_type);
auto te_bias = makeTransformerEngineTensor(
bias.data_ptr(), std::vector<size_t>{static_cast<size_t>(bias.size(0))}, bias_type);
auto te_counter = makeTransformerEngineTensor(
counter.data_ptr(), {static_cast<size_t>(counter.size(0))}, DType::kInt32);
counter.data_ptr(), std::vector<size_t>{static_cast<size_t>(counter.size(0))}, DType::kInt32);
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
......@@ -299,24 +307,23 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
static_cast<size_t>(pre_gelu_out.size(1))};
auto te_pre_gelu_out = makeTransformerEngineTensor(
pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type()));
auto te_workspace =
makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte);
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream());
});
}
std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
std::optional<std::vector<at::Tensor>> D, DType D_type, std::vector<int64_t> m_splits,
std::vector<at::Tensor> bias, DType bias_type, bool single_output,
std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) {
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers;
......@@ -419,16 +426,19 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
wrappers.emplace_back(std::move(te_pre_gelu_out));
}
for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte);
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp));
}
// For now, we only have multi-stream cublas backend.
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_A_vector.size(), transa, transb, grad,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
});
return bias;
}
......@@ -534,3 +544,5 @@ std::vector<at::Tensor> te_batchgemm_ts(
}
#endif
} // namespace transformer_engine::pytorch
......@@ -6,6 +6,8 @@
#include "extensions.h"
namespace transformer_engine::pytorch {
#ifdef USE_ROCM
size_t get_cublasLt_version() { int version = 10000000; return version; }
......@@ -15,3 +17,5 @@ size_t get_cublasLt_version() { return cublasLtGetVersion(); }
size_t get_cudnn_version() { return cudnnGetVersion(); }
#endif
} // namespace transformer_engine::pytorch
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction,
weight_decay, device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_param_remainder_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1,
beta2, epsilon, step, mode, bias_correction, weight_decay, device_id,
at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(),
num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode,
bias_correction, weight_decay, static_cast<NVTEDType>(fp8_dtype),
device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8,
force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
auto ret = at::empty({1}, output.options());
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
output_per_tensor = at::empty({0}, float_options);
ret_per_tensor = at::empty({0}, float_options);
}
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto output_cu = makeTransformerEngineTensor(output);
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, output_cu.data(), output_per_tensor_cu.data(),
ret_cu.data(), ret_per_tensor_cu.data(), per_tensor,
max_chunks_per_tensor, device_id, at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
// Create output tensors for multi scale L2 norm kernel.
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
output_per_tensor = at::empty({0}, float_options);
ret_per_tensor = at::empty({0}, float_options);
}
auto ret = at::empty({1}, output.options());
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto output_cu = makeTransformerEngineTensor(output);
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_unscale_l2norm_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(),
inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, device_id,
at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
} // namespace transformer_engine::pytorch
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, wd, momentum, dampening, lr, nesterov, first_run,
wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -9,28 +9,11 @@
#include "pybind.h"
namespace transformer_engine::pytorch {
std::pair<TensorWrapper, py::object> createOutputTensor(const NVTEShape &shape, DType dtype,
py::handle quantizer) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; i++) {
size_t t = shape.data[i];
shape_vec.push_back(t);
}
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape_vec, dtype);
}
std::pair<TensorWrapper, py::object> createOutputTensor(std::vector<size_t> &shape, DType dtype,
py::handle quantizer) {
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape, dtype);
}
} // namespace transformer_engine::pytorch
std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous();
const auto &mu_ = mu.contiguous();
......@@ -40,7 +23,7 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
auto dbeta = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace;
TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_);
......@@ -52,10 +35,12 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dbeta_cu = makeTransformerEngineTensor(dbeta);
// This call populates tensors with the required config.
NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -63,10 +48,12 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Actual call to bwd kernel.
NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
return {py::cast(dx), py::cast(dgamma), py::cast(dbeta)};
}
......@@ -76,8 +63,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
// Input and param tensors
auto none = py::none();
......@@ -131,11 +116,13 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size
transformer_engine::TensorWrapper workspace;
TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Allocate workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -143,10 +130,12 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
......@@ -154,7 +143,10 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(),
at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
......@@ -169,7 +161,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
}
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
......@@ -177,8 +171,10 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
}
return {out, py::cast(mu), py::cast(rsigma)};
......@@ -187,7 +183,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &rsigma, const at::Tensor &gamma,
const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous();
const auto &rsigma_ = rsigma.contiguous();
......@@ -195,7 +190,7 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace;
TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_);
......@@ -205,10 +200,12 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
// This call populates tensors with the required config.
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -216,21 +213,20 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Actual call to bwd kernel.
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
return {py::cast(dx), py::cast(dgamma)};
}
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object out, py::handle quantizer,
transformer_engine::DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) {
py::object out, py::handle quantizer, DType out_dtype,
const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
// Input and param tensors
auto none = py::none();
......@@ -278,11 +274,13 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size
transformer_engine::TensorWrapper workspace;
TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Allocate workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -290,10 +288,12 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
......@@ -301,7 +301,10 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(),
at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
......@@ -316,7 +319,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
}
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
......@@ -324,9 +329,13 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
}
return {out, py::none(), py::cast(rsigma)};
}
} // namespace transformer_engine::pytorch
......@@ -17,7 +17,8 @@
#include <torch/cuda.h>
#include <torch/extension.h>
namespace nvshmem_api {
namespace transformer_engine::pytorch {
void init_nvshmem_backend(c10d::ProcessGroup *process_group) {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t attr = {};
......@@ -126,4 +127,5 @@ void nvshmem_finalize() {
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
} // namespace nvshmem_api
} // namespace transformer_engine::pytorch
......@@ -5,13 +5,13 @@
************************************************************************/
#include "extensions.h"
#include "pybind.h"
namespace transformer_engine::pytorch {
void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(),
"Number of input row list and padded row list must match.");
NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2.");
......@@ -21,7 +21,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, output_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, output_shape_list;
std::vector<transformer_engine::DType> input_type_list;
std::vector<DType> input_type_list;
void* d_input_ptr = reinterpret_cast<void*>(input.data_ptr());
void* d_output_ptr = reinterpret_cast<void*>(output.data_ptr());
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
......@@ -51,9 +51,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list, nvte_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
std::vector<TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype) -> NVTETensor {
DType dtype) -> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype));
return tensor_wrappers.back().data();
};
......@@ -75,6 +75,10 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
"Number of input and padded row list must match");
// Launch TE kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(),
padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream());
});
}
} // namespace transformer_engine::pytorch
......@@ -4,14 +4,13 @@
* See LICENSE for license information.
************************************************************************/
#include <cub/cub.cuh>
#include "extensions.h"
namespace transformer_engine::pytorch {
std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices,
int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) {
using namespace transformer_engine::pytorch;
at::Tensor input, const DType dtype, at::Tensor indices, int64_t num_out_tokens,
std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) {
const int num_tokens = input.size(0);
int num_cols = input.size(1);
const int topK = indices.size(1);
......@@ -28,9 +27,8 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
size_t temp_storage_bytes = 0;
int *temp_ptr = nullptr;
cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_ptr, temp_ptr, temp_ptr,
temp_ptr, max_expanded_token_num);
nvte_device_radix_sort_pairs(nullptr, &temp_storage_bytes, nullptr, nullptr, nullptr, nullptr,
max_expanded_token_num);
at::Tensor temp_storage = torch::empty(
temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
......@@ -40,17 +38,18 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
workspace.push_back(temp_storage);
}
int *indices_ptr = reinterpret_cast<int *>(getDataPtr(indices, 0));
int *sorted_indices_ptr = reinterpret_cast<int *>(getDataPtr(workspace[0], 0));
int *row_id_ptr = reinterpret_cast<int *>(getDataPtr(workspace[1], 0));
int *sorted_row_id_ptr = reinterpret_cast<int *>(getDataPtr(workspace[2], 0));
void *indices_ptr = getDataPtr(indices, 0);
void *sorted_indices_ptr = getDataPtr(workspace[0], 0);
void *row_id_ptr = getDataPtr(workspace[1], 0);
void *sorted_row_id_ptr = getDataPtr(workspace[2], 0);
void *d_temp_storage = getDataPtr(workspace[3], 0);
size_t temp_storage_bytes = std::numeric_limits<size_t>::max();
cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, indices_ptr,
sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr,
num_tokens * topK);
nvte_device_radix_sort_pairs(
d_temp_storage, &temp_storage_bytes, reinterpret_cast<int *>(indices_ptr),
reinterpret_cast<int *>(sorted_indices_ptr), reinterpret_cast<int *>(row_id_ptr),
reinterpret_cast<int *>(sorted_row_id_ptr), num_tokens * topK);
// Output buffer alloc
num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
......@@ -63,34 +62,33 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_cu = makeTransformerEngineTensor(
input.data_ptr(), {static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)}, dtype);
auto permuted_output_cu = makeTransformerEngineTensor(
permuted_output.data_ptr(),
{static_cast<size_t>(permuted_output.size(0)), static_cast<size_t>(num_cols)}, dtype);
auto sorted_row_id_cu =
makeTransformerEngineTensor(sorted_row_id_ptr, {static_cast<size_t>(num_tokens * topK)},
transformer_engine::DType::kInt32);
input.data_ptr(),
std::vector<size_t>{static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)},
dtype);
auto permuted_output_cu =
makeTransformerEngineTensor(permuted_output.data_ptr(),
std::vector<size_t>{static_cast<size_t>(permuted_output.size(0)),
static_cast<size_t>(num_cols)},
dtype);
auto sorted_row_id_cu = makeTransformerEngineTensor(
sorted_row_id_ptr, std::vector<size_t>{static_cast<size_t>(num_tokens * topK)},
DType::kInt32);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(),
row_id_map_cu.data(), transformer_engine::TensorWrapper().data(),
transformer_engine::TensorWrapper().data(),
transformer_engine::TensorWrapper().data(), num_tokens, topK, num_cols,
num_out_tokens, stream);
row_id_map_cu.data(), TensorWrapper().data(), TensorWrapper().data(),
TensorWrapper().data(), num_tokens, topK, num_cols, num_out_tokens, stream);
return std::make_tuple(permuted_output, row_id_map, workspace);
}
at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK) {
at::Tensor moe_permute_bwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor prob, int64_t num_tokens, int64_t topK) {
return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK);
}
at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK) {
using namespace transformer_engine::pytorch;
at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor prob, int64_t num_tokens, int64_t topK) {
int num_cols = input.size(1);
// Output buffer alloc
......@@ -101,10 +99,14 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_cu = makeTransformerEngineTensor(
input.data_ptr(), {static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)}, dtype);
input.data_ptr(),
std::vector<size_t>{static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)},
dtype);
auto unpermuted_output_cu = makeTransformerEngineTensor(
unpermuted_output.data_ptr(),
{static_cast<size_t>(unpermuted_output.size(0)), static_cast<size_t>(num_cols)}, dtype);
std::vector<size_t>{static_cast<size_t>(unpermuted_output.size(0)),
static_cast<size_t>(num_cols)},
dtype);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
auto prob_cu = makeTransformerEngineTensor(prob);
......@@ -115,9 +117,8 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
}
std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob) {
using namespace transformer_engine::pytorch;
const DType dtype, at::Tensor row_id_map,
at::Tensor prob) {
const int topK = (prob.numel() > 0) ? prob.size(1) : 1;
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1);
......@@ -132,21 +133,26 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_bwd_cu = makeTransformerEngineTensor(
input_bwd.data_ptr(), {static_cast<size_t>(input_bwd.size(0)), static_cast<size_t>(num_cols)},
input_bwd.data_ptr(),
std::vector<size_t>{static_cast<size_t>(input_bwd.size(0)), static_cast<size_t>(num_cols)},
dtype);
auto act_grad_cu = makeTransformerEngineTensor(
act_grad.data_ptr(), {static_cast<size_t>(act_grad.size(0)), static_cast<size_t>(num_cols)},
act_grad.data_ptr(),
std::vector<size_t>{static_cast<size_t>(act_grad.size(0)), static_cast<size_t>(num_cols)},
dtype);
auto input_fwd_cu = makeTransformerEngineTensor(
input_fwd.data_ptr(), {static_cast<size_t>(input_fwd.size(0)), static_cast<size_t>(num_cols)},
input_fwd.data_ptr(),
std::vector<size_t>{static_cast<size_t>(input_fwd.size(0)), static_cast<size_t>(num_cols)},
dtype);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
auto prob_cu = makeTransformerEngineTensor(prob);
auto prob_grad_cu = makeTransformerEngineTensor(prob_grad);
nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), transformer_engine::TensorWrapper().data(),
nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), TensorWrapper().data(),
row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(),
num_tokens, topK, num_cols, 0, stream);
return std::make_tuple(act_grad, prob_grad);
}
} // namespace transformer_engine::pytorch
......@@ -6,7 +6,6 @@
#include "pybind.h"
#include <Python.h>
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
......@@ -111,10 +110,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
m.def("rowwise_swizzle", &rowwise_swizzle, "Swizzle rowwise scale inverses.",
py::call_guard<py::gil_scoped_release>());
m.def("columnwise_swizzle", &columnwise_swizzle, "Swizzle columnwise scale inverses.",
py::call_guard<py::gil_scoped_release>());
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
......@@ -160,85 +155,111 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("quantizer"));
// Permutation functions
m.def("moe_permute_fwd", moe_permute_fwd);
m.def("moe_permute_bwd", moe_permute_bwd);
m.def("moe_unpermute_fwd", moe_unpermute_fwd);
m.def("moe_unpermute_bwd", moe_unpermute_bwd);
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD",
m.def("moe_permute_fwd", transformer_engine::pytorch::moe_permute_fwd, "MOE permute FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD",
m.def("moe_permute_bwd", transformer_engine::pytorch::moe_permute_bwd, "MOE permute BWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward,
"Scaled Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward,
"Scaled Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_forward", &scaled_upper_triang_masked_softmax_forward,
m.def("moe_unpermute_fwd", transformer_engine::pytorch::moe_unpermute_fwd, "MOE unpermute FWD",
py::call_guard<py::gil_scoped_release>());
m.def("moe_unpermute_bwd", transformer_engine::pytorch::moe_unpermute_bwd, "MOE unpermute BWD",
py::call_guard<py::gil_scoped_release>());
// Softmax functions
m.def("scaled_softmax_forward", &transformer_engine::pytorch::scaled_softmax_forward,
"Scaled Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_softmax_backward", &transformer_engine::pytorch::scaled_softmax_backward,
"Scaled Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_forward",
&transformer_engine::pytorch::scaled_masked_softmax_forward, "Scaled Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_backward",
&transformer_engine::pytorch::scaled_masked_softmax_backward, "Scaled Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_forward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_backward", &scaled_upper_triang_masked_softmax_backward,
m.def("scaled_upper_triang_masked_softmax_backward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_forward",
&scaled_aligned_causal_masked_softmax_forward,
&transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_forward,
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_backward",
&scaled_aligned_causal_masked_softmax_backward,
&transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_backward,
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());
// Other granular functions
m.def("layernorm_fwd", &layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"),
py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"),
py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("layernorm_bwd", &layernorm_bwd, "Backward of LayerNorm");
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm", py::arg("input"), py::arg("weight"), py::arg("eps"),
py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"),
py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm");
m.def("layernorm_fwd", &transformer_engine::pytorch::layernorm_fwd, "LayerNorm", py::arg("input"),
py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("layernorm_bwd", &transformer_engine::pytorch::layernorm_bwd, "Backward of LayerNorm");
m.def("rmsnorm_fwd", &transformer_engine::pytorch::rmsnorm_fwd, "RMSNorm", py::arg("input"),
py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm");
m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize,
"Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"),
py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM");
#ifdef USE_ROCM
m.def("te_batchgemm_ts", &te_batchgemm_ts, "Batched GEMM"); /// rocblas
m.def("te_batchgemm_ts", &transformer_engine::pytorch::te_batchgemm_ts, "Batched GEMM"); /// rocblas
#endif
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend,
"Get Fused Attention backend", py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &transformer_engine::pytorch::compute_amax,
"Compute absolute max value in tensor", py::arg("input"), py::arg("amax"),
py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax"));
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
m.def("fused_amax_and_scale_update_after_reduction",
&transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding",
py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_compute_partial_amax",
&transformer_engine::pytorch::fp8_block_scaling_compute_partial_amax,
"Compute partial amax from master weights for fp8 block scaling", py::arg("tensor"),
py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_partial_cast",
&transformer_engine::pytorch::fp8_block_scaling_partial_cast,
"Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"),
py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding,
"Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
// attention kernels
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd,
"Prepare QKV for Flash Attention", py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &transformer_engine::pytorch::fa_prepare_bwd,
"Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &fused_attn_fwd,
m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
m.def("fused_attn_bwd", &transformer_engine::pytorch::fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("copy_to_kv_cache", &copy_to_kv_cache, "Copy new KV tokens to KV cache");
m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD");
m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD");
m.def("copy_to_kv_cache", &transformer_engine::pytorch::copy_to_kv_cache,
"Copy new KV tokens to KV cache", py::call_guard<py::gil_scoped_release>());
m.def("convert_thd_to_bshd", &transformer_engine::pytorch::convert_thd_to_bshd,
"Convert a tensor from THD to BSHD", py::call_guard<py::gil_scoped_release>());
m.def("convert_bshd_to_thd", &transformer_engine::pytorch::convert_bshd_to_thd,
"Convert a tesnor from BSHD to THD", py::call_guard<py::gil_scoped_release>());
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_forward", &transformer_engine::pytorch::fused_rope_forward,
"Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
"Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version",
py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version",
m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version,
"Get cublasLt version", py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>());
m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams);
#ifdef USE_ROCM
......@@ -246,74 +267,82 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif
// Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor,
m.def("thd_read_half_tensor", &transformer_engine::pytorch::thd_read_half_tensor,
"Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD "
"tensor",
py::call_guard<py::gil_scoped_release>());
m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction,
m.def("thd_second_half_lse_correction",
&transformer_engine::pytorch::thd_second_half_lse_correction,
"Correct the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_read_second_half_lse", &thd_read_second_half_lse,
m.def("thd_read_second_half_lse", &transformer_engine::pytorch::thd_read_second_half_lse,
"Read the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_out_correction", &thd_out_correction,
m.def("thd_out_correction", &transformer_engine::pytorch::thd_out_correction,
"Correct the THD format output of context parallelism in forward pass",
py::call_guard<py::gil_scoped_release>());
m.def("thd_grad_correction", &thd_grad_correction,
m.def("thd_grad_correction", &transformer_engine::pytorch::thd_grad_correction,
"Correct the THD format gradients of context parallelism in backward pass",
py::call_guard<py::gil_scoped_release>());
m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices,
m.def("thd_get_partitioned_indices", &transformer_engine::pytorch::thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>());
// nvshmem functions
m.def("init_nvshmem_backend", &nvshmem_api::init_nvshmem_backend,
m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_nvshmem_backend,
"Initialize nvshmem backend with Pytorch distributed process groups",
py::call_guard<py::gil_scoped_release>());
m.def("create_nvshmem_tensor", &nvshmem_api::create_nvshmem_tensor,
m.def("create_nvshmem_tensor", &transformer_engine::pytorch::create_nvshmem_tensor,
"Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_send_on_current_stream", &nvshmem_api::nvshmem_send_on_current_stream,
m.def("nvshmem_send_on_current_stream",
&transformer_engine::pytorch::nvshmem_send_on_current_stream,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_wait_on_current_stream", &nvshmem_api::nvshmem_wait_on_current_stream,
m.def("nvshmem_wait_on_current_stream",
&transformer_engine::pytorch::nvshmem_wait_on_current_stream,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_finalize", &nvshmem_api::nvshmem_finalize,
m.def("nvshmem_finalize", &transformer_engine::pytorch::nvshmem_finalize,
"Clean up and finalize the NVSHMEM communication backend and free associated resources",
py::call_guard<py::gil_scoped_release>());
// multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
m.def("multi_tensor_scale", &transformer_engine::pytorch::multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
m.def("multi_tensor_l2norm", &transformer_engine::pytorch::multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda,
m.def("multi_tensor_unscale_l2norm",
&transformer_engine::pytorch::multi_tensor_unscale_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only "
"performed for L2 norm computation, and tensors are not updated)",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
m.def("multi_tensor_adam", &transformer_engine::pytorch::multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_param_remainder", &multi_tensor_adam_param_remainder_cuda,
m.def("multi_tensor_adam_param_remainder",
&transformer_engine::pytorch::multi_tensor_adam_param_remainder_cuda,
"Compute and apply gradient update to parameters for Adam optimizer"
"where the master parameters only store the remainder bits",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda,
m.def("multi_tensor_adam_fp8", &transformer_engine::pytorch::multi_tensor_adam_fp8_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda,
m.def("multi_tensor_adam_capturable",
&transformer_engine::pytorch::multi_tensor_adam_capturable_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support and LR scheduling",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda,
m.def("multi_tensor_adam_capturable_master",
&transformer_engine::pytorch::multi_tensor_adam_capturable_master_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support, LR scheduling and FP32 master weights",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
m.def("multi_tensor_sgd", &transformer_engine::pytorch::multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_compute_scale_and_scale_inv", &multi_tensor_compute_scale_and_scale_inv_cuda,
m.def("multi_tensor_compute_scale_and_scale_inv",
&transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
// Data structures
......@@ -359,10 +388,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"),
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt)
.def("set_buffer_params", &CommOverlap::set_buffer_params);
py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt);
py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>,
transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>(
......@@ -377,8 +405,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("quantizer"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"),
py::arg("local_chunk") = false, py::arg("shape") = std::nullopt)
.def("set_buffer_params", &CommOverlapP2P::set_buffer_params);
py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt);
}
......@@ -12,10 +12,9 @@
#include "common/common.h"
#include "extensions.h"
void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
namespace transformer_engine::pytorch {
void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
auto input_tensor = tensor.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
......@@ -23,7 +22,7 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
TensorWrapper fake_te_output(
nullptr, te_input.shape(),
transformer_engine::DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
amax.data_ptr<float>());
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
......@@ -33,10 +32,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
DType fp8_dtype, float margin) {
size_t num_tensors = amax_histories.size();
std::vector<Tensor> t_amax_histories(num_tensors);
std::vector<Tensor> t_scales(num_tensors);
......@@ -63,3 +59,5 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin,
at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -6,8 +6,9 @@
#include "extensions.h"
namespace transformer_engine::pytorch {
at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
......@@ -38,8 +39,6 @@ at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) {
at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
......@@ -65,8 +64,6 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_r
}
at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
......@@ -105,8 +102,6 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, floa
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
......@@ -132,8 +127,6 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor so
}
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
......@@ -159,8 +152,6 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float sc
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
......@@ -188,7 +179,6 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
}
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
......@@ -220,8 +210,6 @@ at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float
at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
......@@ -245,3 +233,5 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_
return output_grads;
}
} // namespace transformer_engine::pytorch
......@@ -13,13 +13,12 @@ namespace transformer_engine::pytorch {
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list,
transformer_engine::DType otype) {
std::vector<py::handle> quantizer_list, DType otype) {
init_extension();
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
std::vector<py::object> py_output_objects_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
std::vector<TensorWrapper> tensor_wrappers;
if (output_list.has_value()) {
py_output_objects_list = output_list.value();
}
......@@ -33,7 +32,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
auto input_tensor = makeTransformerEngineTensor(input_list[i]);
const NVTEShape input_shape = input_tensor.shape();
transformer_engine::TensorWrapper output_tensor;
TensorWrapper output_tensor;
if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) {
with_fused_kernel = false;
......@@ -68,8 +67,10 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
// Launch TE kernel
if (with_fused_kernel) {
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
});
} else {
for (size_t i = 0; i < py_output_objects_list.size(); i++) {
quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt);
......@@ -78,8 +79,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
return py_output_objects_list;
}
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
std::optional<at::Tensor> output) {
at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) {
init_extension();
const auto dim = input.dim();
......@@ -100,8 +100,8 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
}
if (M == 0 || N == 0) return out;
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto output_cu = makeTransformerEngineTensor(out.data_ptr(), {N, M}, otype);
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector<size_t>{M, N}, otype);
auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector<size_t>{N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
......
......@@ -8,6 +8,8 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_
#include <Python.h>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
......@@ -18,6 +20,16 @@
namespace transformer_engine::pytorch {
#define NVTE_SCOPED_GIL_RELEASE(code_block) \
do { \
if (PyGILState_Check()) { \
pybind11::gil_scoped_release _gil_release; \
code_block \
} else { \
code_block \
} \
} while (false);
extern PyTypeObject *Float8TensorPythonClass;
extern PyTypeObject *Float8TensorBasePythonClass;
extern PyTypeObject *Float8QuantizerClass;
......
......@@ -9,7 +9,6 @@
#include "common.h"
#include "pybind.h"
#include "torch/torch.h"
#include "util.h"
namespace transformer_engine::pytorch {
......@@ -103,7 +102,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported();
bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
}
......@@ -215,7 +214,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported();
bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment