Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d052f4c8
Unverified
Commit
d052f4c8
authored
Mar 07, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 07, 2025
Browse files
New clang format for sgl kernel (#4194)
parent
e1aaa79a
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1217 additions
and
567 deletions
+1217
-567
python/upload_pypi.sh
python/upload_pypi.sh
+0
-6
sgl-kernel/.clang-format
sgl-kernel/.clang-format
+7
-0
sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu
...c/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu
+9
-4
sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
...l/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
+72
-44
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
...rnel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
+29
-14
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
...kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
+20
-10
sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
...ernel/csrc/attention/lightning_attention_decode_kernel.cu
+31
-15
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
...lass_extensions/epilogue/epilogue_per_row_per_col_scale.h
+41
-22
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
...l-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
+12
-8
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
...csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
+12
-10
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
...csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
+48
-24
sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
+44
-20
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
...nel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+55
-20
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
+409
-189
sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
+280
-124
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
+11
-6
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
...nel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
+25
-8
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
+11
-5
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu
+37
-16
sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu
sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu
+64
-22
No files found.
python/upload_pypi.sh
deleted
100644 → 0
View file @
e1aaa79a
cp
../README.md ../LICENSE
.
rm
-rf
dist
python3
-m
build
python3
-m
twine upload dist/
*
rm
-rf
README.md LICENSE
sgl-kernel/.clang-format
View file @
d052f4c8
...
...
@@ -6,3 +6,10 @@ DerivePointerAlignment: false
PointerAlignment: Left
NamespaceIndentation: None
SortIncludes: true
AllowShortLoopsOnASingleLine: false
BinPackParameters: false # Prevents packing parameters in declarations
BinPackArguments: false # Prevents packing arguments in function calls
AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
AlignOperands: Align # Aligns arguments vertically
PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu
View file @
d052f4c8
...
...
@@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
// support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
cudaError_t
status
=
norm
::
FusedAddRMSNorm
(
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
static_cast
<
c_type
*>
(
residual
.
data_ptr
()),
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
batch_size
,
hidden_size
,
eps
,
torch_current_stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"FusedAddRMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
static_cast
<
c_type
*>
(
residual
.
data_ptr
()),
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
batch_size
,
hidden_size
,
eps
,
torch_current_stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"FusedAddRMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
return
true
;
});
}
sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
View file @
d052f4c8
...
...
@@ -153,19 +153,20 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal
*
self_sg
,
int
rank
)
{
Signal
*
self_sg
,
int
rank
)
{
#ifdef USE_ROCM
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_SYSTEM
);
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
while
(
__scoped_atomic_load_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_DEVICE
)
<
flag
)
...
...
@@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal
*
self_sg
,
int
rank
)
{
Signal
*
self_sg
,
int
rank
)
{
#ifdef USE_ROCM
__syncthreads
();
// eliminate the case that prior writes are not visible after signals become
...
...
@@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg,
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
,
__MEMORY_SCOPE_SYSTEM
);
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
,
__MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
while
(
__scoped_atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
,
__MEMORY_SCOPE_DEVICE
)
<
flag
)
while
(
__scoped_atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
,
__MEMORY_SCOPE_DEVICE
)
<
flag
)
;
}
__syncthreads
();
...
...
@@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
// note: we don't reorder the address so the accumulation order is the same
...
...
@@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) {
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
using
P
=
typename
packed_t
<
T
>::
P
;
...
...
@@ -357,8 +372,14 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
hipIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
hipIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
:
rank_
(
rank
),
world_size_
(
offsets
.
size
()),
full_nvlink_
(
full_nvlink
),
...
...
@@ -382,8 +403,8 @@ class CustomAllreduce {
auto
[
it
,
new_handle
]
=
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
char
*
ipc_ptr
;
CUDACHECK
(
hipIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
hipIpcMemHandle_t
*
)
ipc_handle
),
hipIpcMemLazyEnablePeerAccess
));
CUDACHECK
(
hipIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
hipIpcMemHandle_t
*
)
ipc_handle
),
hipIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
}
return
it
->
second
;
...
...
@@ -399,13 +420,14 @@ class CustomAllreduce {
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// address
if
(
hipPointerGetAttribute
(
&
base_ptr
,
if
(
hipPointerGetAttribute
(
&
base_ptr
,
#ifdef USE_ROCM
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
#else
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
#endif
(
hipDeviceptr_t
)
ptr
)
!=
hipSuccess
)
(
hipDeviceptr_t
)
ptr
)
!=
hipSuccess
)
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
CUDACHECK
(
hipIpcGetMemHandle
((
hipIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
...
...
@@ -415,8 +437,8 @@ class CustomAllreduce {
void
check_rank_data_capacity
(
size_t
num
=
1
)
{
if
(
d_rank_data_base_
+
num
>
d_rank_data_end_
)
throw
std
::
runtime_error
(
"Rank data buffer is overflowed by "
+
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
throw
std
::
runtime_error
(
"Rank data buffer is overflowed by "
+
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
}
void
register_buffer
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
void
*
self
)
{
...
...
@@ -443,8 +465,8 @@ class CustomAllreduce {
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
check_rank_data_capacity
(
num_buffers
);
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
...
...
@@ -474,11 +496,17 @@ class CustomAllreduce {
* will cause contention on NVLink bus.
*/
template
<
typename
T
>
void
allreduce
(
hipStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
void
allreduce
(
hipStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
#ifndef USE_ROCM
int
threads
=
512
,
int
block_limit
=
36
){
int
threads
=
512
,
int
block_limit
=
36
){
#else
int
threads
=
512
,
int
block_limit
=
16
)
{
int
threads
=
512
,
int
block_limit
=
16
)
{
#endif
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
...
...
@@ -487,8 +515,8 @@ class CustomAllreduce {
"of "
+
std
::
to_string
(
d
));
if
(
block_limit
>
kMaxBlocks
)
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
hipStreamCaptureStatus
status
;
...
...
@@ -499,17 +527,17 @@ class CustomAllreduce {
}
else
{
auto
it
=
buffers_
.
find
(
input
);
if
(
it
==
buffers_
.
end
())
throw
std
::
runtime_error
(
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
" is not registered!"
);
throw
std
::
runtime_error
(
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
" is not registered!"
);
ptrs
=
it
->
second
;
}
size
/=
d
;
auto
bytes
=
size
*
sizeof
(
typename
packed_t
<
T
>::
P
);
int
blocks
=
::
min
(
block_limit
,
(
size
+
threads
-
1
)
/
threads
);
#define KL(ngpus, name)
\
hipLaunchKernelGGL(
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_,
\
size);
#define KL(ngpus, name) \
hipLaunchKernelGGL(
\
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_,
size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
...
...
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
View file @
d052f4c8
...
...
@@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) {
return
c
.
packed
;
}
__inline__
__device__
void
multi_gpu_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
)
{
__inline__
__device__
void
multi_gpu_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
)
{
// After this function, at least one block in each GPU has reached the barrier
if
(
tidx
<
world_size
)
{
// we can think of signals having the shape [world_size, world_size]
...
...
@@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
}
template
<
bool
start
,
bool
need_fence
=
false
>
__inline__
__device__
void
block_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
,
int
const
grid_size
)
{
__inline__
__device__
void
block_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
,
int
const
grid_size
)
{
if
constexpr
(
!
start
)
{
__syncthreads
();
}
...
...
@@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Each block accumulates the values from the different GPUs on the same node.
for
(
size_t
iter_offset
=
chunk_start
;
iter_offset
<
chunk_end
;
iter_offset
+=
blockDim
.
x
*
NUM_ELTS
)
{
...
...
@@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
}
}
}
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Each block accumulates the values from the different GPUs on the same node.
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
...
...
@@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
}
}
block_barrier
<
false
,
true
>
(
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
block_barrier
<
false
,
true
>
(
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Gather all needed elts from other intra-node ranks
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
...
...
@@ -459,8 +470,12 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
>
void
dispatchARKernels
(
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
cudaStream_t
stream
)
{
void
dispatchARKernels
(
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
cudaStream_t
stream
)
{
switch
(
algo
)
{
case
AllReduceStrategyType
::
ONESHOT
:
{
oneShotAllReduceKernel
<
T
,
RANKS_PER_NODE
,
COPY_INPUT
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
...
...
@@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
}
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
)
{
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
)
{
if
(
params
.
elts_total
==
0
)
{
return
;
}
...
...
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
View file @
d052f4c8
...
...
@@ -29,9 +29,14 @@ using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
class
AllReduceMeta
{
public:
AllReduceMeta
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
AllReduceMeta
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
this
->
rank_id
=
(
int
)
rank_id
;
this
->
world_size
=
(
int
)
world_size
;
this
->
barrier_in
=
barrier_in
;
...
...
@@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
return
num_elements
%
(
16
/
((
get_bits
(
dtype
)
+
7
)
/
8
))
==
0
;
}
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
auto
m
=
new
AllReduceMeta
(
rank_id
,
world_size
,
rank_data
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
);
return
(
fptr_t
)
m
;
}
...
...
@@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
auto
[
it
,
new_handle
]
=
meta
->
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
char
*
ipc_ptr
;
CHECK_CUDA_SUCCESS
(
cudaIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
CHECK_CUDA_SUCCESS
(
cudaIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
}
return
it
->
second
;
...
...
@@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
std
::
vector
<
std
::
string
>
handle_bytes
;
handle_bytes
.
reserve
(
handles
.
size
());
...
...
sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
View file @
d052f4c8
...
...
@@ -23,15 +23,18 @@ limitations under the License.
#define THREADS_PER_BLOCK 128
template
<
typename
T
>
__global__
void
lightning_attention_decode_kernel
(
const
T
*
__restrict__
q
,
// [b, h, 1, d]
const
T
*
__restrict__
k
,
// [b, h, 1, d]
const
T
*
__restrict__
v
,
// [b, h, 1, e]
const
float
*
__restrict__
past_kv
,
// [b, h, d, e]
const
float
*
__restrict__
slope
,
// [h, 1, 1]
T
*
__restrict__
output
,
// [b, h, 1, e]
float
*
__restrict__
new_kv
,
// [b, h, d, e]
const
int
batch_size
,
const
int
num_heads
,
const
int
qk_dim
,
const
int
v_dim
)
{
__global__
void
lightning_attention_decode_kernel
(
const
T
*
__restrict__
q
,
// [b, h, 1, d]
const
T
*
__restrict__
k
,
// [b, h, 1, d]
const
T
*
__restrict__
v
,
// [b, h, 1, e]
const
float
*
__restrict__
past_kv
,
// [b, h, d, e]
const
float
*
__restrict__
slope
,
// [h, 1, 1]
T
*
__restrict__
output
,
// [b, h, 1, e]
float
*
__restrict__
new_kv
,
// [b, h, d, e]
const
int
batch_size
,
const
int
num_heads
,
const
int
qk_dim
,
const
int
v_dim
)
{
extern
__shared__
char
smem
[];
T
*
__restrict__
q_shared
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
__restrict__
k_shared
=
reinterpret_cast
<
T
*>
(
smem
+
qk_dim
*
sizeof
(
T
));
...
...
@@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
}
}
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
)
{
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
)
{
TORCH_CHECK
(
q
.
is_contiguous
(),
"q must be contiguous"
);
TORCH_CHECK
(
k
.
is_contiguous
(),
"k must be contiguous"
);
TORCH_CHECK
(
v
.
is_contiguous
(),
"v must be contiguous"
);
...
...
@@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
q
.
scalar_type
(),
"lightning_attention_decode_kernel"
,
([
&
]
{
size_t
smem_size
=
(
2
*
qk_dim
+
2
*
v_dim
)
*
sizeof
(
scalar_t
)
+
qk_dim
*
(
v_dim
+
1
)
*
sizeof
(
float
);
lightning_attention_decode_kernel
<
scalar_t
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
q
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
v
.
data_ptr
<
scalar_t
>
(),
past_kv
.
data_ptr
<
float
>
(),
slope
.
data_ptr
<
float
>
(),
output
.
data_ptr
<
scalar_t
>
(),
new_kv
.
data_ptr
<
float
>
(),
batch_size
,
num_heads
,
qk_dim
,
v_dim
);
q
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
v
.
data_ptr
<
scalar_t
>
(),
past_kv
.
data_ptr
<
float
>
(),
slope
.
data_ptr
<
float
>
(),
output
.
data_ptr
<
scalar_t
>
(),
new_kv
.
data_ptr
<
float
>
(),
batch_size
,
num_heads
,
qk_dim
,
v_dim
);
}));
}
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
View file @
d052f4c8
...
...
@@ -25,9 +25,15 @@ namespace cutlass {
namespace
epilogue
{
namespace
threadblock
{
template
<
typename
ThreadblockShape_
,
int
ThreadCount
,
typename
ScaleTileIterator_
,
typename
OutputTileIterator_
,
typename
ElementAccumulator_
,
typename
ElementCompute_
,
typename
ElementwiseFunctor_
,
bool
UseMasking_
=
false
>
template
<
typename
ThreadblockShape_
,
int
ThreadCount
,
typename
ScaleTileIterator_
,
typename
OutputTileIterator_
,
typename
ElementAccumulator_
,
typename
ElementCompute_
,
typename
ElementwiseFunctor_
,
bool
UseMasking_
=
false
>
class
EpilogueVisitorPerRowPerCol
{
public:
using
ThreadblockShape
=
ThreadblockShape_
;
...
...
@@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol {
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
)
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
0
),
batch_stride_C
(
0
),
batch_stride_D
(
0
)
{}
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
,
int64_t
batch_stride_alpha_
,
int64_t
batch_stride_C_
,
int64_t
batch_stride_D_
)
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
,
int64_t
batch_stride_alpha_
,
int64_t
batch_stride_C_
,
int64_t
batch_stride_D_
)
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
batch_stride_alpha_
),
batch_stride_C
(
batch_stride_C_
),
...
...
@@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol {
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol
(
Params
const
&
params
,
SharedStorage
&
shared_storage
,
cutlass
::
MatrixCoord
const
&
problem_size
,
int
thread_idx
,
int
warp_idx
,
int
lane_idx
,
typename
ScaleTileIterator
::
Params
params_alpha_col
,
typename
OutputTileIterator
::
Params
params_C
,
typename
OutputTileIterator
::
Params
params_D
,
bool
with_bias
,
bool
per_token_quant
,
bool
per_channel_quant
,
AlphaScaleElementType
*
ptr_alpha_row
,
AlphaScaleElementType
*
ptr_alpha_col
,
typename
OutputTileIterator
::
Element
*
ptr_C
,
typename
OutputTileIterator
::
Element
*
ptr_D
,
cutlass
::
MatrixCoord
const
&
threadblock_offset
=
cutlass
::
MatrixCoord
(
0
,
0
),
int
column_offset
=
0
,
cutlass
::
MatrixCoord
const
&
problem_size_real
=
cutlass
::
MatrixCoord
(
0
,
0
))
EpilogueVisitorPerRowPerCol
(
Params
const
&
params
,
SharedStorage
&
shared_storage
,
cutlass
::
MatrixCoord
const
&
problem_size
,
int
thread_idx
,
int
warp_idx
,
int
lane_idx
,
typename
ScaleTileIterator
::
Params
params_alpha_col
,
typename
OutputTileIterator
::
Params
params_C
,
typename
OutputTileIterator
::
Params
params_D
,
bool
with_bias
,
bool
per_token_quant
,
bool
per_channel_quant
,
AlphaScaleElementType
*
ptr_alpha_row
,
AlphaScaleElementType
*
ptr_alpha_col
,
typename
OutputTileIterator
::
Element
*
ptr_C
,
typename
OutputTileIterator
::
Element
*
ptr_D
,
cutlass
::
MatrixCoord
const
&
threadblock_offset
=
cutlass
::
MatrixCoord
(
0
,
0
),
int
column_offset
=
0
,
cutlass
::
MatrixCoord
const
&
problem_size_real
=
cutlass
::
MatrixCoord
(
0
,
0
))
:
params_
(
params
),
shared_storage_
(
shared_storage
),
extent_
(
problem_size
),
...
...
@@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol {
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void
set_k_partition
(
int
split_k_index
,
///< Index of this threadblock within split-K partitioned scheme
int
split_k_slices
)
{
///< Total number of split-K slices
void
set_k_partition
(
int
split_k_index
,
///< Index of this threadblock within split-K partitioned scheme
int
split_k_slices
)
{
///< Total number of split-K slices
}
/// Called to set the batch index
...
...
@@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol {
private:
CUTLASS_DEVICE
ComputeFragment
per_token_channel_scale_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
per_token_channel_scale_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
...
...
@@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol {
}
CUTLASS_DEVICE
ComputeFragment
per_token_scale_accumulator_
(
ComputeFragment
const
&
accum
,
AlphaScaleElementType
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
per_token_scale_accumulator_
(
ComputeFragment
const
&
accum
,
AlphaScaleElementType
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
...
...
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
View file @
d052f4c8
...
...
@@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
template
<
int
Stages_
,
class
ClusterShape_
=
Shape
<
_1
,
_1
,
_1
>,
class
KernelSchedule
=
KernelTmaWarpSpecialized
,
int
ScaleGranularityM
=
0
// `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
template
<
int
Stages_
,
class
ClusterShape_
=
Shape
<
_1
,
_1
,
_1
>,
class
KernelSchedule
=
KernelTmaWarpSpecialized
,
int
ScaleGranularityM
=
0
// `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
:
MainloopSm90TmaGmmaWarpSpecialized
<
Stages_
,
ClusterShape_
,
KernelSchedule
>
{
static_assert
(
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
,
"KernelSchedule must be one of the warp specialized policies"
);
static_assert
(
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
,
"KernelSchedule must be one of the warp specialized policies"
);
};
//////////////////////////////////////////////////////////////////////////////
...
...
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
View file @
d052f4c8
...
...
@@ -159,8 +159,9 @@ class GemmUniversalBaseCompat {
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
dim3
result
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
CUTLASS_TRACE_HOST
(
" grid_tiled_shape: "
<<
grid_tiled_shape
<<
"
\n
"
<<
" result = {"
<<
result
<<
"}"
);
CUTLASS_TRACE_HOST
(
" grid_tiled_shape: "
<<
grid_tiled_shape
<<
"
\n
"
<<
" result = {"
<<
result
<<
"}"
);
return
result
;
}
...
...
@@ -175,8 +176,8 @@ class GemmUniversalBaseCompat {
CUTLASS_TRACE_HOST
(
" smem_size: "
<<
smem_size
<<
" bytes"
);
if
(
smem_size
<=
(
48
<<
10
))
{
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
smem_size
);
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
smem_size
);
if
(
result
==
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
...
...
@@ -184,12 +185,12 @@ class GemmUniversalBaseCompat {
}
}
else
{
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
0
);
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
0
);
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<<
cudaGetErrorString
(
result
));
CUTLASS_TRACE_HOST
(
"
cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
...
...
@@ -226,8 +227,9 @@ class GemmUniversalBaseCompat {
/// Initializes GEMM state from arguments.
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
size_t
workspace_bytes
=
get_workspace_size
(
args
);
...
...
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
View file @
d052f4c8
...
...
@@ -32,10 +32,11 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
///! Threadblock swizzling function
>
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
///! Threadblock swizzling function
>
struct
GemmWithEpilogueVisitor
{
public:
using
Mma
=
Mma_
;
...
...
@@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor {
Arguments
()
:
mode
(
GemmUniversalMode
::
kGemm
),
batch_count
(
1
)
{}
/// constructs an arguments structure
Arguments
(
GemmCoord
problem_size_
,
TensorRefA
ref_A_
,
TensorRefB
ref_B_
,
TensorRefAlphaCol
ref_alpha_col_
,
TensorRefAlphaRow
ref_alpha_row_
,
TensorRefC
ref_C_
,
TensorRefC
ref_D_
,
typename
EpilogueVisitor
::
Arguments
epilogue_visitor_
)
Arguments
(
GemmCoord
problem_size_
,
TensorRefA
ref_A_
,
TensorRefB
ref_B_
,
TensorRefAlphaCol
ref_alpha_col_
,
TensorRefAlphaRow
ref_alpha_row_
,
TensorRefC
ref_C_
,
TensorRefC
ref_D_
,
typename
EpilogueVisitor
::
Arguments
epilogue_visitor_
)
:
mode
(
GemmUniversalMode
::
kGemm
),
problem_size
(
problem_size_
),
batch_count
(
1
),
...
...
@@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor {
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajor
>::
value
)
{
isAMisaligned
=
problem_size
.
m
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
...
...
@@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor {
isBMisaligned
=
problem_size
.
n
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
ColumnMajor
>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
64
>>::
value
)
{
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
64
>>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
...
...
@@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor {
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajor
>::
value
)
{
isCMisaligned
=
problem_size
.
m
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
...
...
@@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor {
int
thread_idx
=
threadIdx
.
x
;
// Construct iterators to A and B operands
typename
Mma
::
IteratorA
iterator_A
(
params
.
params_A
,
ptr_A
,
{
params
.
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorA
iterator_A
(
params
.
params_A
,
ptr_A
,
{
params
.
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorB
iterator_B
(
params
.
params_B
,
ptr_B
,
{
problem_size_k
,
params
.
problem_size
.
n
()},
thread_idx
,
tb_offset_B
);
typename
Mma
::
IteratorB
iterator_B
(
params
.
params_B
,
ptr_B
,
{
problem_size_k
,
params
.
problem_size
.
n
()},
thread_idx
,
tb_offset_B
);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
...
...
@@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor {
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// assume identity swizzle
MatrixCoord
threadblock_offset
(
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
);
MatrixCoord
threadblock_offset
(
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
);
int
block_idx
=
threadblock_tile_offset
.
m
()
+
threadblock_tile_offset
.
n
()
*
params
.
grid_tiled_shape
.
m
();
...
...
@@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor {
with_bias
=
false
;
}
EpilogueVisitor
epilogue_visitor
(
params
.
epilogue_visitor
,
shared_storage
.
epilogue
.
visitor
,
params
.
problem_size
.
mn
(),
thread_idx
,
warp_idx
,
lane_idx
,
params
.
params_alpha_col
,
params
.
params_C
,
params
.
params_D
,
with_bias
,
true
,
true
,
params
.
ptr_alpha_row
,
params
.
ptr_alpha_col
,
params
.
ptr_C
,
params
.
ptr_D
,
threadblock_offset
,
blockIdx
.
y
*
params
.
problem_size
.
m
());
EpilogueVisitor
epilogue_visitor
(
params
.
epilogue_visitor
,
shared_storage
.
epilogue
.
visitor
,
params
.
problem_size
.
mn
(),
thread_idx
,
warp_idx
,
lane_idx
,
params
.
params_alpha_col
,
params
.
params_C
,
params
.
params_D
,
with_bias
,
true
,
true
,
params
.
ptr_alpha_row
,
params
.
ptr_alpha_col
,
params
.
ptr_C
,
params
.
ptr_D
,
threadblock_offset
,
blockIdx
.
y
*
params
.
problem_size
.
m
());
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
)
{
// Indicate which position in a serial reduction the output operator is currently updating
...
...
sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
View file @
d052f4c8
...
...
@@ -21,10 +21,13 @@
#include "utils.h"
static
void
check_group_count
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
const
std
::
vector
<
torch
::
Tensor
>&
outputs
)
{
TORCH_CHECK
(((
inputs
.
size
()
==
weights
.
size
())
&&
(
inputs
.
size
()
==
outputs
.
size
())),
"The group count of inputs, weights and outputs should be the same."
);
static
void
check_group_count
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
const
std
::
vector
<
torch
::
Tensor
>&
outputs
)
{
TORCH_CHECK
(
((
inputs
.
size
()
==
weights
.
size
())
&&
(
inputs
.
size
()
==
outputs
.
size
())),
"The group count of inputs, weights and outputs should be the same."
);
}
static
void
check_device_dtype
(
const
torch
::
Dtype
&
dtype
,
const
std
::
vector
<
torch
::
Tensor
>&
tensors
)
{
...
...
@@ -68,21 +71,26 @@ static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tens
static
torch
::
Tensor
create_ptr_pointer
(
const
std
::
vector
<
void
*>&
ptrs
,
cudaStream_t
stream
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kDouble
).
device
(
torch
::
kCUDA
);
torch
::
Tensor
gpu_ptrs
=
torch
::
empty
({
static_cast
<
int
>
(
ptrs
.
size
())},
options
);
TORCH_CHECK
(
cudaMemcpyAsync
(
gpu_ptrs
.
data_ptr
(),
ptrs
.
data
(),
sizeof
(
void
*
)
*
ptrs
.
size
(),
cudaMemcpyHostToDevice
,
stream
)
==
CUBLAS_STATUS_SUCCESS
);
TORCH_CHECK
(
cudaMemcpyAsync
(
gpu_ptrs
.
data_ptr
(),
ptrs
.
data
(),
sizeof
(
void
*
)
*
ptrs
.
size
(),
cudaMemcpyHostToDevice
,
stream
)
==
CUBLAS_STATUS_SUCCESS
);
return
gpu_ptrs
;
}
// We want compute input @ weight^T in row major
// This is equivalent to computing weight @ input^T in col major
// Cublas only accepts matrix in column major, so this arrangement is needed
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
// b: (m, k) row major = (k, m) col major
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
// a: (n, k) row major = (n, k)^T col major
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
// c: (m, n) row major = (n, m) col major
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cuda_stream
)
{
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"cublas grouped_gemm can"
"only be applied to float16 and bfloat16 dtype"
);
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
// b: (m, k) row major = (k, m) col major
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
// a: (n, k) row major = (n, k)^T col major
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
// c: (m, n) row major = (n, m) col major
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cuda_stream
)
{
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"cublas grouped_gemm can"
"only be applied to float16 and bfloat16 dtype"
);
int
group_count
=
inputs
.
size
();
check_group_count
(
inputs
,
weights
,
outputs
);
...
...
@@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k
torch
::
Tensor
d_c
=
create_ptr_pointer
(
c_array
,
stream
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
auto
status
=
cublasGemmGroupedBatchedEx
(
handle
,
transa_array
.
data
(),
transb_array
.
data
(),
m_array
.
data
(),
n_array
.
data
(),
k_array
.
data
(),
alpha_array
.
data
(),
(
void
**
)
d_a
.
data_ptr
(),
cuda_data_type
,
lda_array
.
data
(),
(
void
**
)
d_b
.
data_ptr
(),
cuda_data_type
,
ldb_array
.
data
(),
beta_array
.
data
(),
(
void
**
)
d_c
.
data_ptr
(),
cuda_data_type
,
ldc_array
.
data
(),
group_count
,
group_size
.
data
(),
CUBLAS_COMPUTE_32F
);
auto
status
=
cublasGemmGroupedBatchedEx
(
handle
,
transa_array
.
data
(),
transb_array
.
data
(),
m_array
.
data
(),
n_array
.
data
(),
k_array
.
data
(),
alpha_array
.
data
(),
(
void
**
)
d_a
.
data_ptr
(),
cuda_data_type
,
lda_array
.
data
(),
(
void
**
)
d_b
.
data_ptr
(),
cuda_data_type
,
ldb_array
.
data
(),
beta_array
.
data
(),
(
void
**
)
d_c
.
data_ptr
(),
cuda_data_type
,
ldc_array
.
data
(),
group_count
,
group_size
.
data
(),
CUBLAS_COMPUTE_32F
);
TORCH_CHECK
(
status
==
CUBLAS_STATUS_SUCCESS
,
"cublas grouped gemm failed: "
,
cublasGetStatusString
(
status
));
TORCH_CHECK
(
cudaStreamSynchronize
(
stream
)
==
cudaSuccess
,
"Failed when stream synchronization"
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"Cublas GroupGemm is not implemented with current compute capability: "
,
getSMVersion
());
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"Cublas GroupGemm is not implemented with current compute capability: "
,
getSMVersion
());
}
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
View file @
d052f4c8
...
...
@@ -35,8 +35,12 @@
using
namespace
cute
;
template
<
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
,
int
ScaleGranularityM
=
1
>
void
launch_sm90_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
void
launch_sm90_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
ElementAccumulator
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
...
...
@@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
Gemm
gemm_op
;
...
...
@@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
}
template
<
typename
OutType
>
void
sm90_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
void
sm90_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
launch_sm90_fp8_blockwise_scaled_mm
<
OutType
,
TileShape
,
ClusterShape
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
)
{
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
)
{
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
...
...
@@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_a must be a column major tensor"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
(
mat_a
.
size
(
1
)
*
mat_a
.
element_size
())
%
16
==
0
,
"mat_a must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
(
mat_b
.
size
(
0
)
*
mat_b
.
element_size
())
%
16
==
0
,
"mat_b must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
(
mat_a
.
size
(
1
)
*
mat_a
.
element_size
())
%
16
==
0
,
"mat_a must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
(
mat_b
.
size
(
0
)
*
mat_b
.
element_size
())
%
16
==
0
,
"mat_b must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
mat_a
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_a must be Float8_e4m3fn"
);
TORCH_CHECK
(
mat_b
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_b must be Float8_e4m3fn"
);
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"out_dtype must be Half or BFloat16"
);
...
...
@@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
}
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
View file @
d052f4c8
This diff is collapsed.
Click to expand it.
sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
View file @
d052f4c8
This diff is collapsed.
Click to expand it.
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
View file @
d052f4c8
...
...
@@ -8,8 +8,8 @@
#include "utils.h"
template
<
typename
T
>
__global__
void
per_tensor_absmax_kernel
(
const
T
*
__restrict__
input
,
float
*
__restrict__
output_s
,
const
int64_t
num_elements
)
{
__global__
void
per_tensor_absmax_kernel
(
const
T
*
__restrict__
input
,
float
*
__restrict__
output_s
,
const
int64_t
num_elements
)
{
float
max_value
=
0.0
f
;
unsigned
int
tid
=
threadIdx
.
x
;
unsigned
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r
}
template
<
typename
T
>
__global__
void
per_tensor_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output
,
const
float
*
__restrict__
scale
,
const
int64_t
num_elements
)
{
__global__
void
per_tensor_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output
,
const
float
*
__restrict__
scale
,
const
int64_t
num_elements
)
{
const
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
const
float
scale_val
=
1.0
f
/
(
*
scale
);
...
...
@@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
}
per_tensor_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
num_elements
);
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
num_elements
);
return
true
;
});
}
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
View file @
d052f4c8
...
...
@@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) {
}
template
<
typename
T
>
__global__
void
per_token_group_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int
group_size
,
const
int
num_groups
,
const
float
eps
,
const
float
fp8_min
,
const
float
fp8_max
)
{
__global__
void
per_token_group_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int
group_size
,
const
int
num_groups
,
const
float
eps
,
const
float
fp8_min
,
const
float
fp8_max
)
{
const
int
groups_per_block
=
16
;
const
int
local_group_id
=
threadIdx
.
x
/
16
;
const
int
lane_id
=
threadIdx
.
x
%
16
;
...
...
@@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
}
}
void
sgl_per_token_group_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
)
{
void
sgl_per_token_group_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_s
);
...
...
@@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q,
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
per_token_group_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
output_q
.
data_ptr
(),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
group_size
,
num_groups
,
(
float
)
eps
,
(
float
)
fp8_min
,
(
float
)
fp8_max
);
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
output_q
.
data_ptr
(),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
group_size
,
num_groups
,
(
float
)
eps
,
(
float
)
fp8_min
,
(
float
)
fp8_max
);
return
true
;
});
}
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
View file @
d052f4c8
...
...
@@ -7,9 +7,12 @@
#include "utils.h"
template
<
typename
T
>
__global__
void
per_token_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int64_t
hidden_dim
,
const
int64_t
num_tokens
)
{
__global__
void
per_token_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int64_t
hidden_dim
,
const
int64_t
num_tokens
)
{
const
int
token_idx
=
blockIdx
.
x
;
if
(
token_idx
>=
num_tokens
)
return
;
...
...
@@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
per_token_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
return
true
;
});
}
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
d052f4c8
...
...
@@ -25,9 +25,11 @@ limitations under the License.
#define WARP_SIZE 32
template
<
typename
scalar_t
>
__global__
void
count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
size_t
numel
)
{
__global__
void
count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
size_t
numel
)
{
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
...
...
@@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__
}
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
)
{
__global__
void
moe_align_block_size_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
)
{
__shared__
int32_t
shared_counts
[
WARP_SIZE
][
8
];
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
...
...
@@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
}
}
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
TORCH_CHECK
(
num_experts
==
256
,
"moe_align_block_size kernel only support deepseek v3 now."
);
DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
align_kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
align_kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
const
int
block_threads
=
256
;
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
...
...
@@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
auto
sort_kernel
=
count_and_sort_expert_tokens_kernel
<
scalar_t
>
;
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
});
}
sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu
View file @
d052f4c8
...
...
@@ -23,10 +23,18 @@
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
__global__
void
build_tree_efficient
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int64_t
*
retrive_next_token
,
int64_t
*
retrive_next_sibling
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
__global__
void
build_tree_efficient
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int64_t
*
retrive_next_token
,
int64_t
*
retrive_next_sibling
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
int
bid
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
...
...
@@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind
}
}
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
// TODO (ying) check shape
// TODO (ying) check type
int
bs
=
parent_list
.
size
(
0
);
...
...
@@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
build_tree_efficient
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_sibling
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_sibling
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
}
// parent_list [bs, topk * (depth - 1) + 1)]
...
...
@@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2]
__global__
void
build_tree
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
__global__
void
build_tree
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
int
bid
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
...
...
@@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_
}
}
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
// TODO (ying) check shape
// TODO (ying) check type
int
bs
=
parent_list
.
size
(
0
);
...
...
@@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
build_tree
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment