Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d052f4c8
Unverified
Commit
d052f4c8
authored
Mar 07, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 07, 2025
Browse files
New clang format for sgl kernel (#4194)
parent
e1aaa79a
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
273 additions
and
119 deletions
+273
-119
sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu
...l/src/sgl-kernel/csrc/speculative/speculative_sampling.cu
+32
-14
sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh
.../src/sgl-kernel/csrc/speculative/speculative_sampling.cuh
+63
-32
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+176
-70
sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh
sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh
+2
-2
sgl-kernel/src/sgl-kernel/include/utils.h
sgl-kernel/src/sgl-kernel/include/utils.h
+0
-1
No files found.
sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu
View file @
d052f4c8
...
@@ -29,12 +29,19 @@ using namespace flashinfer;
...
@@ -29,12 +29,19 @@ using namespace flashinfer;
// retrive_next_sibling: [bs, num_draft_tokens]
// retrive_next_sibling: [bs, num_draft_tokens]
// uniform_samples: [bs, num_draft_tokens]
// uniform_samples: [bs, num_draft_tokens]
// target_probs: [bs, num_draft_tokens, vocab_size]
// target_probs: [bs, num_draft_tokens, vocab_size]
void
tree_speculative_sampling_target_only
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
void
tree_speculative_sampling_target_only
(
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
predicts
,
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
accept_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
at
::
Tensor
candidates
,
bool
deterministic
,
int64_t
cuda_stream
=
0
)
{
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
bool
deterministic
,
int64_t
cuda_stream
=
0
)
{
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
retrive_index
);
CHECK_INPUT
(
retrive_index
);
CHECK_INPUT
(
retrive_next_token
);
CHECK_INPUT
(
retrive_next_token
);
...
@@ -108,13 +115,24 @@ void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accep
...
@@ -108,13 +115,24 @@ void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accep
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaError_t
status
=
sampling
::
TreeSpeculativeSamplingTargetOnly
<
float
,
int
>
(
cudaError_t
status
=
sampling
::
TreeSpeculativeSamplingTargetOnly
<
float
,
int
>
(
static_cast
<
int
*>
(
predicts
.
data_ptr
()),
static_cast
<
int
*>
(
accept_index
.
data_ptr
()),
static_cast
<
int
*>
(
predicts
.
data_ptr
()),
static_cast
<
int
*>
(
accept_token_num
.
data_ptr
()),
static_cast
<
int
*>
(
candidates
.
data_ptr
()),
static_cast
<
int
*>
(
accept_index
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int
*>
(
accept_token_num
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
float
*>
(
uniform_samples
.
data_ptr
()),
static_cast
<
int
*>
(
candidates
.
data_ptr
()),
static_cast
<
float
*>
(
target_probs
.
data_ptr
()),
static_cast
<
float
*>
(
draft_probs
.
data_ptr
()),
batch_size
,
static_cast
<
int
*>
(
retrive_index
.
data_ptr
()),
num_spec_step
,
num_draft_tokens
,
vocab_size
,
deterministic
,
stream
);
static_cast
<
int
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
float
*>
(
uniform_samples
.
data_ptr
()),
static_cast
<
float
*>
(
target_probs
.
data_ptr
()),
static_cast
<
float
*>
(
draft_probs
.
data_ptr
()),
batch_size
,
num_spec_step
,
num_draft_tokens
,
vocab_size
,
deterministic
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
TORCH_CHECK
(
"TreeSpeculativeSamplingTargetOnly failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
status
==
cudaSuccess
,
"TreeSpeculativeSamplingTargetOnly failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
}
}
sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh
View file @
d052f4c8
...
@@ -27,15 +27,29 @@ namespace sampling {
...
@@ -27,15 +27,29 @@ namespace sampling {
using
namespace
cub
;
using
namespace
cub
;
template
<
uint32_t
BLOCK_THREADS
,
BlockScanAlgorithm
SCAN_ALGORITHM
,
BlockReduceAlgorithm
REDUCE_ALGORITHM
,
template
<
uint32_t
VEC_SIZE
,
bool
DETERMINISTIC
,
typename
DType
,
typename
IdType
>
uint32_t
BLOCK_THREADS
,
__global__
void
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
accept_index
,
BlockScanAlgorithm
SCAN_ALGORITHM
,
IdType
*
accept_token_num
,
// mutable
BlockReduceAlgorithm
REDUCE_ALGORITHM
,
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_next_token
,
uint32_t
VEC_SIZE
,
IdType
*
retrive_next_sibling
,
DType
*
uniform_samples
,
bool
DETERMINISTIC
,
DType
*
target_probs
,
DType
*
draft_probs
,
uint32_t
batch_size
,
typename
DType
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
typename
IdType
>
uint32_t
d
)
{
__global__
void
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
accept_index
,
IdType
*
accept_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
DType
*
uniform_samples
,
DType
*
target_probs
,
DType
*
draft_probs
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
)
{
const
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
;
const
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
;
extern
__shared__
__align__
(
alignof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
>
))
extern
__shared__
__align__
(
alignof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
>
))
...
@@ -140,37 +154,54 @@ __global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* acce
...
@@ -140,37 +154,54 @@ __global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* acce
}
}
template
<
typename
DType
,
typename
IdType
>
template
<
typename
DType
,
typename
IdType
>
cudaError_t
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
output_token_ids
,
cudaError_t
TreeSpeculativeSamplingTargetOnly
(
IdType
*
output_accepted_token_num
,
// mutable
IdType
*
predicts
,
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
*
output_token_ids
,
IdType
*
retrive_next_sibling
,
DType
*
uniform_samples
,
DType
*
target_probs
,
IdType
*
output_accepted_token_num
,
// mutable
DType
*
draft_probs
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
IdType
*
candidates
,
uint32_t
num_draft_tokens
,
uint32_t
d
,
bool
deterministic
,
IdType
*
retrive_index
,
cudaStream_t
stream
=
0
)
{
IdType
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
DType
*
uniform_samples
,
DType
*
target_probs
,
DType
*
draft_probs
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
,
bool
deterministic
,
cudaStream_t
stream
=
0
)
{
constexpr
uint32_t
BLOCK_THREADS
=
1024
;
constexpr
uint32_t
BLOCK_THREADS
=
1024
;
const
uint32_t
vec_size
=
std
::
gcd
(
16
/
sizeof
(
DType
),
d
);
const
uint32_t
vec_size
=
std
::
gcd
(
16
/
sizeof
(
DType
),
d
);
const
uint32_t
smem_size
=
sizeof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
>
);
const
uint32_t
smem_size
=
sizeof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
>
);
dim3
nblks
(
batch_size
);
dim3
nblks
(
batch_size
);
dim3
nthrs
(
BLOCK_THREADS
);
dim3
nthrs
(
BLOCK_THREADS
);
void
*
args
[]
=
{
&
predicts
,
void
*
args
[]
=
{
&
output_token_ids
,
&
predicts
,
&
output_accepted_token_num
,
&
output_token_ids
,
&
candidates
,
&
output_accepted_token_num
,
&
retrive_index
,
&
candidates
,
&
retrive_next_token
,
&
retrive_index
,
&
retrive_next_sibling
,
&
retrive_next_token
,
&
uniform_samples
,
&
retrive_next_sibling
,
&
target_probs
,
&
uniform_samples
,
&
draft_probs
,
&
target_probs
,
&
batch_size
,
&
draft_probs
,
&
num_speculative_tokens
,
&
batch_size
,
&
num_draft_tokens
,
&
num_speculative_tokens
,
&
d
};
&
num_draft_tokens
,
&
d
};
DISPATCH_ALIGNED_VEC_SIZE
(
DISPATCH_ALIGNED_VEC_SIZE
(
vec_size
,
VEC_SIZE
,
{
DISPATCH_DETERMINISTIC
(
deterministic
,
DETERMINISTIC
,
{
vec_size
,
VEC_SIZE
,
{
DISPATCH_DETERMINISTIC
(
deterministic
,
DETERMINISTIC
,
{
auto
kernel
=
TreeSpeculativeSamplingTargetOnly
<
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
,
VEC_SIZE
,
DETERMINISTIC
,
auto
kernel
=
TreeSpeculativeSamplingTargetOnly
<
DType
,
IdType
>
;
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
,
VEC_SIZE
,
DETERMINISTIC
,
DType
,
IdType
>
;
FLASHINFER_CUDA_CALL
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FLASHINFER_CUDA_CALL
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel
,
nblks
,
nthrs
,
args
,
smem_size
,
stream
));
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel
,
nblks
,
nthrs
,
args
,
smem_size
,
stream
));
})});
})});
...
...
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
d052f4c8
...
@@ -42,8 +42,8 @@ using fptr_t = int64_t;
...
@@ -42,8 +42,8 @@ using fptr_t = int64_t;
void
rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
sgl_fused_add_rmsnorm
(
torch
::
Tensor
input
,
torch
::
Tensor
residual
,
torch
::
Tensor
weight
,
double
eps
);
void
sgl_fused_add_rmsnorm
(
torch
::
Tensor
input
,
torch
::
Tensor
residual
,
torch
::
Tensor
weight
,
double
eps
);
void
gemma_rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
gemma_rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
gemma_fused_add_rmsnorm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
weight
,
double
eps
,
void
gemma_fused_add_rmsnorm
(
int64_t
cuda_stream
);
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
silu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
silu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
gelu_tanh_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
gelu_tanh_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
gelu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
gelu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
...
@@ -53,113 +53,219 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
...
@@ -53,113 +53,219 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
*/
*/
#ifdef USE_ROCM
#ifdef USE_ROCM
// ROCM custom allreduce
// ROCM custom allreduce
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
fptr_t
init_custom_ar
(
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
void
dispose
(
fptr_t
_fa
);
void
dispose
(
fptr_t
_fa
);
int64_t
meta_size
();
int64_t
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
void
register_buffer
(
const
std
::
vector
<
int64_t
>&
offsets
);
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
);
std
::
tuple
<
torch
::
Tensor
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
std
::
tuple
<
torch
::
Tensor
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
void
register_graph_buffers
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
torch
::
Tensor
allocate_meta_buffer
(
int64_t
size
);
torch
::
Tensor
allocate_meta_buffer
(
int64_t
size
);
torch
::
Tensor
get_meta_buffer_ipc_handle
(
torch
::
Tensor
&
inp
);
torch
::
Tensor
get_meta_buffer_ipc_handle
(
torch
::
Tensor
&
inp
);
#else
#else
// TRTLLM custom allreduce
// TRTLLM custom allreduce
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
fptr_t
init_custom_ar
(
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
int64_t
rank_id
,
const
std
::
vector
<
fptr_t
>&
barrier_out
);
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
);
void
dispose
(
fptr_t
_fa
);
void
dispose
(
fptr_t
_fa
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
void
register_graph_buffers
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
#endif
/*
/*
* From csrc/gemm
* From csrc/gemm
*/
*/
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Tensor
&
mat_a
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
torch
::
Tensor
&
mat_b
,
torch
::
Tensor
fp8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
torch
::
Dtype
&
out_dtype
,
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
fp8_scaled_mm
(
const
torch
::
Dtype
&
out_dtype
);
const
torch
::
Tensor
&
mat_a
,
void
sgl_per_token_group_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
const
torch
::
Tensor
&
mat_b
,
double
eps
,
double
fp8_min
,
double
fp8_max
);
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
);
void
sgl_per_token_group_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
const
torch
::
Dtype
&
out_dtype
,
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
/*
/*
* From csrc/moe
* From csrc/moe
*/
*/
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
void
moe_align_block_size
(
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
topk_ids
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
/*
/*
* From csrc/speculative
* From csrc/speculative
*/
*/
void
tree_speculative_sampling_target_only
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
void
tree_speculative_sampling_target_only
(
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
predicts
,
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
accept_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
at
::
Tensor
candidates
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
);
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
at
::
Tensor
target_probs
,
int64_t
depth
,
int64_t
draft_token_num
);
at
::
Tensor
draft_probs
,
bool
deterministic
=
true
,
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
int64_t
cuda_stream
=
0
);
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
/*
/*
* From FlashInfer
* From FlashInfer
*/
*/
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
void
bmm_fp8
(
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
at
::
Tensor
A
,
void
min_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
at
::
Tensor
B
,
std
::
optional
<
at
::
Tensor
>
maybe_min_p_arr
,
double
min_p_val
,
bool
deterministic
,
at
::
Tensor
D
,
int64_t
cuda_stream
);
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
void
min_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
std
::
optional
<
at
::
Tensor
>
maybe_min_p_arr
,
double
min_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
// top k renorm probs
// top k renorm probs
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
void
top_k_renorm_probs
(
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
void
top_k_renorm_probs
(
unsigned
int
top_k_val
,
int64_t
cuda_stream
);
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
unsigned
int
top_k_val
,
int64_t
cuda_stream
);
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
inline
void
top_k_renorm_probs_wrapper
(
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
inline
void
top_k_renorm_probs_wrapper
(
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
int64_t
top_k_val
,
at
::
Tensor
probs
,
int64_t
cuda_stream
)
{
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
int64_t
top_k_val
,
int64_t
cuda_stream
)
{
top_k_renorm_probs
(
probs
,
renorm_probs
,
maybe_top_k_arr
,
static_cast
<
unsigned
int
>
(
top_k_val
),
cuda_stream
);
top_k_renorm_probs
(
probs
,
renorm_probs
,
maybe_top_k_arr
,
static_cast
<
unsigned
int
>
(
top_k_val
),
cuda_stream
);
}
}
void
top_p_renorm_probs
(
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
void
top_p_renorm_probs
(
double
top_p_val
,
int64_t
cuda_stream
);
at
::
Tensor
probs
,
void
top_k_top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
at
::
Tensor
renorm_probs
,
at
::
Tensor
success
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
double
top_k_val
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
bool
deterministic
,
double
top_p_val
,
int64_t
cuda_stream
);
int64_t
cuda_stream
);
void
top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
at
::
Tensor
success
,
void
top_k_top_p_sampling_from_probs
(
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
bool
deterministic
,
at
::
Tensor
probs
,
int64_t
cuda_stream
);
at
::
Tensor
uniform_samples
,
void
apply_rope_pos_ids_cos_sin_cache
(
at
::
Tensor
q
,
at
::
Tensor
k
,
at
::
Tensor
q_rope
,
at
::
Tensor
k_rope
,
at
::
Tensor
samples
,
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
at
::
Tensor
success
,
int64_t
cuda_stream
);
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
double
top_k_val
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
void
top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
at
::
Tensor
success
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
void
apply_rope_pos_ids_cos_sin_cache
(
at
::
Tensor
q
,
at
::
Tensor
k
,
at
::
Tensor
q_rope
,
at
::
Tensor
k_rope
,
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
int64_t
cuda_stream
);
/*
/*
* Other
* Other
*/
*/
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
void
lightning_attention_decode
(
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
const
torch
::
Tensor
&
q
,
torch
::
Tensor
new_kv
);
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
);
// sgl_per_token_quant_fp8
// sgl_per_token_quant_fp8
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh
View file @
d052f4c8
...
@@ -103,7 +103,7 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world
...
@@ -103,7 +103,7 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world
return
AllReduceStrategyType
::
TWOSHOT
;
return
AllReduceStrategyType
::
TWOSHOT
;
}
}
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
void
trtCustomAllReduce
(
cudaStream_t
stream
);
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
);
}
// namespace trt_llm
}
// namespace trt_llm
sgl-kernel/src/sgl-kernel/include/utils.h
View file @
d052f4c8
...
@@ -95,7 +95,6 @@ inline int getSMVersion() {
...
@@ -95,7 +95,6 @@ inline int getSMVersion() {
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define WARP_SIZE 32
#define WARP_SIZE 32
#ifndef USE_ROCM
#ifndef USE_ROCM
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment