Unverified Commit d40846d4 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

breakdown kernel update (#8334)

parent 145482f4
...@@ -10,13 +10,10 @@ def transfer_kv_per_layer( ...@@ -10,13 +10,10 @@ def transfer_kv_per_layer(
dst_v: torch.Tensor, dst_v: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
io_backend: str,
page_size: int,
item_size: int, item_size: int,
block_quota: int = 2, block_quota: int = 2,
num_warps_per_block: int = 32, num_warps_per_block: int = 32,
): ):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_per_layer( torch.ops.sgl_kernel.transfer_kv_per_layer(
src_k, src_k,
dst_k, dst_k,
...@@ -24,16 +21,10 @@ def transfer_kv_per_layer( ...@@ -24,16 +21,10 @@ def transfer_kv_per_layer(
dst_v, dst_v,
src_indices, src_indices,
dst_indices, dst_indices,
item_size * src_k.element_size(), # todo, hot fix for compatibility item_size,
block_quota, block_quota,
num_warps_per_block, num_warps_per_block,
) )
elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_direct(
[src_k, src_v], [dst_k, dst_v], src_indices, dst_indices, page_size
)
else:
raise ValueError(f"Unsupported io backend")
def transfer_kv_per_layer_pf_lf( def transfer_kv_per_layer_pf_lf(
...@@ -69,13 +60,11 @@ def transfer_kv_all_layer( ...@@ -69,13 +60,11 @@ def transfer_kv_all_layer(
dst_v_layers: torch.Tensor, dst_v_layers: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
io_backend: str,
item_size: int, item_size: int,
num_layers: int, num_layers: int,
block_quota: int = 2, block_quota: int = 2,
num_warps_per_block: int = 32, num_warps_per_block: int = 32,
): ):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_all_layer( torch.ops.sgl_kernel.transfer_kv_all_layer(
src_k_layers, src_k_layers,
dst_k_layers, dst_k_layers,
...@@ -88,10 +77,6 @@ def transfer_kv_all_layer( ...@@ -88,10 +77,6 @@ def transfer_kv_all_layer(
block_quota, block_quota,
num_warps_per_block, num_warps_per_block,
) )
elif io_backend == "direct":
raise NotImplementedError("Deprecated interface")
else:
raise ValueError(f"Unsupported io backend")
def transfer_kv_all_layer_lf_pf( def transfer_kv_all_layer_lf_pf(
...@@ -139,28 +124,19 @@ def transfer_kv_per_layer_mla( ...@@ -139,28 +124,19 @@ def transfer_kv_per_layer_mla(
dst: torch.Tensor, dst: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
io_backend: str,
page_size: int,
item_size: int, item_size: int,
block_quota: int = 2, block_quota: int = 2,
num_warps_per_block: int = 32, num_warps_per_block: int = 32,
): ):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_per_layer_mla( torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
src, src,
dst, dst,
src_indices, src_indices,
dst_indices, dst_indices,
item_size * src.element_size(), # todo, hot fix for compatibility item_size,
block_quota, block_quota,
num_warps_per_block, num_warps_per_block,
) )
elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_direct(
[src], [dst], src_indices, dst_indices, page_size
)
else:
raise ValueError(f"Unsupported io backend")
def transfer_kv_per_layer_mla_pf_lf( def transfer_kv_per_layer_mla_pf_lf(
...@@ -190,13 +166,11 @@ def transfer_kv_all_layer_mla( ...@@ -190,13 +166,11 @@ def transfer_kv_all_layer_mla(
dst_layers: torch.Tensor, dst_layers: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
io_backend: str,
item_size: int, item_size: int,
num_layers: int, num_layers: int,
block_quota: int = 2, block_quota: int = 2,
num_warps_per_block: int = 32, num_warps_per_block: int = 32,
): ):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_all_layer_mla( torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
src_layers, src_layers,
dst_layers, dst_layers,
...@@ -207,10 +181,6 @@ def transfer_kv_all_layer_mla( ...@@ -207,10 +181,6 @@ def transfer_kv_all_layer_mla(
block_quota, block_quota,
num_warps_per_block, num_warps_per_block,
) )
elif io_backend == "direct":
raise NotImplementedError("Deprecated interface")
else:
raise ValueError(f"Unsupported io backend")
def transfer_kv_all_layer_mla_lf_pf( def transfer_kv_all_layer_mla_lf_pf(
......
...@@ -101,9 +101,7 @@ def test_transfer_kv( ...@@ -101,9 +101,7 @@ def test_transfer_kv(
dst_pool_kernel[layer_idx_to_test], dst_pool_kernel[layer_idx_to_test],
src_indices_device, src_indices_device,
dst_indices_device, dst_indices_device,
io_backend="kernel", item_size=item_size * dtype.itemsize,
page_size=page_size,
item_size=item_size,
) )
transfer_kv_direct( transfer_kv_direct(
[src_pool_host[layer_idx_to_test]], [src_pool_host[layer_idx_to_test]],
...@@ -138,7 +136,6 @@ def test_transfer_kv( ...@@ -138,7 +136,6 @@ def test_transfer_kv(
dst_layers_device, dst_layers_device,
src_indices_device, src_indices_device,
dst_indices_device, dst_indices_device,
io_backend="kernel",
item_size=item_size * dtype.itemsize, item_size=item_size * dtype.itemsize,
num_layers=num_layers, num_layers=num_layers,
) )
...@@ -173,9 +170,7 @@ def test_transfer_kv( ...@@ -173,9 +170,7 @@ def test_transfer_kv(
dst_v_pool_kernel[layer_idx_to_test], dst_v_pool_kernel[layer_idx_to_test],
src_indices_device, src_indices_device,
dst_indices_device, dst_indices_device,
io_backend="kernel", item_size=item_size * dtype.itemsize,
page_size=page_size,
item_size=item_size,
) )
transfer_kv_direct( transfer_kv_direct(
[src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]], [src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]],
...@@ -235,7 +230,6 @@ def test_transfer_kv( ...@@ -235,7 +230,6 @@ def test_transfer_kv(
dst_v_layers_device, dst_v_layers_device,
src_indices_device, src_indices_device,
dst_indices_device, dst_indices_device,
io_backend="kernel",
item_size=item_size * dtype.itemsize, item_size=item_size * dtype.itemsize,
num_layers=num_layers, num_layers=num_layers,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment