Unverified Commit 7e43feae authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Release GIL for PyTorch extensions (#1767)



* Disallow kwargs for pybind extensions and release GIL if possible
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Wrap nvte_* calls
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0e45e138
......@@ -35,9 +35,13 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
auto my_quantizer_none = std::make_unique<NoneQuantizer>(py::none());
auto [te_output_act, out_act] =
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
NVTE_SCOPED_GIL_RELEASE({
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
// use te_output_act as input to the compute amax and find the amax of activated tensor
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
if (my_quantizer_cs->with_amax_reduction) {
......@@ -47,17 +51,22 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet.");
} else {
act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE(
{ act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); });
}
return out;
......@@ -80,7 +89,9 @@ py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input,
auto [te_output, out] =
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
NVTE_SCOPED_GIL_RELEASE({
act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});
return out;
}
......
......@@ -5,6 +5,7 @@
************************************************************************/
#include "extensions.h"
#include "pybind.h"
constexpr int block_size = 512;
......@@ -41,13 +42,16 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
size_t num_rows_to_zero = max_tokens - start_row;
size_t total_bytes = num_rows_to_zero * fcd_size * element_size;
nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE(
{ nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); });
}
void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) {
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_,
at::cuda::getCurrentCUDAStream());
});
}
// extract PhiloxCudaState from CUDA random number generator
......@@ -177,13 +181,15 @@ std::vector<py::object> fused_attn_fwd(
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_fwd(
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
workspace.data(), at::cuda::getCurrentCUDAStream());
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
});
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -231,13 +237,15 @@ std::vector<py::object> fused_attn_fwd(
}
// execute the kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_fwd(
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
workspace.data(), at::cuda::getCurrentCUDAStream());
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
});
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -456,13 +464,15 @@ std::vector<py::object> fused_attn_bwd(
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
window_size[0], window_size[1], deterministic, workspace.data(),
at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic,
workspace.data(), at::cuda::getCurrentCUDAStream());
});
// allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -470,13 +480,15 @@ std::vector<py::object> fused_attn_bwd(
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
window_size[0], window_size[1], deterministic, workspace.data(),
at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic,
workspace.data(), at::cuda::getCurrentCUDAStream());
});
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
......@@ -32,8 +32,10 @@ std::vector<py::object> bgrad_quantize(const at::Tensor& input, py::handle py_qu
auto dbias_tensor = makeTransformerEngineTensor(dbias);
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
void* workspace_data_ptr = nullptr;
if (workspace.shape().ndim > 0) {
......@@ -46,7 +48,9 @@ std::vector<py::object> bgrad_quantize(const at::Tensor& input, py::handle py_qu
if (detail::IsFloat8CurrentScalingQuantizers(py_quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(quantizer.get());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
......@@ -61,12 +65,17 @@ std::vector<py::object> bgrad_quantize(const at::Tensor& input, py::handle py_qu
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_tensor.data(), quant_config, at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_tensor.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_tensor.set_amax(nullptr, DType::kFloat32, out_tensor.defaultShape);
}
nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
return {py::cast(dbias), out};
}
......
......@@ -52,7 +52,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
......@@ -69,7 +71,10 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
......@@ -77,8 +82,10 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
return out;
}
......@@ -96,7 +103,9 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype)
auto [out_tensor, out] = q.create_tensor(shape, otype);
NVTE_SCOPED_GIL_RELEASE({
nvte_dequantize(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream());
});
return out;
}
......@@ -120,15 +129,19 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({
func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
return {py::cast(grad_bias), dact};
}
......
......@@ -4,7 +4,6 @@
* See LICENSE for license information.
************************************************************************/
#include <Python.h>
#include <pybind11/pybind11.h>
#include <optional>
......@@ -197,38 +196,52 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Direct GEMM call to the correct overlap
if (bulk_overlap) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor,
te_pre_gelu_out, te_workspace, grad, accumulate,
use_split_accumulator, comm_type.value(), extra_output_tensor,
main_stream);
});
} else if (comm_type.value() == CommOverlapType::AG) {
if (comm_overlap->is_atomic_gemm()) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator,
extra_output_tensor, main_stream);
});
} else {
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor,
te_pre_gelu_out, te_workspace, grad, accumulate,
use_split_accumulator, extra_output_tensor, main_stream);
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator, extra_output_tensor,
main_stream);
});
}
} else {
if (comm_overlap->is_atomic_gemm()) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator,
extra_output_tensor, main_stream);
});
} else {
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor,
te_pre_gelu_out, te_workspace, grad, accumulate,
use_split_accumulator, extra_output_tensor, main_stream);
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator, extra_output_tensor,
main_stream);
});
}
}
} else {
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, main_stream);
});
}
} else {
if (D_tensor.numel() != 0 && !accumulate) {
......@@ -303,10 +316,12 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream());
});
}
std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
......@@ -426,10 +441,12 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
wrappers.emplace_back(std::move(wsp));
}
// For now, we only have multi-stream cublas backend.
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_A_vector.size(), transa, transb, grad,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
});
return bias;
}
......@@ -52,10 +52,12 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dbeta_cu = makeTransformerEngineTensor(dbeta);
// This call populates tensors with the required config.
NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -63,10 +65,12 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Actual call to bwd kernel.
NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
return {py::cast(dx), py::cast(dgamma), py::cast(dbeta)};
}
......@@ -132,10 +136,12 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Query workspace size
transformer_engine::TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Allocate workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -143,10 +149,12 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
......@@ -154,7 +162,10 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(),
at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
......@@ -169,7 +180,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
}
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
......@@ -177,8 +190,10 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
}
return {out, py::cast(mu), py::cast(rsigma)};
......@@ -205,10 +220,12 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
// This call populates tensors with the required config.
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -216,10 +233,12 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Actual call to bwd kernel.
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
return {py::cast(dx), py::cast(dgamma)};
}
......@@ -279,10 +298,12 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Query workspace size
transformer_engine::TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Allocate workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -290,10 +311,12 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
......@@ -301,7 +324,10 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(),
at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
......@@ -316,7 +342,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
}
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
......@@ -324,8 +352,10 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
}
return {out, py::none(), py::cast(rsigma)};
......
......@@ -5,6 +5,7 @@
************************************************************************/
#include "extensions.h"
#include "pybind.h"
void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
......@@ -75,6 +76,8 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
"Number of input and padded row list must match");
// Launch TE kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(),
padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream());
});
}
......@@ -6,7 +6,6 @@
#include "pybind.h"
#include <Python.h>
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
......@@ -160,10 +159,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("quantizer"));
// Permutation functions
m.def("moe_permute_fwd", moe_permute_fwd);
m.def("moe_permute_bwd", moe_permute_bwd);
m.def("moe_unpermute_fwd", moe_unpermute_fwd);
m.def("moe_unpermute_bwd", moe_unpermute_bwd);
m.def("moe_permute_fwd", moe_permute_fwd, "MOE permute FWD",
py::call_guard<py::gil_scoped_release>());
m.def("moe_permute_bwd", moe_permute_bwd, "MOE permute BWD",
py::call_guard<py::gil_scoped_release>());
m.def("moe_unpermute_fwd", moe_unpermute_fwd, "MOE unpermute FWD",
py::call_guard<py::gil_scoped_release>());
m.def("moe_unpermute_bwd", moe_unpermute_bwd, "MOE unpermute BWD",
py::call_guard<py::gil_scoped_release>());
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD",
......@@ -206,17 +209,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax"));
m.def("compute_amax", &compute_amax, "Compute absolute max value in tensor", py::arg("input"),
py::arg("amax"), py::call_guard<py::gil_scoped_release>());
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_compute_partial_amax", &fp8_block_scaling_compute_partial_amax,
"Compute partial amax from master weights for fp8 block scaling", py::arg("tensor"),
py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"));
py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_partial_cast", &fp8_block_scaling_partial_cast,
"Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"),
py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::arg("out_dtype"));
py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding",
py::call_guard<py::gil_scoped_release>());
......@@ -229,9 +234,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("copy_to_kv_cache", &copy_to_kv_cache, "Copy new KV tokens to KV cache");
m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD");
m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD");
m.def("copy_to_kv_cache", &copy_to_kv_cache, "Copy new KV tokens to KV cache",
py::call_guard<py::gil_scoped_release>());
m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD",
py::call_guard<py::gil_scoped_release>());
m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD",
py::call_guard<py::gil_scoped_release>());
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD",
......
......@@ -68,8 +68,10 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
// Launch TE kernel
if (with_fused_kernel) {
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
});
} else {
for (size_t i = 0; i < py_output_objects_list.size(); i++) {
quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt);
......
......@@ -8,6 +8,8 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_
#include <Python.h>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
......@@ -18,6 +20,16 @@
namespace transformer_engine::pytorch {
#define NVTE_SCOPED_GIL_RELEASE(code_block) \
do { \
if (PyGILState_Check()) { \
pybind11::gil_scoped_release _gil_release; \
code_block \
} else { \
code_block \
} \
} while (false);
extern PyTypeObject *Float8TensorPythonClass;
extern PyTypeObject *Float8TensorBasePythonClass;
extern PyTypeObject *Float8QuantizerClass;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment