Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d052f4c8
"vscode:/vscode.git/clone" did not exist on "af48bf200860d8b83fe3be92b2d7ae556a3b4111"
Unverified
Commit
d052f4c8
authored
Mar 07, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 07, 2025
Browse files
New clang format for sgl kernel (#4194)
parent
e1aaa79a
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
273 additions
and
119 deletions
+273
-119
sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu
...l/src/sgl-kernel/csrc/speculative/speculative_sampling.cu
+32
-14
sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh
.../src/sgl-kernel/csrc/speculative/speculative_sampling.cuh
+63
-32
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+176
-70
sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh
sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh
+2
-2
sgl-kernel/src/sgl-kernel/include/utils.h
sgl-kernel/src/sgl-kernel/include/utils.h
+0
-1
No files found.
sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu
View file @
d052f4c8
...
...
@@ -29,12 +29,19 @@ using namespace flashinfer;
// retrive_next_sibling: [bs, num_draft_tokens]
// uniform_samples: [bs, num_draft_tokens]
// target_probs: [bs, num_draft_tokens, vocab_size]
void
tree_speculative_sampling_target_only
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
bool
deterministic
,
int64_t
cuda_stream
=
0
)
{
void
tree_speculative_sampling_target_only
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
bool
deterministic
,
int64_t
cuda_stream
=
0
)
{
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
retrive_index
);
CHECK_INPUT
(
retrive_next_token
);
...
...
@@ -108,13 +115,24 @@ void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accep
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaError_t
status
=
sampling
::
TreeSpeculativeSamplingTargetOnly
<
float
,
int
>
(
static_cast
<
int
*>
(
predicts
.
data_ptr
()),
static_cast
<
int
*>
(
accept_index
.
data_ptr
()),
static_cast
<
int
*>
(
accept_token_num
.
data_ptr
()),
static_cast
<
int
*>
(
candidates
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
float
*>
(
uniform_samples
.
data_ptr
()),
static_cast
<
float
*>
(
target_probs
.
data_ptr
()),
static_cast
<
float
*>
(
draft_probs
.
data_ptr
()),
batch_size
,
num_spec_step
,
num_draft_tokens
,
vocab_size
,
deterministic
,
stream
);
static_cast
<
int
*>
(
predicts
.
data_ptr
()),
static_cast
<
int
*>
(
accept_index
.
data_ptr
()),
static_cast
<
int
*>
(
accept_token_num
.
data_ptr
()),
static_cast
<
int
*>
(
candidates
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
float
*>
(
uniform_samples
.
data_ptr
()),
static_cast
<
float
*>
(
target_probs
.
data_ptr
()),
static_cast
<
float
*>
(
draft_probs
.
data_ptr
()),
batch_size
,
num_spec_step
,
num_draft_tokens
,
vocab_size
,
deterministic
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"TreeSpeculativeSamplingTargetOnly failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
TORCH_CHECK
(
status
==
cudaSuccess
,
"TreeSpeculativeSamplingTargetOnly failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
}
sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh
View file @
d052f4c8
...
...
@@ -27,15 +27,29 @@ namespace sampling {
using
namespace
cub
;
template
<
uint32_t
BLOCK_THREADS
,
BlockScanAlgorithm
SCAN_ALGORITHM
,
BlockReduceAlgorithm
REDUCE_ALGORITHM
,
uint32_t
VEC_SIZE
,
bool
DETERMINISTIC
,
typename
DType
,
typename
IdType
>
__global__
void
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
accept_index
,
IdType
*
accept_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
DType
*
uniform_samples
,
DType
*
target_probs
,
DType
*
draft_probs
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
)
{
template
<
uint32_t
BLOCK_THREADS
,
BlockScanAlgorithm
SCAN_ALGORITHM
,
BlockReduceAlgorithm
REDUCE_ALGORITHM
,
uint32_t
VEC_SIZE
,
bool
DETERMINISTIC
,
typename
DType
,
typename
IdType
>
__global__
void
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
accept_index
,
IdType
*
accept_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
DType
*
uniform_samples
,
DType
*
target_probs
,
DType
*
draft_probs
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
)
{
const
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
;
extern
__shared__
__align__
(
alignof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
>
))
...
...
@@ -140,37 +154,54 @@ __global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* acce
}
template
<
typename
DType
,
typename
IdType
>
cudaError_t
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
output_token_ids
,
IdType
*
output_accepted_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
DType
*
uniform_samples
,
DType
*
target_probs
,
DType
*
draft_probs
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
,
bool
deterministic
,
cudaStream_t
stream
=
0
)
{
cudaError_t
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
output_token_ids
,
IdType
*
output_accepted_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
DType
*
uniform_samples
,
DType
*
target_probs
,
DType
*
draft_probs
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
,
bool
deterministic
,
cudaStream_t
stream
=
0
)
{
constexpr
uint32_t
BLOCK_THREADS
=
1024
;
const
uint32_t
vec_size
=
std
::
gcd
(
16
/
sizeof
(
DType
),
d
);
const
uint32_t
smem_size
=
sizeof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
>
);
dim3
nblks
(
batch_size
);
dim3
nthrs
(
BLOCK_THREADS
);
void
*
args
[]
=
{
&
predicts
,
&
output_token_ids
,
&
output_accepted_token_num
,
&
candidates
,
&
retrive_index
,
&
retrive_next_token
,
&
retrive_next_sibling
,
&
uniform_samples
,
&
target_probs
,
&
draft_probs
,
&
batch_size
,
&
num_speculative_tokens
,
&
num_draft_tokens
,
&
d
};
void
*
args
[]
=
{
&
predicts
,
&
output_token_ids
,
&
output_accepted_token_num
,
&
candidates
,
&
retrive_index
,
&
retrive_next_token
,
&
retrive_next_sibling
,
&
uniform_samples
,
&
target_probs
,
&
draft_probs
,
&
batch_size
,
&
num_speculative_tokens
,
&
num_draft_tokens
,
&
d
};
DISPATCH_ALIGNED_VEC_SIZE
(
vec_size
,
VEC_SIZE
,
{
DISPATCH_DETERMINISTIC
(
deterministic
,
DETERMINISTIC
,
{
auto
kernel
=
TreeSpeculativeSamplingTargetOnly
<
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
,
VEC_SIZE
,
DETERMINISTIC
,
DType
,
IdType
>
;
auto
kernel
=
TreeSpeculativeSamplingTargetOnly
<
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
,
VEC_SIZE
,
DETERMINISTIC
,
DType
,
IdType
>
;
FLASHINFER_CUDA_CALL
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel
,
nblks
,
nthrs
,
args
,
smem_size
,
stream
));
})});
...
...
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
d052f4c8
...
...
@@ -42,8 +42,8 @@ using fptr_t = int64_t;
void
rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
sgl_fused_add_rmsnorm
(
torch
::
Tensor
input
,
torch
::
Tensor
residual
,
torch
::
Tensor
weight
,
double
eps
);
void
gemma_rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
gemma_fused_add_rmsnorm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
gemma_fused_add_rmsnorm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
silu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
gelu_tanh_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
gelu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
...
...
@@ -53,113 +53,219 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
*/
#ifdef USE_ROCM
// ROCM custom allreduce
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
void
dispose
(
fptr_t
_fa
);
int64_t
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
);
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
);
std
::
tuple
<
torch
::
Tensor
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
torch
::
Tensor
allocate_meta_buffer
(
int64_t
size
);
torch
::
Tensor
get_meta_buffer_ipc_handle
(
torch
::
Tensor
&
inp
);
#else
// TRTLLM custom allreduce
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
);
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
);
void
dispose
(
fptr_t
_fa
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
/*
* From csrc/gemm
*/
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
torch
::
Tensor
fp8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
);
void
sgl_per_token_group_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
);
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
torch
::
Tensor
fp8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
);
void
sgl_per_token_group_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
/*
* From csrc/moe
*/
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
/*
* From csrc/speculative
*/
void
tree_speculative_sampling_target_only
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
);
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
void
tree_speculative_sampling_target_only
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
);
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
/*
* From FlashInfer
*/
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
void
min_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
std
::
optional
<
at
::
Tensor
>
maybe_min_p_arr
,
double
min_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
void
min_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
std
::
optional
<
at
::
Tensor
>
maybe_min_p_arr
,
double
min_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
// top k renorm probs
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
void
top_k_renorm_probs
(
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
unsigned
int
top_k_val
,
int64_t
cuda_stream
);
void
top_k_renorm_probs
(
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
unsigned
int
top_k_val
,
int64_t
cuda_stream
);
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
inline
void
top_k_renorm_probs_wrapper
(
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
int64_t
top_k_val
,
int64_t
cuda_stream
)
{
inline
void
top_k_renorm_probs_wrapper
(
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
int64_t
top_k_val
,
int64_t
cuda_stream
)
{
top_k_renorm_probs
(
probs
,
renorm_probs
,
maybe_top_k_arr
,
static_cast
<
unsigned
int
>
(
top_k_val
),
cuda_stream
);
}
void
top_p_renorm_probs
(
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
int64_t
cuda_stream
);
void
top_k_top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
at
::
Tensor
success
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
double
top_k_val
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
void
top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
at
::
Tensor
success
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
void
apply_rope_pos_ids_cos_sin_cache
(
at
::
Tensor
q
,
at
::
Tensor
k
,
at
::
Tensor
q_rope
,
at
::
Tensor
k_rope
,
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
int64_t
cuda_stream
);
void
top_p_renorm_probs
(
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
int64_t
cuda_stream
);
void
top_k_top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
at
::
Tensor
success
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
double
top_k_val
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
void
top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
at
::
Tensor
success
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
void
apply_rope_pos_ids_cos_sin_cache
(
at
::
Tensor
q
,
at
::
Tensor
k
,
at
::
Tensor
q_rope
,
at
::
Tensor
k_rope
,
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
int64_t
cuda_stream
);
/*
* Other
*/
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
);
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
);
// sgl_per_token_quant_fp8
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh
View file @
d052f4c8
...
...
@@ -103,7 +103,7 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world
return
AllReduceStrategyType
::
TWOSHOT
;
}
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
);
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
);
}
// namespace trt_llm
sgl-kernel/src/sgl-kernel/include/utils.h
View file @
d052f4c8
...
...
@@ -95,7 +95,6 @@ inline int getSMVersion() {
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define WARP_SIZE 32
#ifndef USE_ROCM
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment