Unverified Commit 36cc3ffd authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

[CPU] [sgl-kernel] set dispatch key of initialize to CatchAll (#7734)

parent 1bebd315
...@@ -342,7 +342,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -342,7 +342,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// all reduce // all reduce
m.def("initialize(int size, int rank) -> ()"); m.def("initialize(int size, int rank) -> ()");
m.impl("initialize", torch::kCPU, &initialize);
m.def("shm_allreduce(Tensor data, int reduce_op) -> ()"); m.def("shm_allreduce(Tensor data, int reduce_op) -> ()");
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce); m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
m.def("shm_allgather(Tensor data, int dim) -> Tensor"); m.def("shm_allgather(Tensor data, int dim) -> Tensor");
...@@ -360,6 +359,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -360,6 +359,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) { TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) {
m.impl("init_cpu_threads_env", init_cpu_threads_env); m.impl("init_cpu_threads_env", init_cpu_threads_env);
m.impl("initialize", &initialize);
} }
REGISTER_EXTENSION(common_ops) REGISTER_EXTENSION(common_ops)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment