Unverified Commit d40846d4 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

breakdown kernel update (#8334)

parent 145482f4
...@@ -10,30 +10,21 @@ def transfer_kv_per_layer( ...@@ -10,30 +10,21 @@ def transfer_kv_per_layer(
dst_v: torch.Tensor, dst_v: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
io_backend: str,
page_size: int,
item_size: int, item_size: int,
block_quota: int = 2, block_quota: int = 2,
num_warps_per_block: int = 32, num_warps_per_block: int = 32,
): ):
if io_backend == "kernel": torch.ops.sgl_kernel.transfer_kv_per_layer(
torch.ops.sgl_kernel.transfer_kv_per_layer( src_k,
src_k, dst_k,
dst_k, src_v,
src_v, dst_v,
dst_v, src_indices,
src_indices, dst_indices,
dst_indices, item_size,
item_size * src_k.element_size(), # todo, hot fix for compatibility block_quota,
block_quota, num_warps_per_block,
num_warps_per_block, )
)
elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_direct(
[src_k, src_v], [dst_k, dst_v], src_indices, dst_indices, page_size
)
else:
raise ValueError(f"Unsupported io backend")
def transfer_kv_per_layer_pf_lf( def transfer_kv_per_layer_pf_lf(
...@@ -69,29 +60,23 @@ def transfer_kv_all_layer( ...@@ -69,29 +60,23 @@ def transfer_kv_all_layer(
dst_v_layers: torch.Tensor, dst_v_layers: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
io_backend: str,
item_size: int, item_size: int,
num_layers: int, num_layers: int,
block_quota: int = 2, block_quota: int = 2,
num_warps_per_block: int = 32, num_warps_per_block: int = 32,
): ):
if io_backend == "kernel": torch.ops.sgl_kernel.transfer_kv_all_layer(
torch.ops.sgl_kernel.transfer_kv_all_layer( src_k_layers,
src_k_layers, dst_k_layers,
dst_k_layers, src_v_layers,
src_v_layers, dst_v_layers,
dst_v_layers, src_indices,
src_indices, dst_indices,
dst_indices, item_size,
item_size, num_layers,
num_layers, block_quota,
block_quota, num_warps_per_block,
num_warps_per_block, )
)
elif io_backend == "direct":
raise NotImplementedError("Deprecated interface")
else:
raise ValueError(f"Unsupported io backend")
def transfer_kv_all_layer_lf_pf( def transfer_kv_all_layer_lf_pf(
...@@ -139,28 +124,19 @@ def transfer_kv_per_layer_mla( ...@@ -139,28 +124,19 @@ def transfer_kv_per_layer_mla(
dst: torch.Tensor, dst: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
io_backend: str,
page_size: int,
item_size: int, item_size: int,
block_quota: int = 2, block_quota: int = 2,
num_warps_per_block: int = 32, num_warps_per_block: int = 32,
): ):
if io_backend == "kernel": torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
torch.ops.sgl_kernel.transfer_kv_per_layer_mla( src,
src, dst,
dst, src_indices,
src_indices, dst_indices,
dst_indices, item_size,
item_size * src.element_size(), # todo, hot fix for compatibility block_quota,
block_quota, num_warps_per_block,
num_warps_per_block, )
)
elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_direct(
[src], [dst], src_indices, dst_indices, page_size
)
else:
raise ValueError(f"Unsupported io backend")
def transfer_kv_per_layer_mla_pf_lf( def transfer_kv_per_layer_mla_pf_lf(
...@@ -190,27 +166,21 @@ def transfer_kv_all_layer_mla( ...@@ -190,27 +166,21 @@ def transfer_kv_all_layer_mla(
dst_layers: torch.Tensor, dst_layers: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
io_backend: str,
item_size: int, item_size: int,
num_layers: int, num_layers: int,
block_quota: int = 2, block_quota: int = 2,
num_warps_per_block: int = 32, num_warps_per_block: int = 32,
): ):
if io_backend == "kernel": torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
torch.ops.sgl_kernel.transfer_kv_all_layer_mla( src_layers,
src_layers, dst_layers,
dst_layers, src_indices,
src_indices, dst_indices,
dst_indices, item_size,
item_size, num_layers,
num_layers, block_quota,
block_quota, num_warps_per_block,
num_warps_per_block, )
)
elif io_backend == "direct":
raise NotImplementedError("Deprecated interface")
else:
raise ValueError(f"Unsupported io backend")
def transfer_kv_all_layer_mla_lf_pf( def transfer_kv_all_layer_mla_lf_pf(
......
...@@ -101,9 +101,7 @@ def test_transfer_kv( ...@@ -101,9 +101,7 @@ def test_transfer_kv(
dst_pool_kernel[layer_idx_to_test], dst_pool_kernel[layer_idx_to_test],
src_indices_device, src_indices_device,
dst_indices_device, dst_indices_device,
io_backend="kernel", item_size=item_size * dtype.itemsize,
page_size=page_size,
item_size=item_size,
) )
transfer_kv_direct( transfer_kv_direct(
[src_pool_host[layer_idx_to_test]], [src_pool_host[layer_idx_to_test]],
...@@ -138,7 +136,6 @@ def test_transfer_kv( ...@@ -138,7 +136,6 @@ def test_transfer_kv(
dst_layers_device, dst_layers_device,
src_indices_device, src_indices_device,
dst_indices_device, dst_indices_device,
io_backend="kernel",
item_size=item_size * dtype.itemsize, item_size=item_size * dtype.itemsize,
num_layers=num_layers, num_layers=num_layers,
) )
...@@ -173,9 +170,7 @@ def test_transfer_kv( ...@@ -173,9 +170,7 @@ def test_transfer_kv(
dst_v_pool_kernel[layer_idx_to_test], dst_v_pool_kernel[layer_idx_to_test],
src_indices_device, src_indices_device,
dst_indices_device, dst_indices_device,
io_backend="kernel", item_size=item_size * dtype.itemsize,
page_size=page_size,
item_size=item_size,
) )
transfer_kv_direct( transfer_kv_direct(
[src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]], [src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]],
...@@ -235,7 +230,6 @@ def test_transfer_kv( ...@@ -235,7 +230,6 @@ def test_transfer_kv(
dst_v_layers_device, dst_v_layers_device,
src_indices_device, src_indices_device,
dst_indices_device, dst_indices_device,
io_backend="kernel",
item_size=item_size * dtype.itemsize, item_size=item_size * dtype.itemsize,
num_layers=num_layers, num_layers=num_layers,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment