Unverified Commit 88c88654 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix pybind strings for RMSNorm (#372)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e056664f
......@@ -30,11 +30,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD");
m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD");
m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "LN FWD FP8");
m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "LN FWD FP8");
m.def("rmsnorm_bwd", &rmsnorm_bwd, "LN BWD");
m.def("rmsnorm_fwd", &rmsnorm_fwd, "LN FWD");
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "LN FWD");
m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8");
m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8");
m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD");
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD");
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment