Unverified Commit e0a2c963 authored by kk's avatar kk Committed by GitHub
Browse files

Fix breakage problem when using custom_ar (#4052)

parent 12f2e6c3
...@@ -75,42 +75,42 @@ else: ...@@ -75,42 +75,42 @@ else:
rank: int, rank: int,
full_nvlink: bool, full_nvlink: bool,
) -> int: ) -> int:
return sgl_kernel.ops.init_custom_ar( return sgl_kernel.ops.allreduce.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink meta, rank_data, handles, offsets, rank, full_nvlink
) )
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.all_reduce_reg(fa, inp, out) sgl_kernel.ops.allreduce.all_reduce_reg(fa, inp, out)
def all_reduce_unreg( def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None: ) -> None:
sgl_kernel.ops.all_reduce_unreg(fa, inp, reg_buffer, out) sgl_kernel.ops.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None: def dispose(fa: int) -> None:
sgl_kernel.ops.dispose(fa) sgl_kernel.ops.allreduce.dispose(fa)
def meta_size() -> int: def meta_size() -> int:
return sgl_kernel.ops.meta_size() return sgl_kernel.ops.allreduce.meta_size()
def register_buffer( def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None: ) -> None:
return sgl_kernel.ops.register_buffer(fa, t, handles, offsets) return sgl_kernel.ops.allreduce.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) return sgl_kernel.ops.allreduce.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers( def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]] fa: int, handles: List[str], offsets: List[List[int]]
) -> None: ) -> None:
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) sgl_kernel.ops.allreduce.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor: def allocate_meta_buffer(size: int) -> torch.Tensor:
return sgl_kernel.ops.allocate_meta_buffer(size) return sgl_kernel.ops.allreduce.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return sgl_kernel.ops.get_meta_buffer_ipc_handle(inp) return sgl_kernel.ops.allreduce.get_meta_buffer_ipc_handle(inp)
else: else:
# TRTLLM custom allreduce # TRTLLM custom allreduce
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment