"vscode:/vscode.git/clone" did not exist on "af48bf200860d8b83fe3be92b2d7ae556a3b4111"
Unverified Commit d052f4c8 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

New clang format for sgl kernel (#4194)

parent e1aaa79a
......@@ -29,12 +29,19 @@ using namespace flashinfer;
// retrive_next_sibling: [bs, num_draft_tokens]
// uniform_samples: [bs, num_draft_tokens]
// target_probs: [bs, num_draft_tokens, vocab_size]
void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index,
at::Tensor accept_token_num, // mutable
at::Tensor candidates, at::Tensor retrive_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling,
at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs,
bool deterministic, int64_t cuda_stream = 0) {
void tree_speculative_sampling_target_only(
at::Tensor predicts,
at::Tensor accept_index,
at::Tensor accept_token_num, // mutable
at::Tensor candidates,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor uniform_samples,
at::Tensor target_probs,
at::Tensor draft_probs,
bool deterministic,
int64_t cuda_stream = 0) {
CHECK_INPUT(candidates);
CHECK_INPUT(retrive_index);
CHECK_INPUT(retrive_next_token);
......@@ -108,13 +115,24 @@ void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accep
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
static_cast<int*>(predicts.data_ptr()), static_cast<int*>(accept_index.data_ptr()),
static_cast<int*>(accept_token_num.data_ptr()), static_cast<int*>(candidates.data_ptr()),
static_cast<int*>(retrive_index.data_ptr()), static_cast<int*>(retrive_next_token.data_ptr()),
static_cast<int*>(retrive_next_sibling.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<float*>(target_probs.data_ptr()), static_cast<float*>(draft_probs.data_ptr()), batch_size,
num_spec_step, num_draft_tokens, vocab_size, deterministic, stream);
static_cast<int*>(predicts.data_ptr()),
static_cast<int*>(accept_index.data_ptr()),
static_cast<int*>(accept_token_num.data_ptr()),
static_cast<int*>(candidates.data_ptr()),
static_cast<int*>(retrive_index.data_ptr()),
static_cast<int*>(retrive_next_token.data_ptr()),
static_cast<int*>(retrive_next_sibling.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()),
static_cast<float*>(target_probs.data_ptr()),
static_cast<float*>(draft_probs.data_ptr()),
batch_size,
num_spec_step,
num_draft_tokens,
vocab_size,
deterministic,
stream);
TORCH_CHECK(status == cudaSuccess,
"TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status)));
TORCH_CHECK(
status == cudaSuccess,
"TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status)));
}
......@@ -27,15 +27,29 @@ namespace sampling {
using namespace cub;
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM, BlockReduceAlgorithm REDUCE_ALGORITHM,
uint32_t VEC_SIZE, bool DETERMINISTIC, typename DType, typename IdType>
__global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* accept_index,
IdType* accept_token_num, // mutable
IdType* candidates, IdType* retrive_index, IdType* retrive_next_token,
IdType* retrive_next_sibling, DType* uniform_samples,
DType* target_probs, DType* draft_probs, uint32_t batch_size,
uint32_t num_speculative_tokens, uint32_t num_draft_tokens,
uint32_t d) {
template <
uint32_t BLOCK_THREADS,
BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM,
uint32_t VEC_SIZE,
bool DETERMINISTIC,
typename DType,
typename IdType>
__global__ void TreeSpeculativeSamplingTargetOnly(
IdType* predicts,
IdType* accept_index,
IdType* accept_token_num, // mutable
IdType* candidates,
IdType* retrive_index,
IdType* retrive_next_token,
IdType* retrive_next_sibling,
DType* uniform_samples,
DType* target_probs,
DType* draft_probs,
uint32_t batch_size,
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens,
uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
......@@ -140,37 +154,54 @@ __global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* acce
}
template <typename DType, typename IdType>
cudaError_t TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* output_token_ids,
IdType* output_accepted_token_num, // mutable
IdType* candidates, IdType* retrive_index, IdType* retrive_next_token,
IdType* retrive_next_sibling, DType* uniform_samples, DType* target_probs,
DType* draft_probs, uint32_t batch_size, uint32_t num_speculative_tokens,
uint32_t num_draft_tokens, uint32_t d, bool deterministic,
cudaStream_t stream = 0) {
cudaError_t TreeSpeculativeSamplingTargetOnly(
IdType* predicts,
IdType* output_token_ids,
IdType* output_accepted_token_num, // mutable
IdType* candidates,
IdType* retrive_index,
IdType* retrive_next_token,
IdType* retrive_next_sibling,
DType* uniform_samples,
DType* target_probs,
DType* draft_probs,
uint32_t batch_size,
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens,
uint32_t d,
bool deterministic,
cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&predicts,
&output_token_ids,
&output_accepted_token_num,
&candidates,
&retrive_index,
&retrive_next_token,
&retrive_next_sibling,
&uniform_samples,
&target_probs,
&draft_probs,
&batch_size,
&num_speculative_tokens,
&num_draft_tokens,
&d};
void* args[] = {
&predicts,
&output_token_ids,
&output_accepted_token_num,
&candidates,
&retrive_index,
&retrive_next_token,
&retrive_next_sibling,
&uniform_samples,
&target_probs,
&draft_probs,
&batch_size,
&num_speculative_tokens,
&num_draft_tokens,
&d};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel = TreeSpeculativeSamplingTargetOnly<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE, DETERMINISTIC,
DType, IdType>;
auto kernel = TreeSpeculativeSamplingTargetOnly<
BLOCK_THREADS,
SCAN_ALGO,
REDUCE_ALGO,
VEC_SIZE,
DETERMINISTIC,
DType,
IdType>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
})});
......
......@@ -42,8 +42,8 @@ using fptr_t = int64_t;
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
int64_t cuda_stream);
void gemma_fused_add_rmsnorm(
at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
......@@ -53,113 +53,219 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
*/
#ifdef USE_ROCM
// ROCM custom allreduce
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank, bool full_nvlink);
fptr_t init_custom_ar(
torch::Tensor& meta,
torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets,
int64_t rank,
bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets);
void register_buffer(
fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles, const std::vector<int64_t>& offsets);
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);
void register_graph_buffers(
fptr_t _fa, const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets);
torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
#else
// TRTLLM custom allreduce
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out);
fptr_t init_custom_ar(
int64_t rank_id,
int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out);
void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
#endif
/*
* From csrc/gemm
*/
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias);
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias);
torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const torch::Dtype& out_dtype);
void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size,
double eps, double fp8_min, double fp8_max);
torch::Tensor int8_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias);
torch::Tensor fp8_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias);
torch::Tensor fp8_blockwise_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype);
void sgl_per_token_group_quant_fp8(
at::Tensor input,
at::Tensor output_q,
at::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max);
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights,
const std::vector<torch::Tensor>& outputs, const torch::Dtype& out_dtype,
int64_t cublas_handle, int64_t cuda_stream);
void cublas_grouped_gemm(
const std::vector<torch::Tensor>& inputs,
const std::vector<torch::Tensor>& weights,
const std::vector<torch::Tensor>& outputs,
const torch::Dtype& out_dtype,
int64_t cublas_handle,
int64_t cuda_stream);
/*
* From csrc/moe
*/
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
void moe_align_block_size(
torch::Tensor topk_ids,
int64_t num_experts,
int64_t block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer);
/*
* From csrc/speculative
*/
void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index,
at::Tensor accept_token_num, // mutable
at::Tensor candidates, at::Tensor retrive_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling,
at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs,
bool deterministic = true, int64_t cuda_stream = 0);
void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk,
int64_t depth, int64_t draft_token_num);
void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
int64_t depth, int64_t draft_token_num);
void tree_speculative_sampling_target_only(
at::Tensor predicts,
at::Tensor accept_index,
at::Tensor accept_token_num, // mutable
at::Tensor candidates,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor uniform_samples,
at::Tensor target_probs,
at::Tensor draft_probs,
bool deterministic = true,
int64_t cuda_stream = 0);
void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num);
void build_tree_kernel(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
int64_t topk,
int64_t depth,
int64_t draft_token_num);
/*
* From FlashInfer
*/
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream);
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val, bool deterministic,
int64_t cuda_stream);
void bmm_fp8(
at::Tensor A,
at::Tensor B,
at::Tensor D,
at::Tensor A_scale,
at::Tensor B_scale,
at::Tensor workspace_buffer,
int64_t cublas_handle,
int64_t cuda_stream);
void min_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
std::optional<at::Tensor> maybe_min_p_arr,
double min_p_val,
bool deterministic,
int64_t cuda_stream);
// top k renorm probs
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
unsigned int top_k_val, int64_t cuda_stream);
void top_k_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
unsigned int top_k_val,
int64_t cuda_stream);
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
int64_t cuda_stream) {
inline void top_k_renorm_probs_wrapper(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
int64_t top_k_val,
int64_t cuda_stream) {
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
}
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val, int64_t cuda_stream);
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
at::Tensor success, std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
int64_t cuda_stream);
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
int64_t cuda_stream);
void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave,
int64_t cuda_stream);
void top_p_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
int64_t cuda_stream);
void top_k_top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor success,
std::optional<at::Tensor> maybe_top_k_arr,
double top_k_val,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
int64_t cuda_stream);
void top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor success,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
int64_t cuda_stream);
void apply_rope_pos_ids_cos_sin_cache(
at::Tensor q,
at::Tensor k,
at::Tensor q_rope,
at::Tensor k_rope,
at::Tensor cos_sin_cache,
at::Tensor pos_ids,
bool interleave,
int64_t cuda_stream);
/*
* Other
*/
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
torch::Tensor new_kv);
void lightning_attention_decode(
const torch::Tensor& q,
const torch::Tensor& k,
const torch::Tensor& v,
const torch::Tensor& past_kv,
const torch::Tensor& slope,
torch::Tensor output,
torch::Tensor new_kv);
// sgl_per_token_quant_fp8
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
......@@ -103,7 +103,7 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world
return AllReduceStrategyType::TWOSHOT;
}
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
cudaStream_t stream);
void trtCustomAllReduce(
AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream);
} // namespace trt_llm
......@@ -95,7 +95,6 @@ inline int getSMVersion() {
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define WARP_SIZE 32
#ifndef USE_ROCM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment