"vscode:/vscode.git/clone" did not exist on "77aadfee6a891ab9fcfb780f87c693f7a5beeb8e"
Unverified Commit 9c3e95d9 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

[AMD] Expand test coverage for AMD CI and enable...

[AMD] Expand test coverage for AMD CI and enable apply_token_bitmask_inplace_cuda in sgl-kernel (#8268)
parent e52c3866
...@@ -322,6 +322,7 @@ jobs: ...@@ -322,6 +322,7 @@ jobs:
docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_align.py docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_align.py
docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_topk_softmax.py docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_topk_softmax.py
docker exec -w /sglang-checkout/sgl-kernel/tests/speculative ci_sglang python3 -m pytest test_eagle_utils.py docker exec -w /sglang-checkout/sgl-kernel/tests/speculative ci_sglang python3 -m pytest test_eagle_utils.py
docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_apply_token_bitmask_inplace.py
pr-test-amd-finish: pr-test-amd-finish:
if: always() if: always()
......
...@@ -32,10 +32,15 @@ from sglang.srt.constrained.base_grammar_backend import ( ...@@ -32,10 +32,15 @@ from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend, BaseGrammarBackend,
BaseGrammarObject, BaseGrammarObject,
) )
from sglang.srt.constrained.triton_ops.bitmask_ops import ( from sglang.srt.utils import is_hip
apply_token_bitmask_inplace_triton,
) _is_hip = is_hip()
if _is_hip:
from sgl_kernel import apply_token_bitmask_inplace_cuda
else:
from sglang.srt.constrained.triton_ops.bitmask_ops import (
apply_token_bitmask_inplace_triton,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -94,7 +99,10 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -94,7 +99,10 @@ class XGrammarGrammar(BaseGrammarObject):
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
if logits.device.type == "cuda": if logits.device.type == "cuda":
apply_token_bitmask_inplace_triton(logits, vocab_mask) if _is_hip:
apply_token_bitmask_inplace_cuda(logits, vocab_mask)
else:
apply_token_bitmask_inplace_triton(logits, vocab_mask)
elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu: elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
self.apply_vocab_mask_cpu(logits, vocab_mask) self.apply_vocab_mask_cpu(logits, vocab_mask)
else: else:
......
...@@ -114,6 +114,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -114,6 +114,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> " "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
"()"); "()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
/*
* From XGrammar
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
} }
REGISTER_EXTENSION(common_ops) REGISTER_EXTENSION(common_ops)
...@@ -25,19 +25,24 @@ ...@@ -25,19 +25,24 @@
#include <torch/all.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
#if !defined(USE_ROCM) && (!defined(CUDA_VERSION) || CUDA_VERSION < 12040)
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt) { void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for ApplyTokenBitmaskInplace"); TORCH_CHECK(false, "CUDA version must be >= 12.4 for ApplyTokenBitmaskInplace");
} }
#else #else
#ifndef CUDART_INF_FP16 #ifndef CUDART_INF_FP16
#ifndef USE_ROCM
#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U) #define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U)
#endif #endif
#endif
#ifndef CUDART_INF_BF16 #ifndef CUDART_INF_BF16
#ifndef USE_ROCM
#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) #define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
#endif #endif
#endif
constexpr int32_t BITS_PER_BLOCK = 32; constexpr int32_t BITS_PER_BLOCK = 32;
constexpr int32_t THREADS_PER_THREAD_BLOCK = 256; constexpr int32_t THREADS_PER_THREAD_BLOCK = 256;
...@@ -49,12 +54,20 @@ __device__ T NegativeInfinity() { ...@@ -49,12 +54,20 @@ __device__ T NegativeInfinity() {
template <> template <>
__device__ __half NegativeInfinity<__half>() { __device__ __half NegativeInfinity<__half>() {
#ifdef USE_ROCM
return __float2half(-INFINITY);
#else
return -CUDART_INF_FP16; return -CUDART_INF_FP16;
#endif
} }
template <> template <>
__device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>() { __device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>() {
#ifdef USE_ROCM
return __nv_bfloat16(-INFINITY);
#else
return -CUDART_INF_BF16; return -CUDART_INF_BF16;
#endif
} }
template <typename T, typename PackedT> template <typename T, typename PackedT>
......
...@@ -48,6 +48,7 @@ sources = [ ...@@ -48,6 +48,7 @@ sources = [
"csrc/moe/moe_topk_softmax_kernels.cu", "csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/speculative/eagle_utils.cu", "csrc/speculative/eagle_utils.cu",
"csrc/common_extension_rocm.cc", "csrc/common_extension_rocm.cc",
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu",
] ]
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
......
...@@ -158,40 +158,66 @@ suites = { ...@@ -158,40 +158,66 @@ suites = {
# Add AMD tests # Add AMD tests
suite_amd = { suite_amd = {
"per-commit-amd": [ "per-commit-amd": [
TestFile("lora/test_lora.py", 200),
TestFile("lora/test_lora_eviction.py", 200),
TestFile("lora/test_lora_backend.py", 99), TestFile("lora/test_lora_backend.py", 99),
TestFile("lora/test_multi_lora_backend.py", 60), TestFile("lora/test_multi_lora_backend.py", 60),
TestFile("lora/test_lora_cuda_graph.py", 250), TestFile("lora/test_lora_cuda_graph.py", 250),
TestFile("lora/test_lora_qwen3.py", 97),
TestFile("models/test_embedding_models.py", 73),
TestFile("models/test_compressed_tensors_models.py", 42),
TestFile("models/test_qwen_models.py", 82), TestFile("models/test_qwen_models.py", 82),
TestFile("models/test_reward_models.py", 132), TestFile("models/test_reward_models.py", 132),
TestFile("models/test_transformers_models.py", 320),
TestFile("openai_server/basic/test_protocol.py", 10),
TestFile("openai_server/basic/test_serving_chat.py", 10),
TestFile("openai_server/basic/test_serving_completions.py", 10),
TestFile("openai_server/basic/test_serving_embedding.py", 10),
TestFile("openai_server/basic/test_openai_embedding.py", 141), TestFile("openai_server/basic/test_openai_embedding.py", 141),
TestFile("openai_server/basic/test_openai_server.py", 149),
TestFile("openai_server/features/test_enable_thinking.py", 70), TestFile("openai_server/features/test_enable_thinking.py", 70),
TestFile("openai_server/features/test_json_constrained.py", 98),
TestFile("openai_server/features/test_json_mode.py", 90),
TestFile("openai_server/features/test_openai_server_ebnf.py", 95),
# TestFile("openai_server/features/test_openai_server_hidden_states.py", 240),
TestFile("openai_server/features/test_reasoning_content.py", 89), TestFile("openai_server/features/test_reasoning_content.py", 89),
TestFile("openai_server/function_call/test_openai_function_calling.py", 60),
TestFile("openai_server/function_call/test_tool_choice.py", 226),
TestFile("openai_server/validation/test_large_max_new_tokens.py", 41), TestFile("openai_server/validation/test_large_max_new_tokens.py", 41),
TestFile("openai_server/validation/test_matched_stop.py", 60),
TestFile("openai_server/validation/test_openai_server_ignore_eos.py", 85),
TestFile("openai_server/validation/test_request_length_validation.py", 31), TestFile("openai_server/validation/test_request_length_validation.py", 31),
TestFile("quant/test_block_int8.py", 22), TestFile("quant/test_block_int8.py", 22),
TestFile("quant/test_awq_dequant.py", 2), TestFile("quant/test_awq_dequant.py", 2),
TestFile("rl/test_update_weights_from_disk.py", 114), TestFile("rl/test_update_weights_from_disk.py", 114),
# TestFile("rl/test_update_weights_from_tensor.py", 48),
TestFile("test_abort.py", 51), TestFile("test_abort.py", 51),
TestFile("test_create_kvindices.py", 2), TestFile("test_create_kvindices.py", 2),
TestFile("test_chunked_prefill.py", 313), TestFile("test_chunked_prefill.py", 313),
TestFile("test_ebnf_constrained.py", 108),
TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_function_call_parser.py", 10), TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 30), TestFile("test_fused_moe.py", 30),
TestFile("test_input_embeddings.py", 38), TestFile("test_input_embeddings.py", 38),
TestFile("test_io_struct.py", 8),
TestFile("test_jinja_template_utils.py", 1),
TestFile("test_metrics.py", 32),
TestFile("test_mla.py", 242), TestFile("test_mla.py", 242),
TestFile("test_mla_deepseek_v3.py", 221), TestFile("test_mla_deepseek_v3.py", 221),
TestFile("test_metrics.py", 32),
TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_chunked_prefill.py", 108),
# TestFile("test_no_overlap_scheduler.py", 234), # Disabled temporarily and track in #7703 # TestFile("test_no_overlap_scheduler.py", 234), # Disabled temporarily and track in #7703
TestFile("test_penalty.py", 41), TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60), TestFile("test_page_size.py", 60),
TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_pytorch_sampling_backend.py", 66),
TestFile("test_radix_attention.py", 105), TestFile("test_radix_attention.py", 105),
TestFile("test_regex_constrained.py", 64),
TestFile("test_retract_decode.py", 54), TestFile("test_retract_decode.py", 54),
TestFile("test_reasoning_parser.py", 5), TestFile("test_reasoning_parser.py", 5),
TestFile("test_rope_rocm.py", 3), TestFile("test_rope_rocm.py", 3),
TestFile("test_server_args.py", 1), TestFile("test_server_args.py", 1),
TestFile("test_skip_tokenizer_init.py", 117), TestFile("test_skip_tokenizer_init.py", 117),
TestFile("test_srt_engine.py", 261),
TestFile("test_srt_endpoint.py", 130),
TestFile("test_torch_compile.py", 76), TestFile("test_torch_compile.py", 76),
TestFile("test_torch_compile_moe.py", 172), TestFile("test_torch_compile_moe.py", 172),
TestFile("test_torch_native_attention_backend.py", 123), TestFile("test_torch_native_attention_backend.py", 123),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment