Unverified Commit af6535e7 authored by Alex Sun's avatar Alex Sun Committed by GitHub
Browse files

[ROCm] Enable MTP (NextN) on AMD GPU (#4631)

parent 93cf7fc5
...@@ -4,9 +4,9 @@ from typing import List ...@@ -4,9 +4,9 @@ from typing import List
import torch import torch
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available, is_hip
if is_cuda_available(): if is_cuda_available() or is_hip():
from sgl_kernel import ( from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
) )
......
...@@ -14,7 +14,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -14,7 +14,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available, is_hip
if is_cuda_available(): if is_cuda_available():
from sgl_kernel import ( from sgl_kernel import (
...@@ -23,6 +23,8 @@ if is_cuda_available(): ...@@ -23,6 +23,8 @@ if is_cuda_available():
tree_speculative_sampling_target_only, tree_speculative_sampling_target_only,
verify_tree_greedy, verify_tree_greedy,
) )
elif is_hip():
from sgl_kernel import verify_tree_greedy
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
......
...@@ -17,7 +17,11 @@ ...@@ -17,7 +17,11 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include "pytorch_extension_utils.h" #include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
#endif
// parent_list [bs, topk * (depth - 1) + 1)] // parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1] // selected_index [bs, draft_token_num - 1]
......
#include <torch/library.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CUDA(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
...@@ -65,6 +65,18 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -65,6 +65,18 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"); "token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
} }
REGISTER_EXTENSION(common_ops) REGISTER_EXTENSION(common_ops)
...@@ -43,6 +43,7 @@ sources = [ ...@@ -43,6 +43,7 @@ sources = [
"csrc/moe/moe_align_kernel.cu", "csrc/moe/moe_align_kernel.cu",
"csrc/moe/moe_topk_softmax_kernels.cu", "csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/torch_extension_rocm.cc", "csrc/torch_extension_rocm.cc",
"csrc/speculative/eagle_utils.cu",
] ]
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
......
...@@ -54,7 +54,7 @@ class TestDeepseekV3MTP(unittest.TestCase): ...@@ -54,7 +54,7 @@ class TestDeepseekV3MTP(unittest.TestCase):
cls.model = "lmsys/sglang-ci-dsv3-test" cls.model = "lmsys/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"] other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda: if torch.cuda.is_available() and (torch.version.cuda or torch.version.hip):
other_args.extend( other_args.extend(
[ [
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment