Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
d052f4c8
"vscode:/vscode.git/clone" did not exist on "d9dd29f322ba98bb69ffd0e36451aecd0bc917b1"
Unverified
Commit
d052f4c8
authored
Mar 07, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 07, 2025
Browse files
New clang format for sgl kernel (#4194)
parent
e1aaa79a
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1217 additions
and
567 deletions
+1217
-567
python/upload_pypi.sh
python/upload_pypi.sh
+0
-6
sgl-kernel/.clang-format
sgl-kernel/.clang-format
+7
-0
sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu
...c/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu
+9
-4
sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
...l/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
+72
-44
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
...rnel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
+29
-14
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
...kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
+20
-10
sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
...ernel/csrc/attention/lightning_attention_decode_kernel.cu
+31
-15
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
...lass_extensions/epilogue/epilogue_per_row_per_col_scale.h
+41
-22
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
...l-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
+12
-8
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
...csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
+12
-10
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
...csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
+48
-24
sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
+44
-20
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
...nel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+55
-20
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
+409
-189
sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
+280
-124
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
+11
-6
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
...nel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
+25
-8
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
+11
-5
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu
+37
-16
sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu
sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu
+64
-22
No files found.
python/upload_pypi.sh
deleted
100644 → 0
View file @
e1aaa79a
cp
../README.md ../LICENSE
.
rm
-rf
dist
python3
-m
build
python3
-m
twine upload dist/
*
rm
-rf
README.md LICENSE
sgl-kernel/.clang-format
View file @
d052f4c8
...
...
@@ -6,3 +6,10 @@ DerivePointerAlignment: false
PointerAlignment: Left
NamespaceIndentation: None
SortIncludes: true
AllowShortLoopsOnASingleLine: false
BinPackParameters: false # Prevents packing parameters in declarations
BinPackArguments: false # Prevents packing arguments in function calls
AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
AlignOperands: Align # Aligns arguments vertically
PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu
View file @
d052f4c8
...
...
@@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
// support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
cudaError_t
status
=
norm
::
FusedAddRMSNorm
(
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
static_cast
<
c_type
*>
(
residual
.
data_ptr
()),
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
batch_size
,
hidden_size
,
eps
,
torch_current_stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"FusedAddRMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
static_cast
<
c_type
*>
(
residual
.
data_ptr
()),
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
batch_size
,
hidden_size
,
eps
,
torch_current_stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"FusedAddRMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
return
true
;
});
}
sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
View file @
d052f4c8
...
...
@@ -153,19 +153,20 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal
*
self_sg
,
int
rank
)
{
Signal
*
self_sg
,
int
rank
)
{
#ifdef USE_ROCM
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_SYSTEM
);
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
while
(
__scoped_atomic_load_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_DEVICE
)
<
flag
)
...
...
@@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal
*
self_sg
,
int
rank
)
{
Signal
*
self_sg
,
int
rank
)
{
#ifdef USE_ROCM
__syncthreads
();
// eliminate the case that prior writes are not visible after signals become
...
...
@@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg,
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
,
__MEMORY_SCOPE_SYSTEM
);
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
,
__MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
while
(
__scoped_atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
,
__MEMORY_SCOPE_DEVICE
)
<
flag
)
while
(
__scoped_atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
,
__MEMORY_SCOPE_DEVICE
)
<
flag
)
;
}
__syncthreads
();
...
...
@@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
// note: we don't reorder the address so the accumulation order is the same
...
...
@@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) {
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
using
P
=
typename
packed_t
<
T
>::
P
;
...
...
@@ -357,8 +372,14 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
hipIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
hipIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
:
rank_
(
rank
),
world_size_
(
offsets
.
size
()),
full_nvlink_
(
full_nvlink
),
...
...
@@ -382,8 +403,8 @@ class CustomAllreduce {
auto
[
it
,
new_handle
]
=
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
char
*
ipc_ptr
;
CUDACHECK
(
hipIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
hipIpcMemHandle_t
*
)
ipc_handle
),
hipIpcMemLazyEnablePeerAccess
));
CUDACHECK
(
hipIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
hipIpcMemHandle_t
*
)
ipc_handle
),
hipIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
}
return
it
->
second
;
...
...
@@ -399,13 +420,14 @@ class CustomAllreduce {
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// address
if
(
hipPointerGetAttribute
(
&
base_ptr
,
if
(
hipPointerGetAttribute
(
&
base_ptr
,
#ifdef USE_ROCM
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
#else
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
#endif
(
hipDeviceptr_t
)
ptr
)
!=
hipSuccess
)
(
hipDeviceptr_t
)
ptr
)
!=
hipSuccess
)
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
CUDACHECK
(
hipIpcGetMemHandle
((
hipIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
...
...
@@ -415,8 +437,8 @@ class CustomAllreduce {
void
check_rank_data_capacity
(
size_t
num
=
1
)
{
if
(
d_rank_data_base_
+
num
>
d_rank_data_end_
)
throw
std
::
runtime_error
(
"Rank data buffer is overflowed by "
+
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
throw
std
::
runtime_error
(
"Rank data buffer is overflowed by "
+
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
}
void
register_buffer
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
void
*
self
)
{
...
...
@@ -443,8 +465,8 @@ class CustomAllreduce {
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
check_rank_data_capacity
(
num_buffers
);
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
...
...
@@ -474,11 +496,17 @@ class CustomAllreduce {
* will cause contention on NVLink bus.
*/
template
<
typename
T
>
void
allreduce
(
hipStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
void
allreduce
(
hipStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
#ifndef USE_ROCM
int
threads
=
512
,
int
block_limit
=
36
){
int
threads
=
512
,
int
block_limit
=
36
){
#else
int
threads
=
512
,
int
block_limit
=
16
)
{
int
threads
=
512
,
int
block_limit
=
16
)
{
#endif
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
...
...
@@ -487,8 +515,8 @@ class CustomAllreduce {
"of "
+
std
::
to_string
(
d
));
if
(
block_limit
>
kMaxBlocks
)
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
hipStreamCaptureStatus
status
;
...
...
@@ -499,17 +527,17 @@ class CustomAllreduce {
}
else
{
auto
it
=
buffers_
.
find
(
input
);
if
(
it
==
buffers_
.
end
())
throw
std
::
runtime_error
(
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
" is not registered!"
);
throw
std
::
runtime_error
(
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
" is not registered!"
);
ptrs
=
it
->
second
;
}
size
/=
d
;
auto
bytes
=
size
*
sizeof
(
typename
packed_t
<
T
>::
P
);
int
blocks
=
::
min
(
block_limit
,
(
size
+
threads
-
1
)
/
threads
);
#define KL(ngpus, name)
\
hipLaunchKernelGGL(
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_,
\
size);
#define KL(ngpus, name) \
hipLaunchKernelGGL(
\
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_,
size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
...
...
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
View file @
d052f4c8
...
...
@@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) {
return
c
.
packed
;
}
__inline__
__device__
void
multi_gpu_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
)
{
__inline__
__device__
void
multi_gpu_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
)
{
// After this function, at least one block in each GPU has reached the barrier
if
(
tidx
<
world_size
)
{
// we can think of signals having the shape [world_size, world_size]
...
...
@@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
}
template
<
bool
start
,
bool
need_fence
=
false
>
__inline__
__device__
void
block_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
,
int
const
grid_size
)
{
__inline__
__device__
void
block_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
,
int
const
grid_size
)
{
if
constexpr
(
!
start
)
{
__syncthreads
();
}
...
...
@@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Each block accumulates the values from the different GPUs on the same node.
for
(
size_t
iter_offset
=
chunk_start
;
iter_offset
<
chunk_end
;
iter_offset
+=
blockDim
.
x
*
NUM_ELTS
)
{
...
...
@@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
}
}
}
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Each block accumulates the values from the different GPUs on the same node.
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
...
...
@@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
}
}
block_barrier
<
false
,
true
>
(
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
block_barrier
<
false
,
true
>
(
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Gather all needed elts from other intra-node ranks
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
...
...
@@ -459,8 +470,12 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
>
void
dispatchARKernels
(
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
cudaStream_t
stream
)
{
void
dispatchARKernels
(
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
cudaStream_t
stream
)
{
switch
(
algo
)
{
case
AllReduceStrategyType
::
ONESHOT
:
{
oneShotAllReduceKernel
<
T
,
RANKS_PER_NODE
,
COPY_INPUT
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
...
...
@@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
}
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
)
{
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
)
{
if
(
params
.
elts_total
==
0
)
{
return
;
}
...
...
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
View file @
d052f4c8
...
...
@@ -29,9 +29,14 @@ using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
class
AllReduceMeta
{
public:
AllReduceMeta
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
AllReduceMeta
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
this
->
rank_id
=
(
int
)
rank_id
;
this
->
world_size
=
(
int
)
world_size
;
this
->
barrier_in
=
barrier_in
;
...
...
@@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
return
num_elements
%
(
16
/
((
get_bits
(
dtype
)
+
7
)
/
8
))
==
0
;
}
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
auto
m
=
new
AllReduceMeta
(
rank_id
,
world_size
,
rank_data
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
);
return
(
fptr_t
)
m
;
}
...
...
@@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
auto
[
it
,
new_handle
]
=
meta
->
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
char
*
ipc_ptr
;
CHECK_CUDA_SUCCESS
(
cudaIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
CHECK_CUDA_SUCCESS
(
cudaIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
}
return
it
->
second
;
...
...
@@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
std
::
vector
<
std
::
string
>
handle_bytes
;
handle_bytes
.
reserve
(
handles
.
size
());
...
...
sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
View file @
d052f4c8
...
...
@@ -23,15 +23,18 @@ limitations under the License.
#define THREADS_PER_BLOCK 128
template
<
typename
T
>
__global__
void
lightning_attention_decode_kernel
(
const
T
*
__restrict__
q
,
// [b, h, 1, d]
const
T
*
__restrict__
k
,
// [b, h, 1, d]
const
T
*
__restrict__
v
,
// [b, h, 1, e]
const
float
*
__restrict__
past_kv
,
// [b, h, d, e]
const
float
*
__restrict__
slope
,
// [h, 1, 1]
T
*
__restrict__
output
,
// [b, h, 1, e]
float
*
__restrict__
new_kv
,
// [b, h, d, e]
const
int
batch_size
,
const
int
num_heads
,
const
int
qk_dim
,
const
int
v_dim
)
{
__global__
void
lightning_attention_decode_kernel
(
const
T
*
__restrict__
q
,
// [b, h, 1, d]
const
T
*
__restrict__
k
,
// [b, h, 1, d]
const
T
*
__restrict__
v
,
// [b, h, 1, e]
const
float
*
__restrict__
past_kv
,
// [b, h, d, e]
const
float
*
__restrict__
slope
,
// [h, 1, 1]
T
*
__restrict__
output
,
// [b, h, 1, e]
float
*
__restrict__
new_kv
,
// [b, h, d, e]
const
int
batch_size
,
const
int
num_heads
,
const
int
qk_dim
,
const
int
v_dim
)
{
extern
__shared__
char
smem
[];
T
*
__restrict__
q_shared
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
__restrict__
k_shared
=
reinterpret_cast
<
T
*>
(
smem
+
qk_dim
*
sizeof
(
T
));
...
...
@@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
}
}
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
)
{
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
)
{
TORCH_CHECK
(
q
.
is_contiguous
(),
"q must be contiguous"
);
TORCH_CHECK
(
k
.
is_contiguous
(),
"k must be contiguous"
);
TORCH_CHECK
(
v
.
is_contiguous
(),
"v must be contiguous"
);
...
...
@@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
q
.
scalar_type
(),
"lightning_attention_decode_kernel"
,
([
&
]
{
size_t
smem_size
=
(
2
*
qk_dim
+
2
*
v_dim
)
*
sizeof
(
scalar_t
)
+
qk_dim
*
(
v_dim
+
1
)
*
sizeof
(
float
);
lightning_attention_decode_kernel
<
scalar_t
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
q
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
v
.
data_ptr
<
scalar_t
>
(),
past_kv
.
data_ptr
<
float
>
(),
slope
.
data_ptr
<
float
>
(),
output
.
data_ptr
<
scalar_t
>
(),
new_kv
.
data_ptr
<
float
>
(),
batch_size
,
num_heads
,
qk_dim
,
v_dim
);
q
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
v
.
data_ptr
<
scalar_t
>
(),
past_kv
.
data_ptr
<
float
>
(),
slope
.
data_ptr
<
float
>
(),
output
.
data_ptr
<
scalar_t
>
(),
new_kv
.
data_ptr
<
float
>
(),
batch_size
,
num_heads
,
qk_dim
,
v_dim
);
}));
}
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
View file @
d052f4c8
...
...
@@ -25,9 +25,15 @@ namespace cutlass {
namespace
epilogue
{
namespace
threadblock
{
template
<
typename
ThreadblockShape_
,
int
ThreadCount
,
typename
ScaleTileIterator_
,
typename
OutputTileIterator_
,
typename
ElementAccumulator_
,
typename
ElementCompute_
,
typename
ElementwiseFunctor_
,
bool
UseMasking_
=
false
>
template
<
typename
ThreadblockShape_
,
int
ThreadCount
,
typename
ScaleTileIterator_
,
typename
OutputTileIterator_
,
typename
ElementAccumulator_
,
typename
ElementCompute_
,
typename
ElementwiseFunctor_
,
bool
UseMasking_
=
false
>
class
EpilogueVisitorPerRowPerCol
{
public:
using
ThreadblockShape
=
ThreadblockShape_
;
...
...
@@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol {
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
)
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
0
),
batch_stride_C
(
0
),
batch_stride_D
(
0
)
{}
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
,
int64_t
batch_stride_alpha_
,
int64_t
batch_stride_C_
,
int64_t
batch_stride_D_
)
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
,
int64_t
batch_stride_alpha_
,
int64_t
batch_stride_C_
,
int64_t
batch_stride_D_
)
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
batch_stride_alpha_
),
batch_stride_C
(
batch_stride_C_
),
...
...
@@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol {
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol
(
Params
const
&
params
,
SharedStorage
&
shared_storage
,
cutlass
::
MatrixCoord
const
&
problem_size
,
int
thread_idx
,
int
warp_idx
,
int
lane_idx
,
typename
ScaleTileIterator
::
Params
params_alpha_col
,
typename
OutputTileIterator
::
Params
params_C
,
typename
OutputTileIterator
::
Params
params_D
,
bool
with_bias
,
bool
per_token_quant
,
bool
per_channel_quant
,
AlphaScaleElementType
*
ptr_alpha_row
,
AlphaScaleElementType
*
ptr_alpha_col
,
typename
OutputTileIterator
::
Element
*
ptr_C
,
typename
OutputTileIterator
::
Element
*
ptr_D
,
cutlass
::
MatrixCoord
const
&
threadblock_offset
=
cutlass
::
MatrixCoord
(
0
,
0
),
int
column_offset
=
0
,
cutlass
::
MatrixCoord
const
&
problem_size_real
=
cutlass
::
MatrixCoord
(
0
,
0
))
EpilogueVisitorPerRowPerCol
(
Params
const
&
params
,
SharedStorage
&
shared_storage
,
cutlass
::
MatrixCoord
const
&
problem_size
,
int
thread_idx
,
int
warp_idx
,
int
lane_idx
,
typename
ScaleTileIterator
::
Params
params_alpha_col
,
typename
OutputTileIterator
::
Params
params_C
,
typename
OutputTileIterator
::
Params
params_D
,
bool
with_bias
,
bool
per_token_quant
,
bool
per_channel_quant
,
AlphaScaleElementType
*
ptr_alpha_row
,
AlphaScaleElementType
*
ptr_alpha_col
,
typename
OutputTileIterator
::
Element
*
ptr_C
,
typename
OutputTileIterator
::
Element
*
ptr_D
,
cutlass
::
MatrixCoord
const
&
threadblock_offset
=
cutlass
::
MatrixCoord
(
0
,
0
),
int
column_offset
=
0
,
cutlass
::
MatrixCoord
const
&
problem_size_real
=
cutlass
::
MatrixCoord
(
0
,
0
))
:
params_
(
params
),
shared_storage_
(
shared_storage
),
extent_
(
problem_size
),
...
...
@@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol {
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void
set_k_partition
(
int
split_k_index
,
///< Index of this threadblock within split-K partitioned scheme
int
split_k_slices
)
{
///< Total number of split-K slices
void
set_k_partition
(
int
split_k_index
,
///< Index of this threadblock within split-K partitioned scheme
int
split_k_slices
)
{
///< Total number of split-K slices
}
/// Called to set the batch index
...
...
@@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol {
private:
CUTLASS_DEVICE
ComputeFragment
per_token_channel_scale_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
per_token_channel_scale_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
...
...
@@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol {
}
CUTLASS_DEVICE
ComputeFragment
per_token_scale_accumulator_
(
ComputeFragment
const
&
accum
,
AlphaScaleElementType
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
per_token_scale_accumulator_
(
ComputeFragment
const
&
accum
,
AlphaScaleElementType
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
...
...
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
View file @
d052f4c8
...
...
@@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
template
<
int
Stages_
,
class
ClusterShape_
=
Shape
<
_1
,
_1
,
_1
>,
class
KernelSchedule
=
KernelTmaWarpSpecialized
,
int
ScaleGranularityM
=
0
// `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
template
<
int
Stages_
,
class
ClusterShape_
=
Shape
<
_1
,
_1
,
_1
>,
class
KernelSchedule
=
KernelTmaWarpSpecialized
,
int
ScaleGranularityM
=
0
// `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
:
MainloopSm90TmaGmmaWarpSpecialized
<
Stages_
,
ClusterShape_
,
KernelSchedule
>
{
static_assert
(
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
,
"KernelSchedule must be one of the warp specialized policies"
);
static_assert
(
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
,
"KernelSchedule must be one of the warp specialized policies"
);
};
//////////////////////////////////////////////////////////////////////////////
...
...
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
View file @
d052f4c8
...
...
@@ -159,8 +159,9 @@ class GemmUniversalBaseCompat {
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
dim3
result
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
CUTLASS_TRACE_HOST
(
" grid_tiled_shape: "
<<
grid_tiled_shape
<<
"
\n
"
<<
" result = {"
<<
result
<<
"}"
);
CUTLASS_TRACE_HOST
(
" grid_tiled_shape: "
<<
grid_tiled_shape
<<
"
\n
"
<<
" result = {"
<<
result
<<
"}"
);
return
result
;
}
...
...
@@ -175,8 +176,8 @@ class GemmUniversalBaseCompat {
CUTLASS_TRACE_HOST
(
" smem_size: "
<<
smem_size
<<
" bytes"
);
if
(
smem_size
<=
(
48
<<
10
))
{
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
smem_size
);
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
smem_size
);
if
(
result
==
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
...
...
@@ -184,12 +185,12 @@ class GemmUniversalBaseCompat {
}
}
else
{
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
0
);
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
0
);
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<<
cudaGetErrorString
(
result
));
CUTLASS_TRACE_HOST
(
"
cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
...
...
@@ -226,8 +227,9 @@ class GemmUniversalBaseCompat {
/// Initializes GEMM state from arguments.
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
size_t
workspace_bytes
=
get_workspace_size
(
args
);
...
...
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
View file @
d052f4c8
...
...
@@ -32,10 +32,11 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
///! Threadblock swizzling function
>
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
///! Threadblock swizzling function
>
struct
GemmWithEpilogueVisitor
{
public:
using
Mma
=
Mma_
;
...
...
@@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor {
Arguments
()
:
mode
(
GemmUniversalMode
::
kGemm
),
batch_count
(
1
)
{}
/// constructs an arguments structure
Arguments
(
GemmCoord
problem_size_
,
TensorRefA
ref_A_
,
TensorRefB
ref_B_
,
TensorRefAlphaCol
ref_alpha_col_
,
TensorRefAlphaRow
ref_alpha_row_
,
TensorRefC
ref_C_
,
TensorRefC
ref_D_
,
typename
EpilogueVisitor
::
Arguments
epilogue_visitor_
)
Arguments
(
GemmCoord
problem_size_
,
TensorRefA
ref_A_
,
TensorRefB
ref_B_
,
TensorRefAlphaCol
ref_alpha_col_
,
TensorRefAlphaRow
ref_alpha_row_
,
TensorRefC
ref_C_
,
TensorRefC
ref_D_
,
typename
EpilogueVisitor
::
Arguments
epilogue_visitor_
)
:
mode
(
GemmUniversalMode
::
kGemm
),
problem_size
(
problem_size_
),
batch_count
(
1
),
...
...
@@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor {
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajor
>::
value
)
{
isAMisaligned
=
problem_size
.
m
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
...
...
@@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor {
isBMisaligned
=
problem_size
.
n
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
ColumnMajor
>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
64
>>::
value
)
{
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
64
>>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
...
...
@@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor {
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajor
>::
value
)
{
isCMisaligned
=
problem_size
.
m
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
...
...
@@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor {
int
thread_idx
=
threadIdx
.
x
;
// Construct iterators to A and B operands
typename
Mma
::
IteratorA
iterator_A
(
params
.
params_A
,
ptr_A
,
{
params
.
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorA
iterator_A
(
params
.
params_A
,
ptr_A
,
{
params
.
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorB
iterator_B
(
params
.
params_B
,
ptr_B
,
{
problem_size_k
,
params
.
problem_size
.
n
()},
thread_idx
,
tb_offset_B
);
typename
Mma
::
IteratorB
iterator_B
(
params
.
params_B
,
ptr_B
,
{
problem_size_k
,
params
.
problem_size
.
n
()},
thread_idx
,
tb_offset_B
);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
...
...
@@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor {
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// assume identity swizzle
MatrixCoord
threadblock_offset
(
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
);
MatrixCoord
threadblock_offset
(
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
);
int
block_idx
=
threadblock_tile_offset
.
m
()
+
threadblock_tile_offset
.
n
()
*
params
.
grid_tiled_shape
.
m
();
...
...
@@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor {
with_bias
=
false
;
}
EpilogueVisitor
epilogue_visitor
(
params
.
epilogue_visitor
,
shared_storage
.
epilogue
.
visitor
,
params
.
problem_size
.
mn
(),
thread_idx
,
warp_idx
,
lane_idx
,
params
.
params_alpha_col
,
params
.
params_C
,
params
.
params_D
,
with_bias
,
true
,
true
,
params
.
ptr_alpha_row
,
params
.
ptr_alpha_col
,
params
.
ptr_C
,
params
.
ptr_D
,
threadblock_offset
,
blockIdx
.
y
*
params
.
problem_size
.
m
());
EpilogueVisitor
epilogue_visitor
(
params
.
epilogue_visitor
,
shared_storage
.
epilogue
.
visitor
,
params
.
problem_size
.
mn
(),
thread_idx
,
warp_idx
,
lane_idx
,
params
.
params_alpha_col
,
params
.
params_C
,
params
.
params_D
,
with_bias
,
true
,
true
,
params
.
ptr_alpha_row
,
params
.
ptr_alpha_col
,
params
.
ptr_C
,
params
.
ptr_D
,
threadblock_offset
,
blockIdx
.
y
*
params
.
problem_size
.
m
());
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
)
{
// Indicate which position in a serial reduction the output operator is currently updating
...
...
sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
View file @
d052f4c8
...
...
@@ -21,10 +21,13 @@
#include "utils.h"
static
void
check_group_count
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
const
std
::
vector
<
torch
::
Tensor
>&
outputs
)
{
TORCH_CHECK
(((
inputs
.
size
()
==
weights
.
size
())
&&
(
inputs
.
size
()
==
outputs
.
size
())),
"The group count of inputs, weights and outputs should be the same."
);
static
void
check_group_count
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
const
std
::
vector
<
torch
::
Tensor
>&
outputs
)
{
TORCH_CHECK
(
((
inputs
.
size
()
==
weights
.
size
())
&&
(
inputs
.
size
()
==
outputs
.
size
())),
"The group count of inputs, weights and outputs should be the same."
);
}
static
void
check_device_dtype
(
const
torch
::
Dtype
&
dtype
,
const
std
::
vector
<
torch
::
Tensor
>&
tensors
)
{
...
...
@@ -68,21 +71,26 @@ static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tens
static
torch
::
Tensor
create_ptr_pointer
(
const
std
::
vector
<
void
*>&
ptrs
,
cudaStream_t
stream
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kDouble
).
device
(
torch
::
kCUDA
);
torch
::
Tensor
gpu_ptrs
=
torch
::
empty
({
static_cast
<
int
>
(
ptrs
.
size
())},
options
);
TORCH_CHECK
(
cudaMemcpyAsync
(
gpu_ptrs
.
data_ptr
(),
ptrs
.
data
(),
sizeof
(
void
*
)
*
ptrs
.
size
(),
cudaMemcpyHostToDevice
,
stream
)
==
CUBLAS_STATUS_SUCCESS
);
TORCH_CHECK
(
cudaMemcpyAsync
(
gpu_ptrs
.
data_ptr
(),
ptrs
.
data
(),
sizeof
(
void
*
)
*
ptrs
.
size
(),
cudaMemcpyHostToDevice
,
stream
)
==
CUBLAS_STATUS_SUCCESS
);
return
gpu_ptrs
;
}
// We want compute input @ weight^T in row major
// This is equivalent to computing weight @ input^T in col major
// Cublas only accepts matrix in column major, so this arrangement is needed
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
// b: (m, k) row major = (k, m) col major
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
// a: (n, k) row major = (n, k)^T col major
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
// c: (m, n) row major = (n, m) col major
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cuda_stream
)
{
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"cublas grouped_gemm can"
"only be applied to float16 and bfloat16 dtype"
);
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
// b: (m, k) row major = (k, m) col major
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
// a: (n, k) row major = (n, k)^T col major
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
// c: (m, n) row major = (n, m) col major
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cuda_stream
)
{
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"cublas grouped_gemm can"
"only be applied to float16 and bfloat16 dtype"
);
int
group_count
=
inputs
.
size
();
check_group_count
(
inputs
,
weights
,
outputs
);
...
...
@@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k
torch
::
Tensor
d_c
=
create_ptr_pointer
(
c_array
,
stream
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
auto
status
=
cublasGemmGroupedBatchedEx
(
handle
,
transa_array
.
data
(),
transb_array
.
data
(),
m_array
.
data
(),
n_array
.
data
(),
k_array
.
data
(),
alpha_array
.
data
(),
(
void
**
)
d_a
.
data_ptr
(),
cuda_data_type
,
lda_array
.
data
(),
(
void
**
)
d_b
.
data_ptr
(),
cuda_data_type
,
ldb_array
.
data
(),
beta_array
.
data
(),
(
void
**
)
d_c
.
data_ptr
(),
cuda_data_type
,
ldc_array
.
data
(),
group_count
,
group_size
.
data
(),
CUBLAS_COMPUTE_32F
);
auto
status
=
cublasGemmGroupedBatchedEx
(
handle
,
transa_array
.
data
(),
transb_array
.
data
(),
m_array
.
data
(),
n_array
.
data
(),
k_array
.
data
(),
alpha_array
.
data
(),
(
void
**
)
d_a
.
data_ptr
(),
cuda_data_type
,
lda_array
.
data
(),
(
void
**
)
d_b
.
data_ptr
(),
cuda_data_type
,
ldb_array
.
data
(),
beta_array
.
data
(),
(
void
**
)
d_c
.
data_ptr
(),
cuda_data_type
,
ldc_array
.
data
(),
group_count
,
group_size
.
data
(),
CUBLAS_COMPUTE_32F
);
TORCH_CHECK
(
status
==
CUBLAS_STATUS_SUCCESS
,
"cublas grouped gemm failed: "
,
cublasGetStatusString
(
status
));
TORCH_CHECK
(
cudaStreamSynchronize
(
stream
)
==
cudaSuccess
,
"Failed when stream synchronization"
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"Cublas GroupGemm is not implemented with current compute capability: "
,
getSMVersion
());
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"Cublas GroupGemm is not implemented with current compute capability: "
,
getSMVersion
());
}
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
View file @
d052f4c8
...
...
@@ -35,8 +35,12 @@
using
namespace
cute
;
template
<
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
,
int
ScaleGranularityM
=
1
>
void
launch_sm90_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
void
launch_sm90_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
ElementAccumulator
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
...
...
@@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
Gemm
gemm_op
;
...
...
@@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
}
template
<
typename
OutType
>
void
sm90_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
void
sm90_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
launch_sm90_fp8_blockwise_scaled_mm
<
OutType
,
TileShape
,
ClusterShape
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
)
{
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
)
{
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
...
...
@@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_a must be a column major tensor"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
(
mat_a
.
size
(
1
)
*
mat_a
.
element_size
())
%
16
==
0
,
"mat_a must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
(
mat_b
.
size
(
0
)
*
mat_b
.
element_size
())
%
16
==
0
,
"mat_b must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
(
mat_a
.
size
(
1
)
*
mat_a
.
element_size
())
%
16
==
0
,
"mat_a must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
(
mat_b
.
size
(
0
)
*
mat_b
.
element_size
())
%
16
==
0
,
"mat_b must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
mat_a
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_a must be Float8_e4m3fn"
);
TORCH_CHECK
(
mat_b
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_b must be Float8_e4m3fn"
);
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"out_dtype must be Half or BFloat16"
);
...
...
@@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
}
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
View file @
d052f4c8
This diff is collapsed.
Click to expand it.
sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
View file @
d052f4c8
This diff is collapsed.
Click to expand it.
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
View file @
d052f4c8
...
...
@@ -8,8 +8,8 @@
#include "utils.h"
template
<
typename
T
>
__global__
void
per_tensor_absmax_kernel
(
const
T
*
__restrict__
input
,
float
*
__restrict__
output_s
,
const
int64_t
num_elements
)
{
__global__
void
per_tensor_absmax_kernel
(
const
T
*
__restrict__
input
,
float
*
__restrict__
output_s
,
const
int64_t
num_elements
)
{
float
max_value
=
0.0
f
;
unsigned
int
tid
=
threadIdx
.
x
;
unsigned
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r
}
template
<
typename
T
>
__global__
void
per_tensor_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output
,
const
float
*
__restrict__
scale
,
const
int64_t
num_elements
)
{
__global__
void
per_tensor_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output
,
const
float
*
__restrict__
scale
,
const
int64_t
num_elements
)
{
const
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
const
float
scale_val
=
1.0
f
/
(
*
scale
);
...
...
@@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
}
per_tensor_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
num_elements
);
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
num_elements
);
return
true
;
});
}
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
View file @
d052f4c8
...
...
@@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) {
}
template
<
typename
T
>
__global__
void
per_token_group_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int
group_size
,
const
int
num_groups
,
const
float
eps
,
const
float
fp8_min
,
const
float
fp8_max
)
{
__global__
void
per_token_group_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int
group_size
,
const
int
num_groups
,
const
float
eps
,
const
float
fp8_min
,
const
float
fp8_max
)
{
const
int
groups_per_block
=
16
;
const
int
local_group_id
=
threadIdx
.
x
/
16
;
const
int
lane_id
=
threadIdx
.
x
%
16
;
...
...
@@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
}
}
void
sgl_per_token_group_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
)
{
void
sgl_per_token_group_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_s
);
...
...
@@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q,
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
per_token_group_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
output_q
.
data_ptr
(),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
group_size
,
num_groups
,
(
float
)
eps
,
(
float
)
fp8_min
,
(
float
)
fp8_max
);
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
output_q
.
data_ptr
(),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
group_size
,
num_groups
,
(
float
)
eps
,
(
float
)
fp8_min
,
(
float
)
fp8_max
);
return
true
;
});
}
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
View file @
d052f4c8
...
...
@@ -7,9 +7,12 @@
#include "utils.h"
template
<
typename
T
>
__global__
void
per_token_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int64_t
hidden_dim
,
const
int64_t
num_tokens
)
{
__global__
void
per_token_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int64_t
hidden_dim
,
const
int64_t
num_tokens
)
{
const
int
token_idx
=
blockIdx
.
x
;
if
(
token_idx
>=
num_tokens
)
return
;
...
...
@@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
per_token_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
return
true
;
});
}
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
d052f4c8
...
...
@@ -25,9 +25,11 @@ limitations under the License.
#define WARP_SIZE 32
template
<
typename
scalar_t
>
__global__
void
count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
size_t
numel
)
{
__global__
void
count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
size_t
numel
)
{
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
...
...
@@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__
}
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
)
{
__global__
void
moe_align_block_size_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
)
{
__shared__
int32_t
shared_counts
[
WARP_SIZE
][
8
];
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
...
...
@@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
}
}
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
TORCH_CHECK
(
num_experts
==
256
,
"moe_align_block_size kernel only support deepseek v3 now."
);
DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
align_kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
align_kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
const
int
block_threads
=
256
;
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
...
...
@@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
auto
sort_kernel
=
count_and_sort_expert_tokens_kernel
<
scalar_t
>
;
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
});
}
sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu
View file @
d052f4c8
...
...
@@ -23,10 +23,18 @@
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
__global__
void
build_tree_efficient
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int64_t
*
retrive_next_token
,
int64_t
*
retrive_next_sibling
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
__global__
void
build_tree_efficient
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int64_t
*
retrive_next_token
,
int64_t
*
retrive_next_sibling
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
int
bid
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
...
...
@@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind
}
}
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
// TODO (ying) check shape
// TODO (ying) check type
int
bs
=
parent_list
.
size
(
0
);
...
...
@@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
build_tree_efficient
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_sibling
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_sibling
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
}
// parent_list [bs, topk * (depth - 1) + 1)]
...
...
@@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2]
__global__
void
build_tree
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
__global__
void
build_tree
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
int
bid
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
...
...
@@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_
}
}
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
// TODO (ying) check shape
// TODO (ying) check type
int
bs
=
parent_list
.
size
(
0
);
...
...
@@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
build_tree
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment