Unverified Commit 6ee92c4b authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Release GIL in PyTorch extensions (#938)



Release GIL in PyTorch pybind11 functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 70d3251f
...@@ -11,151 +11,192 @@ ...@@ -11,151 +11,192 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Softmax functions // Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD"); m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD",
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD"); py::call_guard<py::gil_scoped_release>());
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward, m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward,
"Scaled Masked Softmax FWD"); "Scaled Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward, m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward,
"Scaled Masked Softmax BWD"); "Scaled Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_forward", &scaled_upper_triang_masked_softmax_forward, m.def("scaled_upper_triang_masked_softmax_forward", &scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD"); "Scaled Upper-Triangular Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_backward", &scaled_upper_triang_masked_softmax_backward, m.def("scaled_upper_triang_masked_softmax_backward", &scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD"); "Scaled Upper-Triangular Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_forward", m.def("scaled_aligned_causal_masked_softmax_forward",
&scaled_aligned_causal_masked_softmax_forward, &scaled_aligned_causal_masked_softmax_forward,
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD"); "Scaled Bottom-Right Corner Aligned Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_backward", m.def("scaled_aligned_causal_masked_softmax_backward",
&scaled_aligned_causal_masked_softmax_backward, &scaled_aligned_causal_masked_softmax_backward,
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD"); "Scaled Bottom-Right Corner Aligned Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());
// Other granular functions // Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8", py::arg("input"), py::arg("weight"), m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8",
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8", py::arg("input"), m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8",
py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"), py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"),
py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0); py::arg("scale_inv_offset") = 0);
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD"); m.def("layernorm_bwd", &layernorm_bwd, "LN BWD", py::call_guard<py::gil_scoped_release>());
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD"); m.def("layernorm_fwd", &layernorm_fwd, "LN FWD", py::call_guard<py::gil_scoped_release>());
m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD"); m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD",
m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8", py::arg("input"), py::arg("weight"), py::call_guard<py::gil_scoped_release>());
m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8",
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"),
py::arg("sm_margin"), py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("sm_margin"), py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8", py::arg("input"), m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8",
py::arg("weight"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), py::arg("scale_inv"),
py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"),
py::arg("scale_inv_offset") = 0); py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD"); m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD", py::call_guard<py::gil_scoped_release>());
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD"); m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD", py::call_guard<py::gil_scoped_release>());
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD"); m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD",
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); py::call_guard<py::gil_scoped_release>());
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose",
py::call_guard<py::gil_scoped_release>());
m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop, m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop,
"Fused Cast + Transpose with noop option", py::arg("input"), py::arg("noop"), "Cast + Transpose with noop option", py::call_guard<py::gil_scoped_release>(),
py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("input_cast"), py::arg("input"), py::arg("noop"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"),
py::arg("input_transpose"), py::arg("otype"), py::arg("scale_offset") = 0, py::arg("input_cast"), py::arg("input_transpose"), py::arg("otype"),
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD", m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD",
py::arg("grad_output"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("scale"),
py::arg("otype"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, "Fused FP8 Transpose + BGRAD",
py::arg("grad_output"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"),
py::arg("otype"), py::arg("grad_bias_type"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu, m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, "Fused FP8 Transpose + BGRAD",
"Fused Cast + Transpose + BGRAD + DGELU", py::arg("grad_output"), py::arg("gelu_input"), py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("scale"),
py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("grad_bias_type"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU", py::call_guard<py::gil_scoped_release>(),
py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose"); "Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>());
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8"); m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard<py::gil_scoped_release>());
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8"); m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8",
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8"); py::call_guard<py::gil_scoped_release>());
m.def("te_gemm", &te_gemm, "CublasLt GEMM"); m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard<py::gil_scoped_release>());
m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think
m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed QKV"); "Fused Attention FP8/BF16/FP16 FWD with packed QKV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked, m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed QKV"); "Fused Attention FP8/BF16/FP16 BWD with packed QKV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked, m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed KV"); "Fused Attention FP8/BF16/FP16 FWD with packed KV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed KV"); "Fused Attention FP8/BF16/FP16 BWD with packed KV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &fused_attn_fwd, m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_bwd", &fused_attn_bwd, m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V",
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); py::call_guard<py::gil_scoped_release>());
m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O",
py::call_guard<py::gil_scoped_release>());
m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O",
py::call_guard<py::gil_scoped_release>());
m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop, m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop,
"Transpose with FP8 I/O with noop option."); "Transpose with FP8 I/O with noop option.", py::call_guard<py::gil_scoped_release>());
m.def("gelu", &gelu, "GeLU with FP8 output"); m.def("gelu", &gelu, "GeLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("relu", &relu, "ReLU with FP8 output"); m.def("relu", &relu, "ReLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("geglu", &geglu, "GeGLU with FP8 output"); m.def("geglu", &geglu, "GeGLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("reglu", &reglu, "ReGLU with FP8 output"); m.def("reglu", &reglu, "ReGLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("swiglu", &swiglu, "SwiGLU with FP8 output"); m.def("swiglu", &swiglu, "SwiGLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("qgelu", &qgelu, "QuickGELU with FP8 output"); m.def("qgelu", &qgelu, "QuickGELU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("srelu", &srelu, "Squared ReLU with FP8 output"); m.def("srelu", &srelu, "Squared ReLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("dgelu", &dgelu, "Backward of GeLU"); m.def("dgelu", &dgelu, "Backward of GeLU", py::call_guard<py::gil_scoped_release>());
m.def("drelu", &drelu, "Backward of ReLU"); m.def("drelu", &drelu, "Backward of ReLU", py::call_guard<py::gil_scoped_release>());
m.def("dgeglu", &dgeglu, "Backward of GeGLU"); m.def("dgeglu", &dgeglu, "Backward of GeGLU", py::call_guard<py::gil_scoped_release>());
m.def("dreglu", &dreglu, "Backward of ReGLU"); m.def("dreglu", &dreglu, "Backward of ReGLU", py::call_guard<py::gil_scoped_release>());
m.def("dswiglu", &dswiglu, "Backward of SwiGLU"); m.def("dswiglu", &dswiglu, "Backward of SwiGLU", py::call_guard<py::gil_scoped_release>());
m.def("dqgelu", &dqgelu, "Backward of QuickGELU"); m.def("dqgelu", &dqgelu, "Backward of QuickGELU", py::call_guard<py::gil_scoped_release>());
m.def("dsrelu", &dsrelu, "Backward of Squared ReLU"); m.def("dsrelu", &dsrelu, "Backward of Squared ReLU", py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>());
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction"); "Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>());
// fused apply rope // fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD"); m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD",
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD"); py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format"); m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD",
m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format"); py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format",
py::call_guard<py::gil_scoped_release>());
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version",
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version"); py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>());
// Support THD format for Context Parallel // Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor, m.def("thd_read_half_tensor", &thd_read_half_tensor,
"Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD " "Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD "
"tensor"); "tensor",
py::call_guard<py::gil_scoped_release>());
m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction, m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction,
"Correct the second half of the softmax_lse"); "Correct the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_read_second_half_lse", &thd_read_second_half_lse, m.def("thd_read_second_half_lse", &thd_read_second_half_lse,
"Read the second half of the softmax_lse"); "Read the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_out_correction", &thd_out_correction, m.def("thd_out_correction", &thd_out_correction,
"Correct the THD format output of context parallelism in forward pass"); "Correct the THD format output of context parallelism in forward pass",
py::call_guard<py::gil_scoped_release>());
m.def("thd_grad_correction", &thd_grad_correction, m.def("thd_grad_correction", &thd_grad_correction,
"Correct the THD format gradients of context parallelism in backward pass"); "Correct the THD format gradients of context parallelism in backward pass",
py::call_guard<py::gil_scoped_release>());
m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices, m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format"); "Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>());
// multi-tensor functions // multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors"); "Computes L2 norm for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda, m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only " "Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only "
"performed for L2 norm computation, and tensors are not updated)"); "performed for L2 norm computation, and tensors are not updated)",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam", &multi_tensor_adam_cuda, m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer"); "Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda, m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support and LR scheduling"); "support and LR scheduling",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda, m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support, LR scheduling and FP32 master weights"); "support, LR scheduling and FP32 master weights",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors"); "Fused SGD optimizer for list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
// Data structures // Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta") py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
...@@ -164,8 +205,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -164,8 +205,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
// comm+GEMM overlap w/ userbuffers // Communication functions to initialize Userbuffers communicators
m.def("set_ubuf_bootstrap_callbacks", &ubuf::set_ubuf_bootstrap_callbacks); // Note: Callbacks are not called, so safe to release GIL.
m.def("set_ubuf_bootstrap_callbacks", &ubuf::set_ubuf_bootstrap_callbacks,
py::call_guard<py::gil_scoped_release>());
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo") py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
...@@ -177,32 +220,57 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -177,32 +220,57 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P) .value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P)
.value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P); .value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P);
// Note: Can't release GIL in constructor since it may bootstrap
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap") py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, int, bool, int, bool, .def(py::init<torch::Tensor&, int, int, int, int, int, int, int, bool, int, bool,
torch::Tensor>()) torch::Tensor>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap,
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv) .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs,
.def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs) py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf) .def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv,
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output) .def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs,
.def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm) py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap); .def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm,
py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap,
py::call_guard<py::gil_scoped_release>());
// Note: Can't release GIL in constructor since it may bootstrap
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap") py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, bool, bool, int, bool, bool, bool, .def(py::init<torch::Tensor&, int, int, int, int, int, int, bool, bool, int, bool, bool, bool,
torch::Tensor>()) torch::Tensor>())
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag) .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag,
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs) py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag) .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs,
.def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs) py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) .def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag,
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output) py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf) .def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs,
.def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm) py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap) .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf,
.def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv); py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm,
py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>());
py::enum_<transformer_engine::DType>(m, "DType", py::module_local()) py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte) .value("kByte", transformer_engine::DType::kByte)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment