Unverified Commit 6ee92c4b authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Release GIL in PyTorch extensions (#938)



Release GIL in PyTorch pybind11 functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 70d3251f
......@@ -11,151 +11,192 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD");
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD");
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward,
"Scaled Masked Softmax FWD");
"Scaled Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward,
"Scaled Masked Softmax BWD");
"Scaled Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_forward", &scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD");
"Scaled Upper-Triangular Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_backward", &scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD");
"Scaled Upper-Triangular Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_forward",
&scaled_aligned_causal_masked_softmax_forward,
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD");
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_backward",
&scaled_aligned_causal_masked_softmax_backward,
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD");
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());
// Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8", py::arg("input"), py::arg("weight"),
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8",
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8", py::arg("input"),
py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"),
py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"),
m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8",
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"),
py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD");
m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD");
m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8", py::arg("input"), py::arg("weight"),
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD", py::call_guard<py::gil_scoped_release>());
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD", py::call_guard<py::gil_scoped_release>());
m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD",
py::call_guard<py::gil_scoped_release>());
m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8",
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"),
py::arg("sm_margin"), py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8", py::arg("input"),
py::arg("weight"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"),
py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD");
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD");
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8",
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), py::arg("scale_inv"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD", py::call_guard<py::gil_scoped_release>());
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD", py::call_guard<py::gil_scoped_release>());
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD",
py::call_guard<py::gil_scoped_release>());
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose",
py::call_guard<py::gil_scoped_release>());
m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop,
"Fused Cast + Transpose with noop option", py::arg("input"), py::arg("noop"),
py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("input_cast"),
py::arg("input_transpose"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
"Cast + Transpose with noop option", py::call_guard<py::gil_scoped_release>(),
py::arg("input"), py::arg("noop"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"),
py::arg("input_cast"), py::arg("input_transpose"), py::arg("otype"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD",
py::arg("grad_output"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"),
py::arg("otype"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, "Fused FP8 Transpose + BGRAD",
py::arg("grad_output"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"),
py::arg("otype"), py::arg("grad_bias_type"), py::arg("scale_offset") = 0,
py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("scale"),
py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU", py::arg("grad_output"), py::arg("gelu_input"),
py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"),
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, "Fused FP8 Transpose + BGRAD",
py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("scale"),
py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("grad_bias_type"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU", py::call_guard<py::gil_scoped_release>(),
py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose");
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8");
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8");
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8");
m.def("te_gemm", &te_gemm, "CublasLt GEMM");
"Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>());
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard<py::gil_scoped_release>());
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8",
py::call_guard<py::gil_scoped_release>());
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard<py::gil_scoped_release>());
m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think
m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed QKV");
"Fused Attention FP8/BF16/FP16 FWD with packed QKV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed QKV");
"Fused Attention FP8/BF16/FP16 BWD with packed QKV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed KV");
"Fused Attention FP8/BF16/FP16 FWD with packed KV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed KV");
"Fused Attention FP8/BF16/FP16 BWD with packed KV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O");
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V",
py::call_guard<py::gil_scoped_release>());
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O",
py::call_guard<py::gil_scoped_release>());
m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O",
py::call_guard<py::gil_scoped_release>());
m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop,
"Transpose with FP8 I/O with noop option.");
m.def("gelu", &gelu, "GeLU with FP8 output");
m.def("relu", &relu, "ReLU with FP8 output");
m.def("geglu", &geglu, "GeGLU with FP8 output");
m.def("reglu", &reglu, "ReGLU with FP8 output");
m.def("swiglu", &swiglu, "SwiGLU with FP8 output");
m.def("qgelu", &qgelu, "QuickGELU with FP8 output");
m.def("srelu", &srelu, "Squared ReLU with FP8 output");
m.def("dgelu", &dgelu, "Backward of GeLU");
m.def("drelu", &drelu, "Backward of ReLU");
m.def("dgeglu", &dgeglu, "Backward of GeGLU");
m.def("dreglu", &dreglu, "Backward of ReGLU");
m.def("dswiglu", &dswiglu, "Backward of SwiGLU");
m.def("dqgelu", &dqgelu, "Backward of QuickGELU");
m.def("dsrelu", &dsrelu, "Backward of Squared ReLU");
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
"Transpose with FP8 I/O with noop option.", py::call_guard<py::gil_scoped_release>());
m.def("gelu", &gelu, "GeLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("relu", &relu, "ReLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("geglu", &geglu, "GeGLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("reglu", &reglu, "ReGLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("swiglu", &swiglu, "SwiGLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("qgelu", &qgelu, "QuickGELU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("srelu", &srelu, "Squared ReLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("dgelu", &dgelu, "Backward of GeLU", py::call_guard<py::gil_scoped_release>());
m.def("drelu", &drelu, "Backward of ReLU", py::call_guard<py::gil_scoped_release>());
m.def("dgeglu", &dgeglu, "Backward of GeGLU", py::call_guard<py::gil_scoped_release>());
m.def("dreglu", &dreglu, "Backward of ReGLU", py::call_guard<py::gil_scoped_release>());
m.def("dswiglu", &dswiglu, "Backward of SwiGLU", py::call_guard<py::gil_scoped_release>());
m.def("dqgelu", &dqgelu, "Backward of QuickGELU", py::call_guard<py::gil_scoped_release>());
m.def("dsrelu", &dsrelu, "Backward of Squared ReLU", py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>());
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction");
"Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>());
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD");
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD");
m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format");
m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format");
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format",
py::call_guard<py::gil_scoped_release>());
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version");
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version",
py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>());
// Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor,
"Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD "
"tensor");
"tensor",
py::call_guard<py::gil_scoped_release>());
m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction,
"Correct the second half of the softmax_lse");
"Correct the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_read_second_half_lse", &thd_read_second_half_lse,
"Read the second half of the softmax_lse");
"Read the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_out_correction", &thd_out_correction,
"Correct the THD format output of context parallelism in forward pass");
"Correct the THD format output of context parallelism in forward pass",
py::call_guard<py::gil_scoped_release>());
m.def("thd_grad_correction", &thd_grad_correction,
"Correct the THD format gradients of context parallelism in backward pass");
"Correct the THD format gradients of context parallelism in backward pass",
py::call_guard<py::gil_scoped_release>());
m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format");
"Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>());
// multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
"Fused overflow check + scale for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
"Computes L2 norm for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only "
"performed for L2 norm computation, and tensors are not updated)");
"performed for L2 norm computation, and tensors are not updated)",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support and LR scheduling");
"support and LR scheduling",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support, LR scheduling and FP32 master weights");
"support, LR scheduling and FP32 master weights",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
"Fused SGD optimizer for list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
// Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
......@@ -164,8 +205,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
// comm+GEMM overlap w/ userbuffers
m.def("set_ubuf_bootstrap_callbacks", &ubuf::set_ubuf_bootstrap_callbacks);
// Communication functions to initialize Userbuffers communicators
// Note: Callbacks are not called, so safe to release GIL.
m.def("set_ubuf_bootstrap_callbacks", &ubuf::set_ubuf_bootstrap_callbacks,
py::call_guard<py::gil_scoped_release>());
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
......@@ -177,32 +220,57 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P)
.value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P);
// Note: Can't release GIL in constructor since it may bootstrap
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, int, bool, int, bool,
torch::Tensor>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv)
.def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs)
.def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf)
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output)
.def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm)
.def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap);
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm,
py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap,
py::call_guard<py::gil_scoped_release>());
// Note: Can't release GIL in constructor since it may bootstrap
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, bool, bool, int, bool, bool, bool,
torch::Tensor>())
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs)
.def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag)
.def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output)
.def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf)
.def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm)
.def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap)
.def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv);
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm,
py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>());
py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment