Unverified Commit c480a3f6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Minor style fixes for sgl-kernel (#9289)

parent 6e316588
......@@ -73,6 +73,20 @@ If you modify files protected by code owners, their approval is required to merg
- Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code.
- Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files.
- Prioritize extreme efficiency. SGLang is a runtime, and most of your code runs on the critical path for every request. Optimize every minor overhead as much as possible.
- Try to make functions as pure as possible. Avoid in-place modification of arguments.
## How to update sgl-kernel
Since sglang and sgl-kernel are separate Python packages, our current GitHub CI infrastructure does not support updating a kernel and using it immediately within the same pull request (PR). To add a new kernel or modify an existing one in the sgl-kernel package, you must use multiple PRs.
Follow these steps:
1. Submit a PR to update the sgl-kernel source code without using it (e.g., [#8884](https://github.com/sgl-project/sglang/pull/8884/files)).
2. Bump the version of sgl-kernel (e.g., [#9220](https://github.com/sgl-project/sglang/pull/9220/files)).
- Once merged, this will trigger an automatic release of the sgl-kernel wheel to PyPI.
- If not urgent, you can wait for other people to release the wheel. A new version will typically be released within one week.
3. Apply the changes:
- Update the sgl-kernel version in `sglang/python/pyproject.toml` to use the modified kernels.
- Update the related caller code in the sglang to use the new kernel.
## Tips for newcomers
......
......@@ -39,9 +39,9 @@ runtime_common = [
"pillow",
"prometheus-client>=0.20.0",
"psutil",
"pybase64",
"pydantic",
"pynvml",
"pybase64",
"python-multipart",
"pyzmq>=25.1.2",
"sentencepiece",
......
......@@ -12,7 +12,6 @@ from dataclasses import dataclass
import httpx
import numpy as np
import openai
import transformers
from datasets import load_dataset
from openai import AsyncOpenAI
from tqdm import tqdm
......
......@@ -9,7 +9,6 @@ import argparse
import json
import os
import time
import urllib.parse
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Optional
......
......@@ -5,7 +5,6 @@ import json
import logging
import os
import random
import signal
import socket
import subprocess
import sys
......
......@@ -36,7 +36,7 @@ def read_records(files):
def run_one_request_internal(record):
(req, output, replay_init_time, start_time, end_time, idx) = record
time.sleep(max(0, start_time - (time.time() - replay_init_time)))
time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
if "completion_tokens" in output.get("meta_info", {}):
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
......@@ -121,6 +121,7 @@ if __name__ == "__main__":
parser.add_argument("--parallel", type=int, default=512)
parser.add_argument("--idx", type=int, default=None)
parser.add_argument("--ignore-eos", action="store_true")
parser.add_argument("--speed", type=float, default=1)
args = parser.parse_args()
set_ulimit()
......
......@@ -223,17 +223,19 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
set(SOURCES
"csrc/allreduce/mscclpp_allreduce.cu"
"csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/mscclpp_allreduce.cu"
"csrc/attention/cascade.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/elementwise/activation.cu"
"csrc/elementwise/cast.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/rope.cu"
"csrc/common_extension.cc"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
......@@ -257,7 +259,9 @@ set(SOURCES
"csrc/gemm/marlin/gptq_marlin_repack.cu"
"csrc/gemm/marlin/awq_marlin_repack.cu"
"csrc/gemm/gptq/gptq_kernel.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
......@@ -276,14 +280,18 @@ set(SOURCES
"csrc/moe/prepare_moe_input.cu"
"csrc/moe/ep_moe_reorder_kernel.cu"
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
"csrc/memory/store.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/memory/store.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
......
......@@ -17,6 +17,7 @@ limitations under the License.
#include <torch/library.h>
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From csrc/allreduce
......@@ -93,6 +94,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
m.def(
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, int "
"mult, int offset, int cuda_stream) -> ()");
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
/*
* From csrc/gemm
*/
......@@ -161,7 +167,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
// GPTQ related method
/*
* From csrc/gemm/gptq
*/
m.def(
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
......@@ -183,6 +191,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
/*
* From csrc/moe
*/
......@@ -229,6 +238,41 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
/*
* From csrc/moe/marlin_moe_wna16
*/
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
m.def(
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()");
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
m.def(
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
" int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/*
* From csrc/speculative
*/
......@@ -306,25 +350,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
m.impl("store_kv_cache", &store_kv_cache);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
m.def(
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()");
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
m.def(
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
" int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/*
* From FlashInfer
*/
......@@ -358,19 +383,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/*
* From Sparse Flash Attention
*/
......
......@@ -33,6 +33,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
/*
* From csrc/allreduce
*/
......
#include "pytorch_extension_utils.h"
template <typename T>
struct ConvertToFP8 {
static __device__ __nv_fp8_storage_t convert_to_fp8(T value) {
return 0;
}
};
template <>
struct ConvertToFP8<__nv_bfloat16> {
static __device__ __nv_fp8_storage_t convert_to_fp8(__nv_bfloat16 value) {
return __nv_cvt_bfloat16raw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
}
};
template <>
struct ConvertToFP8<half> {
static __device__ __nv_fp8_storage_t convert_to_fp8(half value) {
return __nv_cvt_halfraw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
}
};
template <typename T>
struct ConvertFromFloat {
static __device__ T convert_from_float(float value) {
return 0;
}
};
template <>
struct ConvertFromFloat<__nv_bfloat16> {
static __device__ __nv_bfloat16 convert_from_float(float value) {
return __float2bfloat16(value);
}
};
template <>
struct ConvertFromFloat<half> {
static __device__ half convert_from_float(float value) {
return __float2half(value);
}
};
template <typename T>
__global__ void fused_downcast_kernel(
const T* cache_k,
const T* cache_v,
const float* k_scale,
const float* v_scale,
__nv_fp8_storage_t* output_k,
__nv_fp8_storage_t* output_v,
const int input_sl,
const int head,
const int dim,
const T max_fp8,
const T min_fp8,
const int64_t mult,
const int64_t offset,
const int64_t* loc) {
// TODO: change name
int token_idx = blockIdx.x;
int thread_idx = threadIdx.x;
int total_threads = blockDim.x;
T k_scale_val = ConvertFromFloat<T>::convert_from_float(k_scale[0]);
T v_scale_val = ConvertFromFloat<T>::convert_from_float(v_scale[0]);
T k_scale_inv = static_cast<T>(1.f) / k_scale_val;
T v_scale_inv = static_cast<T>(1.f) / v_scale_val;
auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); };
if (token_idx < input_sl) {
int out_seq_idx = loc[token_idx];
#pragma unroll
for (int i = thread_idx; i < head * dim; i += total_threads) {
int in_idx = token_idx * head * dim + i;
int out_idx = (out_seq_idx * mult + offset) * head * dim + i;
T k_val = cache_k[in_idx] * k_scale_inv;
k_val = clamp(k_val);
output_k[out_idx] = ConvertToFP8<T>::convert_to_fp8(k_val);
T v_val = cache_v[in_idx] * v_scale_inv;
v_val = clamp(v_val);
output_v[out_idx] = ConvertToFP8<T>::convert_to_fp8(v_val);
}
}
}
template <typename T>
void downcast_fp8_impl(
at::Tensor& k,
at::Tensor& v,
at::Tensor& k_out,
at::Tensor& v_out,
at::Tensor& k_scale,
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
cudaStream_t stream) {
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_INPUT(k_out);
CHECK_INPUT(v_out);
CHECK_INPUT(k_scale);
CHECK_INPUT(v_scale);
CHECK_INPUT(loc);
int64_t input_sl = k.size(0);
int64_t head = k.size(1);
int64_t dim = k.size(2);
dim3 grid(input_sl * head);
int vec_size = 8;
dim3 block(std::min(int(dim) / vec_size, 1024));
const T max_fp8 = static_cast<T>(448.0f);
const T min_fp8 = static_cast<T>(-448.0f);
fused_downcast_kernel<T><<<grid, block, 0, stream>>>(
static_cast<const T*>(k.data_ptr()),
static_cast<const T*>(v.data_ptr()),
static_cast<const float*>(k_scale.data_ptr()),
static_cast<const float*>(v_scale.data_ptr()),
static_cast<__nv_fp8_storage_t*>(k_out.data_ptr()),
static_cast<__nv_fp8_storage_t*>(v_out.data_ptr()),
input_sl,
head,
dim,
max_fp8,
min_fp8,
mult,
offset,
static_cast<const int64_t*>(loc.data_ptr()));
cudaError_t status = cudaGetLastError();
TORCH_CHECK(status == cudaSuccess, "Kernel launch failed: " + std::string(cudaGetErrorString(status)));
}
void downcast_fp8(
at::Tensor& k,
at::Tensor& v,
at::Tensor& k_out,
at::Tensor& v_out,
at::Tensor& k_scale,
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
int64_t cuda_stream) {
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_INPUT(k_out);
CHECK_INPUT(v_out);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
switch (k.scalar_type()) {
case at::ScalarType::BFloat16:
downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
break;
case at::ScalarType::Half:
downcast_fp8_impl<__half>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
break;
default:
TORCH_CHECK(false, "Unsupported input type for downcast_fp8. Expected bfloat16 or float16.");
}
}
......@@ -122,6 +122,95 @@ __global__ void build_tree_efficient(
}
}
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask: [draft_token*num_bytes_per_item | .. ] = [bs*draft_token*num_bytes_per_item]
// positions [bs * draft_token]
// retrive_index [bs, draft_token]
// retrive_next_token [bs, draft_token]
// retrive_next_sibling [bs, draft_token]
__global__ void build_tree_efficient_partial_packed(
int64_t* parent_list,
int64_t* selected_index,
int64_t* verified_seq_len,
uint8_t* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num,
size_t num_bytes_per_item) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num) {
return;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = (bid * draft_token_num + tid) * num_bytes_per_item;
tree_mask[token_tree_idx] = 1; // little endian
int position = 0;
if (tid == 0) {
positions[bid * draft_token_num] = seq_len;
int retrive_index_offset = bid * draft_token_num;
for (int i = draft_token_num - 1; i > 0; --i) {
int current_token_idx = retrive_index_offset + i;
retrive_index[bid * draft_token_num + i] = current_token_idx;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
int parent_position = 0;
if (parent_tb_idx > 0) {
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (; parent_position < draft_token_num; ++parent_position) {
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
++parent_position;
break;
}
}
}
if (parent_position == draft_token_num) {
printf(
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
continue;
}
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
retrive_next_token[bid * draft_token_num + parent_position] = i;
} else {
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
retrive_next_token[bid * draft_token_num + parent_position] = i;
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
}
}
retrive_index[bid * draft_token_num] = bid * draft_token_num;
} else {
int cur_position = tid - 1;
while (true) {
position += 1;
int byte_idx = (cur_position + 1) / 8;
int bit_idx = (cur_position + 1) % 8;
tree_mask[token_tree_idx + byte_idx] |= (1 << bit_idx);
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
if (parent_tb_idx == 0) {
break;
}
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
break;
}
}
}
positions[bid * draft_token_num + tid] = position + seq_len;
}
}
void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,
......@@ -149,7 +238,19 @@ void build_tree_kernel_efficient(
} else if (draft_token_num > 8) {
num_bytes_per_item = 2;
}
throw std::runtime_error("Not implemented");
build_tree_efficient_partial_packed<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<uint8_t*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num),
num_bytes_per_item);
} else {
build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
......
......@@ -130,6 +130,7 @@ int64_t cutlass_mla_get_workspace_size(
int64_t num_batches,
int64_t sm_count = 0,
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
/*
* From csrc/elementwise
*/
......@@ -156,9 +157,22 @@ void apply_rope_pos_ids_cos_sin_cache(
const std::optional<at::Tensor>& v_buffer,
const std::optional<at::Tensor>& kv_cache_loc);
void downcast_fp8(
at::Tensor& k,
at::Tensor& v,
at::Tensor& k_out,
at::Tensor& v_out,
at::Tensor& k_scale,
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
int64_t cuda_stream);
#ifdef USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input);
#endif
/*
* From csrc/gemm
*/
......@@ -221,7 +235,6 @@ void bmm_fp8(
int64_t cublas_handle,
int64_t cuda_stream);
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
torch::Tensor gptq_marlin_gemm(
......@@ -258,6 +271,7 @@ torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
/*
* From csrc/moe
*/
......@@ -374,6 +388,61 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
torch::Tensor& sorted_token_ids,
torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded,
torch::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
/*
* From csrc/speculative
*/
......@@ -527,35 +596,6 @@ void transfer_kv_direct(
const at::Tensor dst_indices,
int64_t page_size);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
/*
* From FlashInfer
*/
......@@ -597,32 +637,6 @@ void top_p_sampling_from_probs(
void top_k_mask_logits(
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
torch::Tensor& sorted_token_ids,
torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded,
torch::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
namespace flash {
/*
* From fa2 sparse
......
......@@ -31,11 +31,11 @@ from sgl_kernel.elementwise import (
rmsnorm,
silu_and_mul,
)
from sgl_kernel.fused_moe import fused_marlin_moe
if torch.version.hip is not None:
from sgl_kernel.elementwise import gelu_quick
from sgl_kernel.fused_moe import fused_marlin_moe
from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8,
......@@ -114,7 +114,3 @@ from sgl_kernel.speculative import (
)
from sgl_kernel.top_k import fast_topk
from sgl_kernel.version import __version__
build_tree_kernel = (
None # TODO(ying): remove this after updating the sglang python code.
)
from dataclasses import dataclass
from typing import Any, Optional
from typing import Optional
import torch
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
......@@ -345,3 +345,19 @@ def apply_rope_with_cos_sin_cache_inplace(
else None
),
)
def downcast_fp8(
k: torch.Tensor,
v: torch.Tensor,
k_out: torch.Tensor,
v_out: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
loc: torch.Tensor,
mult: int = 1,
offset: int = 0,
) -> None:
torch.ops.sgl_kernel.downcast_fp8(
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
)
......@@ -160,7 +160,7 @@ def fused_marlin_moe(
size_m=M,
size_n=2 * N,
size_k=K,
is_full_k=is_k_full,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
......@@ -192,7 +192,7 @@ def fused_marlin_moe(
size_m=M * topk,
size_n=K,
size_k=N,
is_full_k=is_k_full,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
......
from typing import Optional, Tuple, Union
from typing import Optional, Union
import torch
from sgl_kernel.utils import _to_tensor_scalar_tuple
......
......@@ -14,7 +14,6 @@
# ==============================================================================
import functools
import subprocess
from typing import Dict, Tuple
import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment