Unverified Commit d052f4c8 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

New clang format for sgl kernel (#4194)

parent e1aaa79a
...@@ -29,12 +29,19 @@ using namespace flashinfer; ...@@ -29,12 +29,19 @@ using namespace flashinfer;
// retrive_next_sibling: [bs, num_draft_tokens] // retrive_next_sibling: [bs, num_draft_tokens]
// uniform_samples: [bs, num_draft_tokens] // uniform_samples: [bs, num_draft_tokens]
// target_probs: [bs, num_draft_tokens, vocab_size] // target_probs: [bs, num_draft_tokens, vocab_size]
void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index, void tree_speculative_sampling_target_only(
at::Tensor accept_token_num, // mutable at::Tensor predicts,
at::Tensor candidates, at::Tensor retrive_index, at::Tensor accept_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, at::Tensor accept_token_num, // mutable
at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs, at::Tensor candidates,
bool deterministic, int64_t cuda_stream = 0) { at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor uniform_samples,
at::Tensor target_probs,
at::Tensor draft_probs,
bool deterministic,
int64_t cuda_stream = 0) {
CHECK_INPUT(candidates); CHECK_INPUT(candidates);
CHECK_INPUT(retrive_index); CHECK_INPUT(retrive_index);
CHECK_INPUT(retrive_next_token); CHECK_INPUT(retrive_next_token);
...@@ -108,13 +115,24 @@ void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accep ...@@ -108,13 +115,24 @@ void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accep
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>( cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
static_cast<int*>(predicts.data_ptr()), static_cast<int*>(accept_index.data_ptr()), static_cast<int*>(predicts.data_ptr()),
static_cast<int*>(accept_token_num.data_ptr()), static_cast<int*>(candidates.data_ptr()), static_cast<int*>(accept_index.data_ptr()),
static_cast<int*>(retrive_index.data_ptr()), static_cast<int*>(retrive_next_token.data_ptr()), static_cast<int*>(accept_token_num.data_ptr()),
static_cast<int*>(retrive_next_sibling.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()), static_cast<int*>(candidates.data_ptr()),
static_cast<float*>(target_probs.data_ptr()), static_cast<float*>(draft_probs.data_ptr()), batch_size, static_cast<int*>(retrive_index.data_ptr()),
num_spec_step, num_draft_tokens, vocab_size, deterministic, stream); static_cast<int*>(retrive_next_token.data_ptr()),
static_cast<int*>(retrive_next_sibling.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()),
static_cast<float*>(target_probs.data_ptr()),
static_cast<float*>(draft_probs.data_ptr()),
batch_size,
num_spec_step,
num_draft_tokens,
vocab_size,
deterministic,
stream);
TORCH_CHECK(status == cudaSuccess, TORCH_CHECK(
"TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status))); status == cudaSuccess,
"TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status)));
} }
...@@ -27,15 +27,29 @@ namespace sampling { ...@@ -27,15 +27,29 @@ namespace sampling {
using namespace cub; using namespace cub;
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM, BlockReduceAlgorithm REDUCE_ALGORITHM, template <
uint32_t VEC_SIZE, bool DETERMINISTIC, typename DType, typename IdType> uint32_t BLOCK_THREADS,
__global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* accept_index, BlockScanAlgorithm SCAN_ALGORITHM,
IdType* accept_token_num, // mutable BlockReduceAlgorithm REDUCE_ALGORITHM,
IdType* candidates, IdType* retrive_index, IdType* retrive_next_token, uint32_t VEC_SIZE,
IdType* retrive_next_sibling, DType* uniform_samples, bool DETERMINISTIC,
DType* target_probs, DType* draft_probs, uint32_t batch_size, typename DType,
uint32_t num_speculative_tokens, uint32_t num_draft_tokens, typename IdType>
uint32_t d) { __global__ void TreeSpeculativeSamplingTargetOnly(
IdType* predicts,
IdType* accept_index,
IdType* accept_token_num, // mutable
IdType* candidates,
IdType* retrive_index,
IdType* retrive_next_token,
IdType* retrive_next_sibling,
DType* uniform_samples,
DType* target_probs,
DType* draft_probs,
uint32_t batch_size,
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens,
uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x;
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
...@@ -140,37 +154,54 @@ __global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* acce ...@@ -140,37 +154,54 @@ __global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* acce
} }
template <typename DType, typename IdType> template <typename DType, typename IdType>
cudaError_t TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* output_token_ids, cudaError_t TreeSpeculativeSamplingTargetOnly(
IdType* output_accepted_token_num, // mutable IdType* predicts,
IdType* candidates, IdType* retrive_index, IdType* retrive_next_token, IdType* output_token_ids,
IdType* retrive_next_sibling, DType* uniform_samples, DType* target_probs, IdType* output_accepted_token_num, // mutable
DType* draft_probs, uint32_t batch_size, uint32_t num_speculative_tokens, IdType* candidates,
uint32_t num_draft_tokens, uint32_t d, bool deterministic, IdType* retrive_index,
cudaStream_t stream = 0) { IdType* retrive_next_token,
IdType* retrive_next_sibling,
DType* uniform_samples,
DType* target_probs,
DType* draft_probs,
uint32_t batch_size,
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens,
uint32_t d,
bool deterministic,
cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024; constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>); const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size); dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS); dim3 nthrs(BLOCK_THREADS);
void* args[] = {&predicts, void* args[] = {
&output_token_ids, &predicts,
&output_accepted_token_num, &output_token_ids,
&candidates, &output_accepted_token_num,
&retrive_index, &candidates,
&retrive_next_token, &retrive_index,
&retrive_next_sibling, &retrive_next_token,
&uniform_samples, &retrive_next_sibling,
&target_probs, &uniform_samples,
&draft_probs, &target_probs,
&batch_size, &draft_probs,
&num_speculative_tokens, &batch_size,
&num_draft_tokens, &num_speculative_tokens,
&d}; &num_draft_tokens,
&d};
DISPATCH_ALIGNED_VEC_SIZE( DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel = TreeSpeculativeSamplingTargetOnly<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE, DETERMINISTIC, auto kernel = TreeSpeculativeSamplingTargetOnly<
DType, IdType>; BLOCK_THREADS,
SCAN_ALGO,
REDUCE_ALGO,
VEC_SIZE,
DETERMINISTIC,
DType,
IdType>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
})}); })});
......
...@@ -42,8 +42,8 @@ using fptr_t = int64_t; ...@@ -42,8 +42,8 @@ using fptr_t = int64_t;
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, void gemma_fused_add_rmsnorm(
int64_t cuda_stream); at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
...@@ -53,113 +53,219 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); ...@@ -53,113 +53,219 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
*/ */
#ifdef USE_ROCM #ifdef USE_ROCM
// ROCM custom allreduce // ROCM custom allreduce
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector<std::string>& handles, fptr_t init_custom_ar(
const std::vector<int64_t>& offsets, int64_t rank, bool full_nvlink); torch::Tensor& meta,
torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets,
int64_t rank,
bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
int64_t meta_size(); int64_t meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles, void register_buffer(
const std::vector<int64_t>& offsets); fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles, const std::vector<int64_t>& offsets);
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles, void register_graph_buffers(
const std::vector<std::vector<int64_t>>& offsets); fptr_t _fa, const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets);
torch::Tensor allocate_meta_buffer(int64_t size); torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
#else #else
// TRTLLM custom allreduce // TRTLLM custom allreduce
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers, fptr_t init_custom_ar(
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in, int64_t rank_id,
const std::vector<fptr_t>& barrier_out); int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, void register_graph_buffers(
const std::vector<std::vector<int64_t>>& offsets); fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
#endif #endif
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, torch::Tensor int8_scaled_mm(
const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const torch::Tensor& mat_a,
const c10::optional<torch::Tensor>& bias); const torch::Tensor& mat_b,
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias); const torch::Dtype& out_dtype,
torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const c10::optional<torch::Tensor>& bias);
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor fp8_scaled_mm(
const torch::Dtype& out_dtype); const torch::Tensor& mat_a,
void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, const torch::Tensor& mat_b,
double eps, double fp8_min, double fp8_max); const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias);
torch::Tensor fp8_blockwise_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype);
void sgl_per_token_group_quant_fp8(
at::Tensor input,
at::Tensor output_q,
at::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max);
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights, void cublas_grouped_gemm(
const std::vector<torch::Tensor>& outputs, const torch::Dtype& out_dtype, const std::vector<torch::Tensor>& inputs,
int64_t cublas_handle, int64_t cuda_stream); const std::vector<torch::Tensor>& weights,
const std::vector<torch::Tensor>& outputs,
const torch::Dtype& out_dtype,
int64_t cublas_handle,
int64_t cuda_stream);
/* /*
* From csrc/moe * From csrc/moe
*/ */
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, void moe_align_block_size(
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor topk_ids,
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); int64_t num_experts,
int64_t block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer);
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index, void tree_speculative_sampling_target_only(
at::Tensor accept_token_num, // mutable at::Tensor predicts,
at::Tensor candidates, at::Tensor retrive_index, at::Tensor accept_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, at::Tensor accept_token_num, // mutable
at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs, at::Tensor candidates,
bool deterministic = true, int64_t cuda_stream = 0); at::Tensor retrive_index,
at::Tensor retrive_next_token,
void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, at::Tensor retrive_next_sibling,
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, at::Tensor uniform_samples,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk, at::Tensor target_probs,
int64_t depth, int64_t draft_token_num); at::Tensor draft_probs,
bool deterministic = true,
void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, int64_t cuda_stream = 0);
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
int64_t depth, int64_t draft_token_num); void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num);
void build_tree_kernel(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
int64_t topk,
int64_t depth,
int64_t draft_token_num);
/* /*
* From FlashInfer * From FlashInfer
*/ */
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, void bmm_fp8(
at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); at::Tensor A,
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor B,
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val, bool deterministic, at::Tensor D,
int64_t cuda_stream); at::Tensor A_scale,
at::Tensor B_scale,
at::Tensor workspace_buffer,
int64_t cublas_handle,
int64_t cuda_stream);
void min_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
std::optional<at::Tensor> maybe_min_p_arr,
double min_p_val,
bool deterministic,
int64_t cuda_stream);
// top k renorm probs // top k renorm probs
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. // patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr, void top_k_renorm_probs(
unsigned int top_k_val, int64_t cuda_stream); at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
unsigned int top_k_val,
int64_t cuda_stream);
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. // patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs, inline void top_k_renorm_probs_wrapper(
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val, at::Tensor probs,
int64_t cuda_stream) { at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
int64_t top_k_val,
int64_t cuda_stream) {
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream); top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
} }
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr, void top_p_renorm_probs(
double top_p_val, int64_t cuda_stream); at::Tensor probs,
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor renorm_probs,
at::Tensor success, std::optional<at::Tensor> maybe_top_k_arr, double top_k_val, std::optional<at::Tensor> maybe_top_p_arr,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic, double top_p_val,
int64_t cuda_stream); int64_t cuda_stream);
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, void top_k_top_p_sampling_from_probs(
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic, at::Tensor probs,
int64_t cuda_stream); at::Tensor uniform_samples,
void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor samples,
at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, at::Tensor success,
int64_t cuda_stream); std::optional<at::Tensor> maybe_top_k_arr,
double top_k_val,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
int64_t cuda_stream);
void top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor success,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
int64_t cuda_stream);
void apply_rope_pos_ids_cos_sin_cache(
at::Tensor q,
at::Tensor k,
at::Tensor q_rope,
at::Tensor k_rope,
at::Tensor cos_sin_cache,
at::Tensor pos_ids,
bool interleave,
int64_t cuda_stream);
/* /*
* Other * Other
*/ */
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, void lightning_attention_decode(
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, const torch::Tensor& q,
torch::Tensor new_kv); const torch::Tensor& k,
const torch::Tensor& v,
const torch::Tensor& past_kv,
const torch::Tensor& slope,
torch::Tensor output,
torch::Tensor new_kv);
// sgl_per_token_quant_fp8 // sgl_per_token_quant_fp8
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
...@@ -103,7 +103,7 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world ...@@ -103,7 +103,7 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world
return AllReduceStrategyType::TWOSHOT; return AllReduceStrategyType::TWOSHOT;
} }
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, void trtCustomAllReduce(
cudaStream_t stream); AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream);
} // namespace trt_llm } // namespace trt_llm
...@@ -95,7 +95,6 @@ inline int getSMVersion() { ...@@ -95,7 +95,6 @@ inline int getSMVersion() {
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y)-1) / (y)) #define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define WARP_SIZE 32 #define WARP_SIZE 32
#ifndef USE_ROCM #ifndef USE_ROCM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment