Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
af6535e7
Unverified
Commit
af6535e7
authored
Mar 24, 2025
by
Alex Sun
Committed by
GitHub
Mar 23, 2025
Browse files
[ROCm] Enable MTP (NextN) on AMD GPU (#4631)
parent
93cf7fc5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
43 additions
and
4 deletions
+43
-4
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+2
-2
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+3
-1
sgl-kernel/csrc/speculative/eagle_utils.cu
sgl-kernel/csrc/speculative/eagle_utils.cu
+4
-0
sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h
sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h
+20
-0
sgl-kernel/csrc/torch_extension_rocm.cc
sgl-kernel/csrc/torch_extension_rocm.cc
+12
-0
sgl-kernel/setup_rocm.py
sgl-kernel/setup_rocm.py
+1
-0
test/srt/test_mla_deepseek_v3.py
test/srt/test_mla_deepseek_v3.py
+1
-1
No files found.
python/sglang/srt/speculative/build_eagle_tree.py
View file @
af6535e7
...
@@ -4,9 +4,9 @@ from typing import List
...
@@ -4,9 +4,9 @@ from typing import List
import
torch
import
torch
from
sglang.srt.utils
import
is_cuda_available
from
sglang.srt.utils
import
is_cuda_available
,
is_hip
if
is_cuda_available
():
if
is_cuda_available
()
or
is_hip
()
:
from
sgl_kernel
import
(
from
sgl_kernel
import
(
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
)
)
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
af6535e7
...
@@ -14,7 +14,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
...
@@ -14,7 +14,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.utils
import
is_cuda_available
from
sglang.srt.utils
import
is_cuda_available
,
is_hip
if
is_cuda_available
():
if
is_cuda_available
():
from
sgl_kernel
import
(
from
sgl_kernel
import
(
...
@@ -23,6 +23,8 @@ if is_cuda_available():
...
@@ -23,6 +23,8 @@ if is_cuda_available():
tree_speculative_sampling_target_only
,
tree_speculative_sampling_target_only
,
verify_tree_greedy
,
verify_tree_greedy
,
)
)
elif
is_hip
():
from
sgl_kernel
import
verify_tree_greedy
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
...
sgl-kernel/csrc/speculative/eagle_utils.cu
View file @
af6535e7
...
@@ -17,7 +17,11 @@
...
@@ -17,7 +17,11 @@
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include "pytorch_extension_utils.h"
#include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
#endif
// parent_list [bs, topk * (depth - 1) + 1)]
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// selected_index [bs, draft_token_num - 1]
...
...
sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h
0 → 100644
View file @
af6535e7
#include <torch/library.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CUDA(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
sgl-kernel/csrc/torch_extension_rocm.cc
View file @
af6535e7
...
@@ -65,6 +65,18 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
...
@@ -65,6 +65,18 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"
);
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
def
(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()"
);
m
.
impl
(
"verify_tree_greedy"
,
torch
::
kCUDA
,
&
verify_tree_greedy
);
m
.
def
(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"
);
m
.
impl
(
"build_tree_kernel_efficient"
,
torch
::
kCUDA
,
&
build_tree_kernel_efficient
);
}
}
REGISTER_EXTENSION
(
common_ops
)
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/setup_rocm.py
View file @
af6535e7
...
@@ -43,6 +43,7 @@ sources = [
...
@@ -43,6 +43,7 @@ sources = [
"csrc/moe/moe_align_kernel.cu"
,
"csrc/moe/moe_align_kernel.cu"
,
"csrc/moe/moe_topk_softmax_kernels.cu"
,
"csrc/moe/moe_topk_softmax_kernels.cu"
,
"csrc/torch_extension_rocm.cc"
,
"csrc/torch_extension_rocm.cc"
,
"csrc/speculative/eagle_utils.cu"
,
]
]
cxx_flags
=
[
"-O3"
]
cxx_flags
=
[
"-O3"
]
...
...
test/srt/test_mla_deepseek_v3.py
View file @
af6535e7
...
@@ -54,7 +54,7 @@ class TestDeepseekV3MTP(unittest.TestCase):
...
@@ -54,7 +54,7 @@ class TestDeepseekV3MTP(unittest.TestCase):
cls
.
model
=
"lmsys/sglang-ci-dsv3-test"
cls
.
model
=
"lmsys/sglang-ci-dsv3-test"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
[
"--trust-remote-code"
]
other_args
=
[
"--trust-remote-code"
]
if
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
:
if
torch
.
cuda
.
is_available
()
and
(
torch
.
version
.
cuda
or
torch
.
version
.
hip
)
:
other_args
.
extend
(
other_args
.
extend
(
[
[
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment