Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
711390a9
"vscode:/vscode.git/clone" did not exist on "ff42d33e9944832a19171967d2edd6c292bdb2d6"
Unverified
Commit
711390a9
authored
Aug 28, 2025
by
Hubert Lu
Committed by
GitHub
Aug 28, 2025
Browse files
[AMD] Support Hierarchical Caching on AMD GPUs (#8236)
parent
53430588
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
105 additions
and
32 deletions
+105
-32
.github/workflows/pr-test-amd.yml
.github/workflows/pr-test-amd.yml
+4
-3
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+42
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+23
-13
sgl-kernel/include/pytorch_extension_utils_rocm.h
sgl-kernel/include/pytorch_extension_utils_rocm.h
+0
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+15
-8
sgl-kernel/setup_rocm.py
sgl-kernel/setup_rocm.py
+1
-0
test/srt/hicache/test_hicache.py
test/srt/hicache/test_hicache.py
+4
-2
test/srt/hicache/test_hicache_mla.py
test/srt/hicache/test_hicache_mla.py
+9
-4
test/srt/hicache/test_hicache_storage.py
test/srt/hicache/test_hicache_storage.py
+4
-2
test/srt/run_suite.py
test/srt/run_suite.py
+3
-0
No files found.
.github/workflows/pr-test-amd.yml
View file @
711390a9
...
...
@@ -223,7 +223,7 @@ jobs:
fail-fast
:
false
matrix
:
runner
:
[
linux-mi300-gpu-1
,
linux-mi325-gpu-1
]
part
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]
part
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
runs-on
:
${{matrix.runner}}
steps
:
-
name
:
Checkout code
...
...
@@ -240,7 +240,7 @@ jobs:
-
name
:
Run test
timeout-minutes
:
50
run
:
|
bash scripts/ci/amd_ci_exec.sh python3 run_suite.py --suite per-commit-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size
7
bash scripts/ci/amd_ci_exec.sh python3 run_suite.py --suite per-commit-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size
8
unit-test-backend-2-gpu-amd
:
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
...
...
@@ -336,13 +336,14 @@ jobs:
bash scripts/ci/amd_ci_install_dependency.sh
-
name
:
Run test
timeout-minutes
:
1
0
timeout-minutes
:
1
4
run
:
|
docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_align.py
docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_topk_softmax.py
docker exec -w /sglang-checkout/sgl-kernel/tests/speculative ci_sglang python3 -m pytest test_eagle_utils.py
docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_apply_token_bitmask_inplace.py
docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_activation.py
docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_kvcacheio.py
pr-test-amd-finish
:
if
:
always()
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
711390a9
...
...
@@ -121,6 +121,48 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
*/
m
.
def
(
"apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()"
);
m
.
impl
(
"apply_token_bitmask_inplace_cuda"
,
&
ApplyTokenBitmaskInplace
);
/*
* From csrc/kvcacheio
*/
m
.
def
(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer
);
m
.
def
(
"transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_pf_lf"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_pf_lf
);
m
.
def
(
"transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
"Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer
);
m
.
def
(
"transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_lf_pf"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_lf_pf
);
m
.
def
(
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla
);
m
.
def
(
"transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int layer_id, "
"int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_mla_pf_lf"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla_pf_lf
);
m
.
def
(
"transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int "
"item_size, int num_layers, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla
);
m
.
def
(
"transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, "
"int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_mla_lf_pf"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla_lf_pf
);
m
.
def
(
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
"page_size) -> ()"
);
m
.
impl
(
"transfer_kv_direct"
,
torch
::
kCUDA
,
&
transfer_kv_direct
);
}
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
711390a9
...
...
@@ -4,21 +4,31 @@
#include <cstdint>
#ifndef USE_ROCM
#define WARP_SIZE 32
#include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
#include "utils.h" // WARP_SIZE
#endif
__device__
__forceinline__
void
transfer_item_warp
(
int32_t
lane_id
,
const
void
*
src_addr
,
void
*
dst_addr
,
int64_t
item_size_bytes
)
{
// todo, different chunk size
int
total_chunks
=
item_size_bytes
/
8
;
const
int
64_t
*
src_8
=
reinterpret_cast
<
const
int64_t
*>
(
src_addr
);
int64_t
*
dst_8
=
reinterpret_cast
<
int64_t
*>
(
dst_addr
);
const
uint64_t
*
__restrict__
src
=
static_cast
<
const
uint64_t
*>
(
src_addr
);
u
int
64_t
*
__restrict__
dst
=
static_cast
<
uint64_t
*>
(
dst_addr
)
;
const
int
total_chunks
=
item_size_bytes
/
sizeof
(
uint64_t
);
#pragma unroll
for
(
int
j
=
lane_id
;
j
<
total_chunks
;
j
+=
32
)
{
const
int64_t
*
src_addr_lane
=
&
src_8
[
j
];
int64_t
*
dst_addr_lane
=
&
dst_8
[
j
];
int64_t
temp_val
;
asm
volatile
(
"ld.global.nc.b64 %0, [%1];"
:
"=l"
(
temp_val
)
:
"l"
(
src_addr_lane
)
:
"memory"
);
asm
volatile
(
"st.global.cg.b64 [%0], %1;"
::
"l"
(
dst_addr_lane
),
"l"
(
temp_val
)
:
"memory"
);
for
(
int
j
=
lane_id
;
j
<
total_chunks
;
j
+=
WARP_SIZE
)
{
#ifndef USE_ROCM
uint64_t
tmp
;
asm
volatile
(
"ld.global.nc.b64 %0,[%1];"
:
"=l"
(
tmp
)
:
"l"
(
src
+
j
)
:
"memory"
);
asm
volatile
(
"st.global.cg.b64 [%0],%1;"
::
"l"
(
dst
+
j
),
"l"
(
tmp
)
:
"memory"
);
#else
uint64_t
tmp
=
__builtin_nontemporal_load
(
src
+
j
);
__builtin_nontemporal_store
(
tmp
,
dst
+
j
);
#endif
}
}
...
...
@@ -78,8 +88,8 @@ __global__ void transfer_kernel_impl(
const
uintptr_t
*
__restrict__
src_v_layer_tbl
,
const
uintptr_t
*
__restrict__
dst_v_layer_tbl
)
{
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int32_t
lane_id
=
tid
%
32
;
int32_t
warp_id
=
tid
/
32
;
int32_t
lane_id
=
tid
%
WARP_SIZE
;
int32_t
warp_id
=
tid
/
WARP_SIZE
;
for
(
int
i
=
0
;
i
<
items_per_warp
;
++
i
)
{
int64_t
item_id
=
warp_id
*
items_per_warp
+
i
;
...
...
@@ -139,7 +149,7 @@ void transfer_kv_launcher(
const
int64_t
items_per_warp
=
div_up
(
num_items
,
block_quota
*
num_warps_per_block
);
const
int32_t
num_blocks
=
div_up
(
num_items
,
items_per_warp
*
num_warps_per_block
);
dim3
grid_dim
(
num_blocks
,
1
,
1
);
const
int32_t
threads_per_block
=
num_warps_per_block
*
32
;
const
int32_t
threads_per_block
=
num_warps_per_block
*
WARP_SIZE
;
const
void
*
src_k_ptr
=
src_k
.
defined
()
?
src_k
.
data_ptr
()
:
nullptr
;
void
*
dst_k_ptr
=
dst_k
.
defined
()
?
dst_k
.
data_ptr
()
:
nullptr
;
...
...
sgl-kernel/
csrc/speculativ
e/pytorch_extension_utils_rocm.h
→
sgl-kernel/
includ
e/pytorch_extension_utils_rocm.h
View file @
711390a9
File moved
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
711390a9
...
...
@@ -3,6 +3,13 @@ from typing import List
import
torch
def
is_hip
()
->
bool
:
return
torch
.
version
.
hip
is
not
None
_is_hip
=
is_hip
()
def
transfer_kv_per_layer
(
src_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
...
...
@@ -12,7 +19,7 @@ def transfer_kv_per_layer(
dst_indices
:
torch
.
Tensor
,
item_size
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer
(
src_k
,
...
...
@@ -38,7 +45,7 @@ def transfer_kv_per_layer_pf_lf(
item_size
:
int
,
src_layout_dim
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_pf_lf
(
src_k
,
...
...
@@ -65,7 +72,7 @@ def transfer_kv_all_layer(
item_size
:
int
,
num_layers
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer
(
src_k_layers
,
...
...
@@ -92,7 +99,7 @@ def transfer_kv_all_layer_lf_pf(
dst_layout_dim
:
int
,
num_layers
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_lf_pf
(
src_k_layers
,
...
...
@@ -128,7 +135,7 @@ def transfer_kv_per_layer_mla(
dst_indices
:
torch
.
Tensor
,
item_size
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla
(
src
,
...
...
@@ -150,7 +157,7 @@ def transfer_kv_per_layer_mla_pf_lf(
item_size
:
int
,
src_layout_dim
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla_pf_lf
(
src
,
...
...
@@ -173,7 +180,7 @@ def transfer_kv_all_layer_mla(
item_size
:
int
,
num_layers
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla
(
src_layers
,
...
...
@@ -196,7 +203,7 @@ def transfer_kv_all_layer_mla_lf_pf(
dst_layout_dim
:
int
,
num_layers
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla_lf_pf
(
src_layers
,
...
...
sgl-kernel/setup_rocm.py
View file @
711390a9
...
...
@@ -49,6 +49,7 @@ sources = [
"csrc/moe/moe_align_kernel.cu"
,
"csrc/moe/moe_topk_softmax_kernels.cu"
,
"csrc/speculative/eagle_utils.cu"
,
"csrc/kvcacheio/transfer.cu"
,
]
cxx_flags
=
[
"-O3"
]
...
...
test/srt/hicache/test_hicache.py
View file @
711390a9
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
is_hip
,
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
...
...
@@ -11,6 +11,8 @@ from sglang.test.test_utils import (
popen_launch_server
,
)
_is_hip
=
is_hip
()
class
TestHiCache
(
CustomTestCase
):
@
classmethod
...
...
@@ -26,7 +28,7 @@ class TestHiCache(CustomTestCase):
"--mem-fraction-static"
,
0.7
,
"--hicache-size"
,
100
,
100
if
not
_is_hip
else
200
,
],
)
...
...
test/srt/hicache/test_hicache_mla.py
View file @
711390a9
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
is_hip
,
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
...
...
@@ -11,6 +11,12 @@ from sglang.test.test_utils import (
popen_launch_server
,
)
_is_hip
=
is_hip
()
if
_is_hip
:
hicache_args
=
[
"--hicache-size"
,
200
]
else
:
hicache_args
=
[
"--hicache-ratio"
,
2
]
class
TestHierarchicalMLA
(
CustomTestCase
):
@
classmethod
...
...
@@ -24,9 +30,8 @@ class TestHierarchicalMLA(CustomTestCase):
other_args
=
[
"--trust-remote-code"
,
"--enable-hierarchical-cache"
,
"--hicache-ratio"
,
2
,
],
]
+
hicache_args
,
)
@
classmethod
...
...
test/srt/hicache/test_hicache_storage.py
View file @
711390a9
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
is_hip
,
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
...
...
@@ -11,6 +11,8 @@ from sglang.test.test_utils import (
popen_launch_server
,
)
_is_hip
=
is_hip
()
class
TestHiCache
(
CustomTestCase
):
@
classmethod
...
...
@@ -26,7 +28,7 @@ class TestHiCache(CustomTestCase):
"--mem-fraction-static"
,
0.7
,
"--hicache-size"
,
100
,
100
if
not
_is_hip
else
200
,
"--page-size"
,
"64"
,
"--hicache-storage-backend"
,
...
...
test/srt/run_suite.py
View file @
711390a9
...
...
@@ -162,6 +162,9 @@ suites = {
# Add AMD tests
suite_amd
=
{
"per-commit-amd"
:
[
TestFile
(
"hicache/test_hicache.py"
,
116
),
TestFile
(
"hicache/test_hicache_mla.py"
,
127
),
TestFile
(
"hicache/test_hicache_storage.py"
,
127
),
TestFile
(
"lora/test_lora.py"
,
200
),
TestFile
(
"lora/test_lora_eviction.py"
,
200
),
TestFile
(
"lora/test_lora_backend.py"
,
99
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment