Unverified Commit c480a3f6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Minor style fixes for sgl-kernel (#9289)

parent 6e316588
...@@ -73,6 +73,20 @@ If you modify files protected by code owners, their approval is required to merg ...@@ -73,6 +73,20 @@ If you modify files protected by code owners, their approval is required to merg
- Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code. - Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code.
- Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files. - Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files.
- Prioritize extreme efficiency. SGLang is a runtime, and most of your code runs on the critical path for every request. Optimize every minor overhead as much as possible. - Prioritize extreme efficiency. SGLang is a runtime, and most of your code runs on the critical path for every request. Optimize every minor overhead as much as possible.
- Try to make functions as pure as possible. Avoid in-place modification of arguments.
## How to update sgl-kernel
Since sglang and sgl-kernel are separate Python packages, our current GitHub CI infrastructure does not support updating a kernel and using it immediately within the same pull request (PR). To add a new kernel or modify an existing one in the sgl-kernel package, you must use multiple PRs.
Follow these steps:
1. Submit a PR to update the sgl-kernel source code without using it (e.g., [#8884](https://github.com/sgl-project/sglang/pull/8884/files)).
2. Bump the version of sgl-kernel (e.g., [#9220](https://github.com/sgl-project/sglang/pull/9220/files)).
- Once merged, this will trigger an automatic release of the sgl-kernel wheel to PyPI.
- If not urgent, you can wait for other people to release the wheel. A new version will typically be released within one week.
3. Apply the changes:
- Update the sgl-kernel version in `sglang/python/pyproject.toml` to use the modified kernels.
- Update the related caller code in the sglang to use the new kernel.
## Tips for newcomers ## Tips for newcomers
......
...@@ -39,9 +39,9 @@ runtime_common = [ ...@@ -39,9 +39,9 @@ runtime_common = [
"pillow", "pillow",
"prometheus-client>=0.20.0", "prometheus-client>=0.20.0",
"psutil", "psutil",
"pybase64",
"pydantic", "pydantic",
"pynvml", "pynvml",
"pybase64",
"python-multipart", "python-multipart",
"pyzmq>=25.1.2", "pyzmq>=25.1.2",
"sentencepiece", "sentencepiece",
......
...@@ -12,7 +12,6 @@ from dataclasses import dataclass ...@@ -12,7 +12,6 @@ from dataclasses import dataclass
import httpx import httpx
import numpy as np import numpy as np
import openai import openai
import transformers
from datasets import load_dataset from datasets import load_dataset
from openai import AsyncOpenAI from openai import AsyncOpenAI
from tqdm import tqdm from tqdm import tqdm
......
...@@ -9,7 +9,6 @@ import argparse ...@@ -9,7 +9,6 @@ import argparse
import json import json
import os import os
import time import time
import urllib.parse
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
......
...@@ -5,7 +5,6 @@ import json ...@@ -5,7 +5,6 @@ import json
import logging import logging
import os import os
import random import random
import signal
import socket import socket
import subprocess import subprocess
import sys import sys
......
...@@ -36,7 +36,7 @@ def read_records(files): ...@@ -36,7 +36,7 @@ def read_records(files):
def run_one_request_internal(record): def run_one_request_internal(record):
(req, output, replay_init_time, start_time, end_time, idx) = record (req, output, replay_init_time, start_time, end_time, idx) = record
time.sleep(max(0, start_time - (time.time() - replay_init_time))) time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
if "completion_tokens" in output.get("meta_info", {}): if "completion_tokens" in output.get("meta_info", {}):
recorded_completion_tokens = output["meta_info"]["completion_tokens"] recorded_completion_tokens = output["meta_info"]["completion_tokens"]
...@@ -121,6 +121,7 @@ if __name__ == "__main__": ...@@ -121,6 +121,7 @@ if __name__ == "__main__":
parser.add_argument("--parallel", type=int, default=512) parser.add_argument("--parallel", type=int, default=512)
parser.add_argument("--idx", type=int, default=None) parser.add_argument("--idx", type=int, default=None)
parser.add_argument("--ignore-eos", action="store_true") parser.add_argument("--ignore-eos", action="store_true")
parser.add_argument("--speed", type=float, default=1)
args = parser.parse_args() args = parser.parse_args()
set_ulimit() set_ulimit()
......
...@@ -223,17 +223,19 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE ...@@ -223,17 +223,19 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
set(SOURCES set(SOURCES
"csrc/allreduce/mscclpp_allreduce.cu"
"csrc/allreduce/custom_all_reduce.cu" "csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/mscclpp_allreduce.cu"
"csrc/attention/cascade.cu" "csrc/attention/cascade.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/cutlass_mla_kernel.cu" "csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/attention/lightning_attention_decode_kernel.cu" "csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/elementwise/activation.cu" "csrc/elementwise/activation.cu"
"csrc/elementwise/cast.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/rope.cu" "csrc/elementwise/rope.cu"
"csrc/common_extension.cc" "csrc/common_extension.cc"
"csrc/gemm/awq_kernel.cu" "csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu" "csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu" "csrc/gemm/dsv3_fused_a_gemm.cu"
...@@ -257,7 +259,9 @@ set(SOURCES ...@@ -257,7 +259,9 @@ set(SOURCES
"csrc/gemm/marlin/gptq_marlin_repack.cu" "csrc/gemm/marlin/gptq_marlin_repack.cu"
"csrc/gemm/marlin/awq_marlin_repack.cu" "csrc/gemm/marlin/awq_marlin_repack.cu"
"csrc/gemm/gptq/gptq_kernel.cu" "csrc/gemm/gptq/gptq_kernel.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu" "csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu" "csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
...@@ -276,14 +280,18 @@ set(SOURCES ...@@ -276,14 +280,18 @@ set(SOURCES
"csrc/moe/prepare_moe_input.cu" "csrc/moe/prepare_moe_input.cu"
"csrc/moe/ep_moe_reorder_kernel.cu" "csrc/moe/ep_moe_reorder_kernel.cu"
"csrc/moe/ep_moe_silu_and_mul_kernel.cu" "csrc/moe/ep_moe_silu_and_mul_kernel.cu"
"csrc/memory/store.cu"
"csrc/kvcacheio/transfer.cu" "csrc/kvcacheio/transfer.cu"
"csrc/speculative/eagle_utils.cu" "csrc/speculative/eagle_utils.cu"
"csrc/speculative/packbit.cu" "csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu" "csrc/speculative/speculative_sampling.cu"
"csrc/memory/store.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
......
...@@ -17,6 +17,7 @@ limitations under the License. ...@@ -17,6 +17,7 @@ limitations under the License.
#include <torch/library.h> #include <torch/library.h>
#include "sgl_kernel_ops.h" #include "sgl_kernel_ops.h"
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/* /*
* From csrc/allreduce * From csrc/allreduce
...@@ -93,6 +94,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -93,6 +94,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"); "Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
m.def(
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, int "
"mult, int offset, int cuda_stream) -> ()");
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
...@@ -161,7 +167,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -161,7 +167,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm); m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
// GPTQ related method /*
* From csrc/gemm/gptq
*/
m.def( m.def(
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none," "gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none," "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
...@@ -183,6 +191,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -183,6 +191,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor"); m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack); m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
/* /*
* From csrc/moe * From csrc/moe
*/ */
...@@ -229,6 +238,41 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -229,6 +238,41 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"); m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum); m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
/*
* From csrc/moe/marlin_moe_wna16
*/
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
m.def(
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()");
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
m.def(
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
" int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
...@@ -306,25 +350,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -306,25 +350,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()"); m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
m.impl("store_kv_cache", &store_kv_cache); m.impl("store_kv_cache", &store_kv_cache);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
m.def(
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()");
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
m.def(
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
" int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/* /*
* From FlashInfer * From FlashInfer
*/ */
...@@ -358,19 +383,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -358,19 +383,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()"); m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits); m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/* /*
* From Sparse Flash Attention * From Sparse Flash Attention
*/ */
......
...@@ -33,6 +33,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -33,6 +33,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("gelu_quick(Tensor! out, Tensor input) -> ()"); m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
m.impl("gelu_quick", torch::kCUDA, &gelu_quick); m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
/* /*
* From csrc/allreduce * From csrc/allreduce
*/ */
......
#include "pytorch_extension_utils.h"
template <typename T>
struct ConvertToFP8 {
static __device__ __nv_fp8_storage_t convert_to_fp8(T value) {
return 0;
}
};
template <>
struct ConvertToFP8<__nv_bfloat16> {
static __device__ __nv_fp8_storage_t convert_to_fp8(__nv_bfloat16 value) {
return __nv_cvt_bfloat16raw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
}
};
template <>
struct ConvertToFP8<half> {
static __device__ __nv_fp8_storage_t convert_to_fp8(half value) {
return __nv_cvt_halfraw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
}
};
template <typename T>
struct ConvertFromFloat {
static __device__ T convert_from_float(float value) {
return 0;
}
};
template <>
struct ConvertFromFloat<__nv_bfloat16> {
static __device__ __nv_bfloat16 convert_from_float(float value) {
return __float2bfloat16(value);
}
};
template <>
struct ConvertFromFloat<half> {
static __device__ half convert_from_float(float value) {
return __float2half(value);
}
};
template <typename T>
__global__ void fused_downcast_kernel(
const T* cache_k,
const T* cache_v,
const float* k_scale,
const float* v_scale,
__nv_fp8_storage_t* output_k,
__nv_fp8_storage_t* output_v,
const int input_sl,
const int head,
const int dim,
const T max_fp8,
const T min_fp8,
const int64_t mult,
const int64_t offset,
const int64_t* loc) {
// TODO: change name
int token_idx = blockIdx.x;
int thread_idx = threadIdx.x;
int total_threads = blockDim.x;
T k_scale_val = ConvertFromFloat<T>::convert_from_float(k_scale[0]);
T v_scale_val = ConvertFromFloat<T>::convert_from_float(v_scale[0]);
T k_scale_inv = static_cast<T>(1.f) / k_scale_val;
T v_scale_inv = static_cast<T>(1.f) / v_scale_val;
auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); };
if (token_idx < input_sl) {
int out_seq_idx = loc[token_idx];
#pragma unroll
for (int i = thread_idx; i < head * dim; i += total_threads) {
int in_idx = token_idx * head * dim + i;
int out_idx = (out_seq_idx * mult + offset) * head * dim + i;
T k_val = cache_k[in_idx] * k_scale_inv;
k_val = clamp(k_val);
output_k[out_idx] = ConvertToFP8<T>::convert_to_fp8(k_val);
T v_val = cache_v[in_idx] * v_scale_inv;
v_val = clamp(v_val);
output_v[out_idx] = ConvertToFP8<T>::convert_to_fp8(v_val);
}
}
}
template <typename T>
void downcast_fp8_impl(
at::Tensor& k,
at::Tensor& v,
at::Tensor& k_out,
at::Tensor& v_out,
at::Tensor& k_scale,
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
cudaStream_t stream) {
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_INPUT(k_out);
CHECK_INPUT(v_out);
CHECK_INPUT(k_scale);
CHECK_INPUT(v_scale);
CHECK_INPUT(loc);
int64_t input_sl = k.size(0);
int64_t head = k.size(1);
int64_t dim = k.size(2);
dim3 grid(input_sl * head);
int vec_size = 8;
dim3 block(std::min(int(dim) / vec_size, 1024));
const T max_fp8 = static_cast<T>(448.0f);
const T min_fp8 = static_cast<T>(-448.0f);
fused_downcast_kernel<T><<<grid, block, 0, stream>>>(
static_cast<const T*>(k.data_ptr()),
static_cast<const T*>(v.data_ptr()),
static_cast<const float*>(k_scale.data_ptr()),
static_cast<const float*>(v_scale.data_ptr()),
static_cast<__nv_fp8_storage_t*>(k_out.data_ptr()),
static_cast<__nv_fp8_storage_t*>(v_out.data_ptr()),
input_sl,
head,
dim,
max_fp8,
min_fp8,
mult,
offset,
static_cast<const int64_t*>(loc.data_ptr()));
cudaError_t status = cudaGetLastError();
TORCH_CHECK(status == cudaSuccess, "Kernel launch failed: " + std::string(cudaGetErrorString(status)));
}
void downcast_fp8(
at::Tensor& k,
at::Tensor& v,
at::Tensor& k_out,
at::Tensor& v_out,
at::Tensor& k_scale,
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
int64_t cuda_stream) {
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_INPUT(k_out);
CHECK_INPUT(v_out);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
switch (k.scalar_type()) {
case at::ScalarType::BFloat16:
downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
break;
case at::ScalarType::Half:
downcast_fp8_impl<__half>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
break;
default:
TORCH_CHECK(false, "Unsupported input type for downcast_fp8. Expected bfloat16 or float16.");
}
}
...@@ -122,6 +122,95 @@ __global__ void build_tree_efficient( ...@@ -122,6 +122,95 @@ __global__ void build_tree_efficient(
} }
} }
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask: [draft_token*num_bytes_per_item | .. ] = [bs*draft_token*num_bytes_per_item]
// positions [bs * draft_token]
// retrive_index [bs, draft_token]
// retrive_next_token [bs, draft_token]
// retrive_next_sibling [bs, draft_token]
__global__ void build_tree_efficient_partial_packed(
int64_t* parent_list,
int64_t* selected_index,
int64_t* verified_seq_len,
uint8_t* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num,
size_t num_bytes_per_item) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num) {
return;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = (bid * draft_token_num + tid) * num_bytes_per_item;
tree_mask[token_tree_idx] = 1; // little endian
int position = 0;
if (tid == 0) {
positions[bid * draft_token_num] = seq_len;
int retrive_index_offset = bid * draft_token_num;
for (int i = draft_token_num - 1; i > 0; --i) {
int current_token_idx = retrive_index_offset + i;
retrive_index[bid * draft_token_num + i] = current_token_idx;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
int parent_position = 0;
if (parent_tb_idx > 0) {
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (; parent_position < draft_token_num; ++parent_position) {
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
++parent_position;
break;
}
}
}
if (parent_position == draft_token_num) {
printf(
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
continue;
}
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
retrive_next_token[bid * draft_token_num + parent_position] = i;
} else {
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
retrive_next_token[bid * draft_token_num + parent_position] = i;
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
}
}
retrive_index[bid * draft_token_num] = bid * draft_token_num;
} else {
int cur_position = tid - 1;
while (true) {
position += 1;
int byte_idx = (cur_position + 1) / 8;
int bit_idx = (cur_position + 1) % 8;
tree_mask[token_tree_idx + byte_idx] |= (1 << bit_idx);
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
if (parent_tb_idx == 0) {
break;
}
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
break;
}
}
}
positions[bid * draft_token_num + tid] = position + seq_len;
}
}
void build_tree_kernel_efficient( void build_tree_kernel_efficient(
at::Tensor parent_list, at::Tensor parent_list,
at::Tensor selected_index, at::Tensor selected_index,
...@@ -149,7 +238,19 @@ void build_tree_kernel_efficient( ...@@ -149,7 +238,19 @@ void build_tree_kernel_efficient(
} else if (draft_token_num > 8) { } else if (draft_token_num > 8) {
num_bytes_per_item = 2; num_bytes_per_item = 2;
} }
throw std::runtime_error("Not implemented"); build_tree_efficient_partial_packed<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<uint8_t*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num),
num_bytes_per_item);
} else { } else {
build_tree_efficient<<<grid, block, 0, stream>>>( build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(parent_list.data_ptr()),
......
...@@ -130,6 +130,7 @@ int64_t cutlass_mla_get_workspace_size( ...@@ -130,6 +130,7 @@ int64_t cutlass_mla_get_workspace_size(
int64_t num_batches, int64_t num_batches,
int64_t sm_count = 0, int64_t sm_count = 0,
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */); int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
/* /*
* From csrc/elementwise * From csrc/elementwise
*/ */
...@@ -156,9 +157,22 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -156,9 +157,22 @@ void apply_rope_pos_ids_cos_sin_cache(
const std::optional<at::Tensor>& v_buffer, const std::optional<at::Tensor>& v_buffer,
const std::optional<at::Tensor>& kv_cache_loc); const std::optional<at::Tensor>& kv_cache_loc);
void downcast_fp8(
at::Tensor& k,
at::Tensor& v,
at::Tensor& k_out,
at::Tensor& v_out,
at::Tensor& k_scale,
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
int64_t cuda_stream);
#ifdef USE_ROCM #ifdef USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input); void gelu_quick(at::Tensor& out, const at::Tensor& input);
#endif #endif
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
...@@ -221,7 +235,6 @@ void bmm_fp8( ...@@ -221,7 +235,6 @@ void bmm_fp8(
int64_t cublas_handle, int64_t cublas_handle,
int64_t cuda_stream); int64_t cuda_stream);
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b); void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
torch::Tensor gptq_marlin_gemm( torch::Tensor gptq_marlin_gemm(
...@@ -258,6 +271,7 @@ torch::Tensor ...@@ -258,6 +271,7 @@ torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits); torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
/* /*
* From csrc/moe * From csrc/moe
*/ */
...@@ -374,6 +388,61 @@ void scaled_fp4_experts_quant( ...@@ -374,6 +388,61 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts, torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts); torch::Tensor const& output_scale_offset_by_experts);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
torch::Tensor& sorted_token_ids,
torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded,
torch::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
...@@ -527,35 +596,6 @@ void transfer_kv_direct( ...@@ -527,35 +596,6 @@ void transfer_kv_direct(
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t page_size); int64_t page_size);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
/* /*
* From FlashInfer * From FlashInfer
*/ */
...@@ -597,32 +637,6 @@ void top_p_sampling_from_probs( ...@@ -597,32 +637,6 @@ void top_p_sampling_from_probs(
void top_k_mask_logits( void top_k_mask_logits(
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val); at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
torch::Tensor& sorted_token_ids,
torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded,
torch::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
namespace flash { namespace flash {
/* /*
* From fa2 sparse * From fa2 sparse
......
...@@ -31,11 +31,11 @@ from sgl_kernel.elementwise import ( ...@@ -31,11 +31,11 @@ from sgl_kernel.elementwise import (
rmsnorm, rmsnorm,
silu_and_mul, silu_and_mul,
) )
from sgl_kernel.fused_moe import fused_marlin_moe
if torch.version.hip is not None: if torch.version.hip is not None:
from sgl_kernel.elementwise import gelu_quick from sgl_kernel.elementwise import gelu_quick
from sgl_kernel.fused_moe import fused_marlin_moe
from sgl_kernel.gemm import ( from sgl_kernel.gemm import (
awq_dequantize, awq_dequantize,
bmm_fp8, bmm_fp8,
...@@ -114,7 +114,3 @@ from sgl_kernel.speculative import ( ...@@ -114,7 +114,3 @@ from sgl_kernel.speculative import (
) )
from sgl_kernel.top_k import fast_topk from sgl_kernel.top_k import fast_topk
from sgl_kernel.version import __version__ from sgl_kernel.version import __version__
build_tree_kernel = (
None # TODO(ying): remove this after updating the sglang python code.
)
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Optional
import torch import torch
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
...@@ -345,3 +345,19 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -345,3 +345,19 @@ def apply_rope_with_cos_sin_cache_inplace(
else None else None
), ),
) )
def downcast_fp8(
k: torch.Tensor,
v: torch.Tensor,
k_out: torch.Tensor,
v_out: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
loc: torch.Tensor,
mult: int = 1,
offset: int = 0,
) -> None:
torch.ops.sgl_kernel.downcast_fp8(
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
)
...@@ -160,7 +160,7 @@ def fused_marlin_moe( ...@@ -160,7 +160,7 @@ def fused_marlin_moe(
size_m=M, size_m=M,
size_n=2 * N, size_n=2 * N,
size_k=K, size_k=K,
is_full_k=is_k_full, is_k_full=is_k_full,
use_atomic_add=use_atomic_add, use_atomic_add=use_atomic_add,
use_fp32_reduce=True, use_fp32_reduce=True,
is_zp_float=False, is_zp_float=False,
...@@ -192,7 +192,7 @@ def fused_marlin_moe( ...@@ -192,7 +192,7 @@ def fused_marlin_moe(
size_m=M * topk, size_m=M * topk,
size_n=K, size_n=K,
size_k=N, size_k=N,
is_full_k=is_k_full, is_k_full=is_k_full,
use_atomic_add=use_atomic_add, use_atomic_add=use_atomic_add,
use_fp32_reduce=True, use_fp32_reduce=True,
is_zp_float=False, is_zp_float=False,
......
from typing import Optional, Tuple, Union from typing import Optional, Union
import torch import torch
from sgl_kernel.utils import _to_tensor_scalar_tuple from sgl_kernel.utils import _to_tensor_scalar_tuple
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# ============================================================================== # ==============================================================================
import functools import functools
import subprocess
from typing import Dict, Tuple from typing import Dict, Tuple
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment