Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
d052f4c8
"examples/pytorch/gxn/README.md" did not exist on "9fc5eed6fdaa10602d35e451536bcff9aeba9224"
Unverified
Commit
d052f4c8
authored
Mar 07, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 07, 2025
Browse files
New clang format for sgl kernel (#4194)
parent
e1aaa79a
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1217 additions
and
567 deletions
+1217
-567
python/upload_pypi.sh
python/upload_pypi.sh
+0
-6
sgl-kernel/.clang-format
sgl-kernel/.clang-format
+7
-0
sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu
...c/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu
+9
-4
sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
...l/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
+72
-44
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
...rnel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
+29
-14
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
...kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
+20
-10
sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
...ernel/csrc/attention/lightning_attention_decode_kernel.cu
+31
-15
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
...lass_extensions/epilogue/epilogue_per_row_per_col_scale.h
+41
-22
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
...l-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
+12
-8
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
...csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
+12
-10
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
...csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
+48
-24
sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
+44
-20
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
...nel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+55
-20
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
+409
-189
sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
+280
-124
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
+11
-6
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
...nel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
+25
-8
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
+11
-5
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu
+37
-16
sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu
sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu
+64
-22
No files found.
python/upload_pypi.sh
deleted
100644 → 0
View file @
e1aaa79a
cp
../README.md ../LICENSE
.
rm
-rf
dist
python3
-m
build
python3
-m
twine upload dist/
*
rm
-rf
README.md LICENSE
sgl-kernel/.clang-format
View file @
d052f4c8
...
@@ -6,3 +6,10 @@ DerivePointerAlignment: false
...
@@ -6,3 +6,10 @@ DerivePointerAlignment: false
PointerAlignment: Left
PointerAlignment: Left
NamespaceIndentation: None
NamespaceIndentation: None
SortIncludes: true
SortIncludes: true
AllowShortLoopsOnASingleLine: false
BinPackParameters: false # Prevents packing parameters in declarations
BinPackArguments: false # Prevents packing arguments in function calls
AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
AlignOperands: Align # Aligns arguments vertically
PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu
View file @
d052f4c8
...
@@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
...
@@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
// support float16, bfloat16 and float32
// support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
cudaError_t
status
=
norm
::
FusedAddRMSNorm
(
cudaError_t
status
=
norm
::
FusedAddRMSNorm
(
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
static_cast
<
c_type
*>
(
residual
.
data_ptr
()),
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
batch_size
,
hidden_size
,
eps
,
torch_current_stream
);
static_cast
<
c_type
*>
(
residual
.
data_ptr
()),
TORCH_CHECK
(
status
==
cudaSuccess
,
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
"FusedAddRMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
batch_size
,
hidden_size
,
eps
,
torch_current_stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"FusedAddRMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
return
true
;
return
true
;
});
});
}
}
sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
View file @
d052f4c8
...
@@ -153,19 +153,20 @@ DINLINE O downcast(array_t<float, O::size> val) {
...
@@ -153,19 +153,20 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
// other volatile writes.
template
<
int
ngpus
>
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
#ifndef USE_ROCM
#ifndef USE_ROCM
volatile
volatile
#endif
#endif
Signal
*
self_sg
,
Signal
*
self_sg
,
int
rank
)
{
int
rank
)
{
#ifdef USE_ROCM
#ifdef USE_ROCM
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// Latency = 1 p2p write
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
,
__scoped_atomic_store_n
(
__MEMORY_SCOPE_SYSTEM
);
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
// wait until we got true from all ranks
while
(
__scoped_atomic_load_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_DEVICE
)
<
while
(
__scoped_atomic_load_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_DEVICE
)
<
flag
)
flag
)
...
@@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg,
...
@@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
// we don't need to make any visibility guarantees for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
#ifndef USE_ROCM
#ifndef USE_ROCM
volatile
volatile
#endif
#endif
Signal
*
self_sg
,
Signal
*
self_sg
,
int
rank
)
{
int
rank
)
{
#ifdef USE_ROCM
#ifdef USE_ROCM
__syncthreads
();
__syncthreads
();
// eliminate the case that prior writes are not visible after signals become
// eliminate the case that prior writes are not visible after signals become
...
@@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg,
...
@@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg,
if
(
threadIdx
.
x
<
ngpus
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// Latency = 1 p2p write
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
__scoped_atomic_store_n
(
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
,
__MEMORY_SCOPE_SYSTEM
);
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
,
__MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
// wait until we got true from all ranks
while
(
__scoped_atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
while
(
__scoped_atomic_load_n
(
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
,
__MEMORY_SCOPE_DEVICE
)
<
flag
)
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
,
__MEMORY_SCOPE_DEVICE
)
<
flag
)
;
;
}
}
__syncthreads
();
__syncthreads
();
...
@@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
...
@@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
}
}
template
<
typename
T
,
int
ngpus
>
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
#ifndef USE_ROCM
#ifndef USE_ROCM
volatile
volatile
#endif
#endif
Signal
*
self_sg
,
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
using
A
=
typename
packed_t
<
T
>::
A
;
// note: we don't reorder the address so the accumulation order is the same
// note: we don't reorder the address so the accumulation order is the same
...
@@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) {
...
@@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) {
}
}
template
<
typename
T
,
int
ngpus
>
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
#ifndef USE_ROCM
#ifndef USE_ROCM
volatile
volatile
#endif
#endif
Signal
*
self_sg
,
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
using
P
=
typename
packed_t
<
T
>::
P
;
using
P
=
typename
packed_t
<
T
>::
P
;
...
@@ -357,8 +372,14 @@ class CustomAllreduce {
...
@@ -357,8 +372,14 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
* are passed in from the constructor
*/
*/
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
hipIpcMemHandle_t
*
handles
,
CustomAllreduce
(
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
hipIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
:
rank_
(
rank
),
:
rank_
(
rank
),
world_size_
(
offsets
.
size
()),
world_size_
(
offsets
.
size
()),
full_nvlink_
(
full_nvlink
),
full_nvlink_
(
full_nvlink
),
...
@@ -382,8 +403,8 @@ class CustomAllreduce {
...
@@ -382,8 +403,8 @@ class CustomAllreduce {
auto
[
it
,
new_handle
]
=
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
auto
[
it
,
new_handle
]
=
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
if
(
new_handle
)
{
char
*
ipc_ptr
;
char
*
ipc_ptr
;
CUDACHECK
(
hipIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
hipIpcMemHandle_t
*
)
ipc_handle
),
CUDACHECK
(
hipIpcOpenMemHandle
(
hipIpcMemLazyEnablePeerAccess
));
(
void
**
)
&
ipc_ptr
,
*
((
const
hipIpcMemHandle_t
*
)
ipc_handle
),
hipIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
it
->
second
=
ipc_ptr
;
}
}
return
it
->
second
;
return
it
->
second
;
...
@@ -399,13 +420,14 @@ class CustomAllreduce {
...
@@ -399,13 +420,14 @@ class CustomAllreduce {
void
*
base_ptr
;
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// note: must share the base address of each allocation, or we get wrong
// address
// address
if
(
hipPointerGetAttribute
(
&
base_ptr
,
if
(
hipPointerGetAttribute
(
&
base_ptr
,
#ifdef USE_ROCM
#ifdef USE_ROCM
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
#else
#else
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
#endif
#endif
(
hipDeviceptr_t
)
ptr
)
!=
hipSuccess
)
(
hipDeviceptr_t
)
ptr
)
!=
hipSuccess
)
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
CUDACHECK
(
hipIpcGetMemHandle
((
hipIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
CUDACHECK
(
hipIpcGetMemHandle
((
hipIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
...
@@ -415,8 +437,8 @@ class CustomAllreduce {
...
@@ -415,8 +437,8 @@ class CustomAllreduce {
void
check_rank_data_capacity
(
size_t
num
=
1
)
{
void
check_rank_data_capacity
(
size_t
num
=
1
)
{
if
(
d_rank_data_base_
+
num
>
d_rank_data_end_
)
if
(
d_rank_data_base_
+
num
>
d_rank_data_end_
)
throw
std
::
runtime_error
(
"Rank data buffer is overflowed by "
+
throw
std
::
runtime_error
(
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
"Rank data buffer is overflowed by "
+
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
}
}
void
register_buffer
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
void
*
self
)
{
void
register_buffer
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
void
*
self
)
{
...
@@ -443,8 +465,8 @@ class CustomAllreduce {
...
@@ -443,8 +465,8 @@ class CustomAllreduce {
// rank 1 may get the same input address for the second allreduce, but rank 2
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
// mechanism so overhead should be small.
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>&
handles
,
void
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
check_rank_data_capacity
(
num_buffers
);
check_rank_data_capacity
(
num_buffers
);
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
...
@@ -474,11 +496,17 @@ class CustomAllreduce {
...
@@ -474,11 +496,17 @@ class CustomAllreduce {
* will cause contention on NVLink bus.
* will cause contention on NVLink bus.
*/
*/
template
<
typename
T
>
template
<
typename
T
>
void
allreduce
(
hipStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
void
allreduce
(
hipStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
#ifndef USE_ROCM
#ifndef USE_ROCM
int
threads
=
512
,
int
block_limit
=
36
){
int
threads
=
512
,
int
block_limit
=
36
){
#else
#else
int
threads
=
512
,
int
block_limit
=
16
)
{
int
threads
=
512
,
int
block_limit
=
16
)
{
#endif
#endif
auto
d
=
packed_t
<
T
>::
P
::
size
;
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
if
(
size
%
d
!=
0
)
...
@@ -487,8 +515,8 @@ class CustomAllreduce {
...
@@ -487,8 +515,8 @@ class CustomAllreduce {
"of "
+
"of "
+
std
::
to_string
(
d
));
std
::
to_string
(
d
));
if
(
block_limit
>
kMaxBlocks
)
if
(
block_limit
>
kMaxBlocks
)
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
throw
std
::
runtime_error
(
std
::
to_string
(
block_limit
));
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
RankData
*
ptrs
;
hipStreamCaptureStatus
status
;
hipStreamCaptureStatus
status
;
...
@@ -499,17 +527,17 @@ class CustomAllreduce {
...
@@ -499,17 +527,17 @@ class CustomAllreduce {
}
else
{
}
else
{
auto
it
=
buffers_
.
find
(
input
);
auto
it
=
buffers_
.
find
(
input
);
if
(
it
==
buffers_
.
end
())
if
(
it
==
buffers_
.
end
())
throw
std
::
runtime_error
(
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
throw
std
::
runtime_error
(
" is not registered!"
);
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
" is not registered!"
);
ptrs
=
it
->
second
;
ptrs
=
it
->
second
;
}
}
size
/=
d
;
size
/=
d
;
auto
bytes
=
size
*
sizeof
(
typename
packed_t
<
T
>::
P
);
auto
bytes
=
size
*
sizeof
(
typename
packed_t
<
T
>::
P
);
int
blocks
=
::
min
(
block_limit
,
(
size
+
threads
-
1
)
/
threads
);
int
blocks
=
::
min
(
block_limit
,
(
size
+
threads
-
1
)
/
threads
);
#define KL(ngpus, name)
\
#define KL(ngpus, name) \
hipLaunchKernelGGL(
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_,
\
hipLaunchKernelGGL(
\
size);
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_,
size);
#define REDUCE_CASE(ngpus) \
#define REDUCE_CASE(ngpus) \
case ngpus: { \
case ngpus: { \
if (world_size_ == 2) { \
if (world_size_ == 2) { \
...
...
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
View file @
d052f4c8
...
@@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) {
...
@@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) {
return
c
.
packed
;
return
c
.
packed
;
}
}
__inline__
__device__
void
multi_gpu_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
__inline__
__device__
void
multi_gpu_barrier
(
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
)
{
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
)
{
// After this function, at least one block in each GPU has reached the barrier
// After this function, at least one block in each GPU has reached the barrier
if
(
tidx
<
world_size
)
{
if
(
tidx
<
world_size
)
{
// we can think of signals having the shape [world_size, world_size]
// we can think of signals having the shape [world_size, world_size]
...
@@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
...
@@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
}
}
template
<
bool
start
,
bool
need_fence
=
false
>
template
<
bool
start
,
bool
need_fence
=
false
>
__inline__
__device__
void
block_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
__inline__
__device__
void
block_barrier
(
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
,
int
const
grid_size
)
{
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
,
int
const
grid_size
)
{
if
constexpr
(
!
start
)
{
if
constexpr
(
!
start
)
{
__syncthreads
();
__syncthreads
();
}
}
...
@@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
...
@@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
}
}
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
block_barrier
<
true
>
(
grid_size
);
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Each block accumulates the values from the different GPUs on the same node.
// Each block accumulates the values from the different GPUs on the same node.
for
(
size_t
iter_offset
=
chunk_start
;
iter_offset
<
chunk_end
;
iter_offset
+=
blockDim
.
x
*
NUM_ELTS
)
{
for
(
size_t
iter_offset
=
chunk_start
;
iter_offset
<
chunk_end
;
iter_offset
+=
blockDim
.
x
*
NUM_ELTS
)
{
...
@@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
...
@@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
}
}
}
}
}
}
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
block_barrier
<
true
>
(
grid_size
);
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Each block accumulates the values from the different GPUs on the same node.
// Each block accumulates the values from the different GPUs on the same node.
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
...
@@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
...
@@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
}
}
}
}
block_barrier
<
false
,
true
>
(
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
block_barrier
<
false
,
true
>
(
bidx
,
grid_size
);
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Gather all needed elts from other intra-node ranks
// Gather all needed elts from other intra-node ranks
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
...
@@ -459,8 +470,12 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
...
@@ -459,8 +470,12 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
>
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
>
void
dispatchARKernels
(
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
void
dispatchARKernels
(
cudaStream_t
stream
)
{
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
cudaStream_t
stream
)
{
switch
(
algo
)
{
switch
(
algo
)
{
case
AllReduceStrategyType
::
ONESHOT
:
{
case
AllReduceStrategyType
::
ONESHOT
:
{
oneShotAllReduceKernel
<
T
,
RANKS_PER_NODE
,
COPY_INPUT
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
oneShotAllReduceKernel
<
T
,
RANKS_PER_NODE
,
COPY_INPUT
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
...
@@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
...
@@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
}
}
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
void
trtCustomAllReduce
(
cudaStream_t
stream
)
{
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
)
{
if
(
params
.
elts_total
==
0
)
{
if
(
params
.
elts_total
==
0
)
{
return
;
return
;
}
}
...
...
sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
View file @
d052f4c8
...
@@ -29,9 +29,14 @@ using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
...
@@ -29,9 +29,14 @@ using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
class
AllReduceMeta
{
class
AllReduceMeta
{
public:
public:
AllReduceMeta
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
AllReduceMeta
(
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
int64_t
rank_id
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
this
->
rank_id
=
(
int
)
rank_id
;
this
->
rank_id
=
(
int
)
rank_id
;
this
->
world_size
=
(
int
)
world_size
;
this
->
world_size
=
(
int
)
world_size
;
this
->
barrier_in
=
barrier_in
;
this
->
barrier_in
=
barrier_in
;
...
@@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
...
@@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
return
num_elements
%
(
16
/
((
get_bits
(
dtype
)
+
7
)
/
8
))
==
0
;
return
num_elements
%
(
16
/
((
get_bits
(
dtype
)
+
7
)
/
8
))
==
0
;
}
}
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
fptr_t
init_custom_ar
(
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
int64_t
rank_id
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
auto
m
=
new
AllReduceMeta
(
rank_id
,
world_size
,
rank_data
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
);
auto
m
=
new
AllReduceMeta
(
rank_id
,
world_size
,
rank_data
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
);
return
(
fptr_t
)
m
;
return
(
fptr_t
)
m
;
}
}
...
@@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
...
@@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
auto
[
it
,
new_handle
]
=
meta
->
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
auto
[
it
,
new_handle
]
=
meta
->
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
if
(
new_handle
)
{
char
*
ipc_ptr
;
char
*
ipc_ptr
;
CHECK_CUDA_SUCCESS
(
cudaIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
CHECK_CUDA_SUCCESS
(
cudaIpcOpenMemHandle
(
cudaIpcMemLazyEnablePeerAccess
));
(
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
it
->
second
=
ipc_ptr
;
}
}
return
it
->
second
;
return
it
->
second
;
...
@@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
...
@@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
// rank 1 may get the same input address for the second allreduce, but rank 2
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
// mechanism so overhead should be small.
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
void
register_graph_buffers
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
std
::
vector
<
std
::
string
>
handle_bytes
;
std
::
vector
<
std
::
string
>
handle_bytes
;
handle_bytes
.
reserve
(
handles
.
size
());
handle_bytes
.
reserve
(
handles
.
size
());
...
...
sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
View file @
d052f4c8
...
@@ -23,15 +23,18 @@ limitations under the License.
...
@@ -23,15 +23,18 @@ limitations under the License.
#define THREADS_PER_BLOCK 128
#define THREADS_PER_BLOCK 128
template
<
typename
T
>
template
<
typename
T
>
__global__
void
lightning_attention_decode_kernel
(
const
T
*
__restrict__
q
,
// [b, h, 1, d]
__global__
void
lightning_attention_decode_kernel
(
const
T
*
__restrict__
k
,
// [b, h, 1, d]
const
T
*
__restrict__
q
,
// [b, h, 1, d]
const
T
*
__restrict__
v
,
// [b, h, 1, e]
const
T
*
__restrict__
k
,
// [b, h, 1, d]
const
float
*
__restrict__
past_kv
,
// [b, h, d, e]
const
T
*
__restrict__
v
,
// [b, h, 1, e]
const
float
*
__restrict__
slope
,
// [h, 1, 1]
const
float
*
__restrict__
past_kv
,
// [b, h, d, e]
T
*
__restrict__
output
,
// [b, h, 1, e]
const
float
*
__restrict__
slope
,
// [h, 1, 1]
float
*
__restrict__
new_kv
,
// [b, h, d, e]
T
*
__restrict__
output
,
// [b, h, 1, e]
const
int
batch_size
,
const
int
num_heads
,
const
int
qk_dim
,
float
*
__restrict__
new_kv
,
// [b, h, d, e]
const
int
v_dim
)
{
const
int
batch_size
,
const
int
num_heads
,
const
int
qk_dim
,
const
int
v_dim
)
{
extern
__shared__
char
smem
[];
extern
__shared__
char
smem
[];
T
*
__restrict__
q_shared
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
__restrict__
q_shared
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
__restrict__
k_shared
=
reinterpret_cast
<
T
*>
(
smem
+
qk_dim
*
sizeof
(
T
));
T
*
__restrict__
k_shared
=
reinterpret_cast
<
T
*>
(
smem
+
qk_dim
*
sizeof
(
T
));
...
@@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
...
@@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
}
}
}
}
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
void
lightning_attention_decode
(
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
const
torch
::
Tensor
&
q
,
torch
::
Tensor
new_kv
)
{
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
)
{
TORCH_CHECK
(
q
.
is_contiguous
(),
"q must be contiguous"
);
TORCH_CHECK
(
q
.
is_contiguous
(),
"q must be contiguous"
);
TORCH_CHECK
(
k
.
is_contiguous
(),
"k must be contiguous"
);
TORCH_CHECK
(
k
.
is_contiguous
(),
"k must be contiguous"
);
TORCH_CHECK
(
v
.
is_contiguous
(),
"v must be contiguous"
);
TORCH_CHECK
(
v
.
is_contiguous
(),
"v must be contiguous"
);
...
@@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
...
@@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
q
.
scalar_type
(),
"lightning_attention_decode_kernel"
,
([
&
]
{
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
q
.
scalar_type
(),
"lightning_attention_decode_kernel"
,
([
&
]
{
size_t
smem_size
=
(
2
*
qk_dim
+
2
*
v_dim
)
*
sizeof
(
scalar_t
)
+
qk_dim
*
(
v_dim
+
1
)
*
sizeof
(
float
);
size_t
smem_size
=
(
2
*
qk_dim
+
2
*
v_dim
)
*
sizeof
(
scalar_t
)
+
qk_dim
*
(
v_dim
+
1
)
*
sizeof
(
float
);
lightning_attention_decode_kernel
<
scalar_t
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
lightning_attention_decode_kernel
<
scalar_t
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
q
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
v
.
data_ptr
<
scalar_t
>
(),
past_kv
.
data_ptr
<
float
>
(),
q
.
data_ptr
<
scalar_t
>
(),
slope
.
data_ptr
<
float
>
(),
output
.
data_ptr
<
scalar_t
>
(),
new_kv
.
data_ptr
<
float
>
(),
batch_size
,
num_heads
,
k
.
data_ptr
<
scalar_t
>
(),
qk_dim
,
v_dim
);
v
.
data_ptr
<
scalar_t
>
(),
past_kv
.
data_ptr
<
float
>
(),
slope
.
data_ptr
<
float
>
(),
output
.
data_ptr
<
scalar_t
>
(),
new_kv
.
data_ptr
<
float
>
(),
batch_size
,
num_heads
,
qk_dim
,
v_dim
);
}));
}));
}
}
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
View file @
d052f4c8
...
@@ -25,9 +25,15 @@ namespace cutlass {
...
@@ -25,9 +25,15 @@ namespace cutlass {
namespace
epilogue
{
namespace
epilogue
{
namespace
threadblock
{
namespace
threadblock
{
template
<
typename
ThreadblockShape_
,
int
ThreadCount
,
typename
ScaleTileIterator_
,
typename
OutputTileIterator_
,
template
<
typename
ElementAccumulator_
,
typename
ElementCompute_
,
typename
ElementwiseFunctor_
,
typename
ThreadblockShape_
,
bool
UseMasking_
=
false
>
int
ThreadCount
,
typename
ScaleTileIterator_
,
typename
OutputTileIterator_
,
typename
ElementAccumulator_
,
typename
ElementCompute_
,
typename
ElementwiseFunctor_
,
bool
UseMasking_
=
false
>
class
EpilogueVisitorPerRowPerCol
{
class
EpilogueVisitorPerRowPerCol
{
public:
public:
using
ThreadblockShape
=
ThreadblockShape_
;
using
ThreadblockShape
=
ThreadblockShape_
;
...
@@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol {
...
@@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol {
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
)
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
)
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
0
),
batch_stride_C
(
0
),
batch_stride_D
(
0
)
{}
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
0
),
batch_stride_C
(
0
),
batch_stride_D
(
0
)
{}
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
,
int64_t
batch_stride_alpha_
,
int64_t
batch_stride_C_
,
Arguments
(
int64_t
batch_stride_D_
)
typename
ElementwiseFunctor
::
Params
elementwise_
,
int64_t
batch_stride_alpha_
,
int64_t
batch_stride_C_
,
int64_t
batch_stride_D_
)
:
elementwise
(
elementwise_
),
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
batch_stride_alpha_
),
batch_stride_alpha
(
batch_stride_alpha_
),
batch_stride_C
(
batch_stride_C_
),
batch_stride_C
(
batch_stride_C_
),
...
@@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol {
...
@@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol {
public:
public:
CUTLASS_DEVICE
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol
(
Params
const
&
params
,
SharedStorage
&
shared_storage
,
EpilogueVisitorPerRowPerCol
(
cutlass
::
MatrixCoord
const
&
problem_size
,
int
thread_idx
,
int
warp_idx
,
int
lane_idx
,
Params
const
&
params
,
typename
ScaleTileIterator
::
Params
params_alpha_col
,
SharedStorage
&
shared_storage
,
typename
OutputTileIterator
::
Params
params_C
,
cutlass
::
MatrixCoord
const
&
problem_size
,
typename
OutputTileIterator
::
Params
params_D
,
bool
with_bias
,
bool
per_token_quant
,
int
thread_idx
,
bool
per_channel_quant
,
AlphaScaleElementType
*
ptr_alpha_row
,
int
warp_idx
,
AlphaScaleElementType
*
ptr_alpha_col
,
typename
OutputTileIterator
::
Element
*
ptr_C
,
int
lane_idx
,
typename
OutputTileIterator
::
Element
*
ptr_D
,
typename
ScaleTileIterator
::
Params
params_alpha_col
,
cutlass
::
MatrixCoord
const
&
threadblock_offset
=
cutlass
::
MatrixCoord
(
0
,
0
),
typename
OutputTileIterator
::
Params
params_C
,
int
column_offset
=
0
,
typename
OutputTileIterator
::
Params
params_D
,
cutlass
::
MatrixCoord
const
&
problem_size_real
=
cutlass
::
MatrixCoord
(
0
,
0
))
bool
with_bias
,
bool
per_token_quant
,
bool
per_channel_quant
,
AlphaScaleElementType
*
ptr_alpha_row
,
AlphaScaleElementType
*
ptr_alpha_col
,
typename
OutputTileIterator
::
Element
*
ptr_C
,
typename
OutputTileIterator
::
Element
*
ptr_D
,
cutlass
::
MatrixCoord
const
&
threadblock_offset
=
cutlass
::
MatrixCoord
(
0
,
0
),
int
column_offset
=
0
,
cutlass
::
MatrixCoord
const
&
problem_size_real
=
cutlass
::
MatrixCoord
(
0
,
0
))
:
params_
(
params
),
:
params_
(
params
),
shared_storage_
(
shared_storage
),
shared_storage_
(
shared_storage
),
extent_
(
problem_size
),
extent_
(
problem_size
),
...
@@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol {
...
@@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol {
/// Helper to indicate split-K behavior
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
CUTLASS_DEVICE
void
set_k_partition
(
int
split_k_index
,
///< Index of this threadblock within split-K partitioned scheme
void
set_k_partition
(
int
split_k_slices
)
{
///< Total number of split-K slices
int
split_k_index
,
///< Index of this threadblock within split-K partitioned scheme
int
split_k_slices
)
{
///< Total number of split-K slices
}
}
/// Called to set the batch index
/// Called to set the batch index
...
@@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol {
...
@@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol {
private:
private:
CUTLASS_DEVICE
CUTLASS_DEVICE
ComputeFragment
per_token_channel_scale_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
scale_col
,
ComputeFragment
per_token_channel_scale_accumulator_
(
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
...
@@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol {
...
@@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol {
}
}
CUTLASS_DEVICE
CUTLASS_DEVICE
ComputeFragment
per_token_scale_accumulator_
(
ComputeFragment
const
&
accum
,
AlphaScaleElementType
const
&
scale_col
,
ComputeFragment
per_token_scale_accumulator_
(
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
const
&
accum
,
AlphaScaleElementType
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
...
...
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
View file @
d052f4c8
...
@@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT
...
@@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
// specialized dynamic schedule For FP8 kernels with Block Scaling
template
<
int
Stages_
,
class
ClusterShape_
=
Shape
<
_1
,
_1
,
_1
>,
class
KernelSchedule
=
KernelTmaWarpSpecialized
,
template
<
int
ScaleGranularityM
=
0
// `ScaleGranularityM` specifies scaling granularity along M,
int
Stages_
,
// while zero-value `ScaleGranularityM` indicates that scaling
class
ClusterShape_
=
Shape
<
_1
,
_1
,
_1
>,
// granularity is `size<0>(TileShape_MNK{})` along M.
class
KernelSchedule
=
KernelTmaWarpSpecialized
,
>
int
ScaleGranularityM
=
0
// `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
struct
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
:
MainloopSm90TmaGmmaWarpSpecialized
<
Stages_
,
ClusterShape_
,
KernelSchedule
>
{
:
MainloopSm90TmaGmmaWarpSpecialized
<
Stages_
,
ClusterShape_
,
KernelSchedule
>
{
static_assert
(
cute
::
is_same_v
<
KernelSchedule
,
static_assert
(
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
,
cute
::
"KernelSchedule must be one of the warp specialized policies"
);
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
,
"KernelSchedule must be one of the warp specialized policies"
);
};
};
//////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
...
...
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
View file @
d052f4c8
...
@@ -159,8 +159,9 @@ class GemmUniversalBaseCompat {
...
@@ -159,8 +159,9 @@ class GemmUniversalBaseCompat {
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
dim3
result
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
dim3
result
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
CUTLASS_TRACE_HOST
(
" grid_tiled_shape: "
<<
grid_tiled_shape
<<
"
\n
"
CUTLASS_TRACE_HOST
(
<<
" result = {"
<<
result
<<
"}"
);
" grid_tiled_shape: "
<<
grid_tiled_shape
<<
"
\n
"
<<
" result = {"
<<
result
<<
"}"
);
return
result
;
return
result
;
}
}
...
@@ -175,8 +176,8 @@ class GemmUniversalBaseCompat {
...
@@ -175,8 +176,8 @@ class GemmUniversalBaseCompat {
CUTLASS_TRACE_HOST
(
" smem_size: "
<<
smem_size
<<
" bytes"
);
CUTLASS_TRACE_HOST
(
" smem_size: "
<<
smem_size
<<
" bytes"
);
if
(
smem_size
<=
(
48
<<
10
))
{
if
(
smem_size
<=
(
48
<<
10
))
{
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
GemmKernel
::
kThreadCount
,
smem_size
);
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
smem_size
);
if
(
result
==
cudaSuccess
)
{
if
(
result
==
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
...
@@ -184,12 +185,12 @@ class GemmUniversalBaseCompat {
...
@@ -184,12 +185,12 @@ class GemmUniversalBaseCompat {
}
}
}
else
{
}
else
{
// Query assuming zero shared memory then compute occupancy limit based on SMEM
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
GemmKernel
::
kThreadCount
,
0
);
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
0
);
if
(
result
!=
cudaSuccess
)
{
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
CUTLASS_TRACE_HOST
(
<<
cudaGetErrorString
(
result
));
"
cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<<
cudaGetErrorString
(
result
));
return
-
1
;
return
-
1
;
}
}
...
@@ -226,8 +227,9 @@ class GemmUniversalBaseCompat {
...
@@ -226,8 +227,9 @@ class GemmUniversalBaseCompat {
/// Initializes GEMM state from arguments.
/// Initializes GEMM state from arguments.
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::initialize() - workspace "
CUTLASS_TRACE_HOST
(
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
"GemmUniversalBaseCompat::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
size_t
workspace_bytes
=
get_workspace_size
(
args
);
size_t
workspace_bytes
=
get_workspace_size
(
args
);
...
...
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
View file @
d052f4c8
...
@@ -32,10 +32,11 @@ namespace kernel {
...
@@ -32,10 +32,11 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
template
<
typename
Epilogue_
,
///! Epilogue
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
ThreadblockSwizzle_
///! Threadblock swizzling function
typename
Epilogue_
,
///! Epilogue
>
typename
ThreadblockSwizzle_
///! Threadblock swizzling function
>
struct
GemmWithEpilogueVisitor
{
struct
GemmWithEpilogueVisitor
{
public:
public:
using
Mma
=
Mma_
;
using
Mma
=
Mma_
;
...
@@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor {
...
@@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor {
Arguments
()
:
mode
(
GemmUniversalMode
::
kGemm
),
batch_count
(
1
)
{}
Arguments
()
:
mode
(
GemmUniversalMode
::
kGemm
),
batch_count
(
1
)
{}
/// constructs an arguments structure
/// constructs an arguments structure
Arguments
(
GemmCoord
problem_size_
,
TensorRefA
ref_A_
,
TensorRefB
ref_B_
,
TensorRefAlphaCol
ref_alpha_col_
,
Arguments
(
TensorRefAlphaRow
ref_alpha_row_
,
TensorRefC
ref_C_
,
TensorRefC
ref_D_
,
GemmCoord
problem_size_
,
typename
EpilogueVisitor
::
Arguments
epilogue_visitor_
)
TensorRefA
ref_A_
,
TensorRefB
ref_B_
,
TensorRefAlphaCol
ref_alpha_col_
,
TensorRefAlphaRow
ref_alpha_row_
,
TensorRefC
ref_C_
,
TensorRefC
ref_D_
,
typename
EpilogueVisitor
::
Arguments
epilogue_visitor_
)
:
mode
(
GemmUniversalMode
::
kGemm
),
:
mode
(
GemmUniversalMode
::
kGemm
),
problem_size
(
problem_size_
),
problem_size
(
problem_size_
),
batch_count
(
1
),
batch_count
(
1
),
...
@@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor {
...
@@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor {
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajor
>::
value
)
{
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajor
>::
value
)
{
isAMisaligned
=
problem_size
.
m
()
%
kAlignmentA
;
isAMisaligned
=
problem_size
.
m
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
}
...
@@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor {
...
@@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor {
isBMisaligned
=
problem_size
.
n
()
%
kAlignmentB
;
isBMisaligned
=
problem_size
.
n
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
ColumnMajor
>::
value
)
{
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
ColumnMajor
>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
32
>>::
value
||
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
64
>>::
value
)
{
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
64
>>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
}
...
@@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor {
...
@@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor {
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajor
>::
value
)
{
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajor
>::
value
)
{
isCMisaligned
=
problem_size
.
m
()
%
kAlignmentC
;
isCMisaligned
=
problem_size
.
m
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
}
...
@@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor {
...
@@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor {
int
thread_idx
=
threadIdx
.
x
;
int
thread_idx
=
threadIdx
.
x
;
// Construct iterators to A and B operands
// Construct iterators to A and B operands
typename
Mma
::
IteratorA
iterator_A
(
params
.
params_A
,
ptr_A
,
{
params
.
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
typename
Mma
::
IteratorA
iterator_A
(
tb_offset_A
);
params
.
params_A
,
ptr_A
,
{
params
.
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorB
iterator_B
(
params
.
params_B
,
ptr_B
,
{
problem_size_k
,
params
.
problem_size
.
n
()},
thread_idx
,
typename
Mma
::
IteratorB
iterator_B
(
tb_offset_B
);
params
.
params_B
,
ptr_B
,
{
problem_size_k
,
params
.
problem_size
.
n
()},
thread_idx
,
tb_offset_B
);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
// is compiled as warp-uniform.
...
@@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor {
...
@@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor {
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// assume identity swizzle
// assume identity swizzle
MatrixCoord
threadblock_offset
(
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
MatrixCoord
threadblock_offset
(
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
);
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
);
int
block_idx
=
threadblock_tile_offset
.
m
()
+
threadblock_tile_offset
.
n
()
*
params
.
grid_tiled_shape
.
m
();
int
block_idx
=
threadblock_tile_offset
.
m
()
+
threadblock_tile_offset
.
n
()
*
params
.
grid_tiled_shape
.
m
();
...
@@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor {
...
@@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor {
with_bias
=
false
;
with_bias
=
false
;
}
}
EpilogueVisitor
epilogue_visitor
(
params
.
epilogue_visitor
,
shared_storage
.
epilogue
.
visitor
,
params
.
problem_size
.
mn
(),
EpilogueVisitor
epilogue_visitor
(
thread_idx
,
warp_idx
,
lane_idx
,
params
.
params_alpha_col
,
params
.
params_C
,
params
.
epilogue_visitor
,
params
.
params_D
,
with_bias
,
true
,
true
,
params
.
ptr_alpha_row
,
params
.
ptr_alpha_col
,
shared_storage
.
epilogue
.
visitor
,
params
.
ptr_C
,
params
.
ptr_D
,
threadblock_offset
,
params
.
problem_size
.
mn
(),
blockIdx
.
y
*
params
.
problem_size
.
m
());
thread_idx
,
warp_idx
,
lane_idx
,
params
.
params_alpha_col
,
params
.
params_C
,
params
.
params_D
,
with_bias
,
true
,
true
,
params
.
ptr_alpha_row
,
params
.
ptr_alpha_col
,
params
.
ptr_C
,
params
.
ptr_D
,
threadblock_offset
,
blockIdx
.
y
*
params
.
problem_size
.
m
());
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
)
{
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
)
{
// Indicate which position in a serial reduction the output operator is currently updating
// Indicate which position in a serial reduction the output operator is currently updating
...
...
sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
View file @
d052f4c8
...
@@ -21,10 +21,13 @@
...
@@ -21,10 +21,13 @@
#include "utils.h"
#include "utils.h"
static
void
check_group_count
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
static
void
check_group_count
(
const
std
::
vector
<
torch
::
Tensor
>&
outputs
)
{
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
TORCH_CHECK
(((
inputs
.
size
()
==
weights
.
size
())
&&
(
inputs
.
size
()
==
outputs
.
size
())),
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
"The group count of inputs, weights and outputs should be the same."
);
const
std
::
vector
<
torch
::
Tensor
>&
outputs
)
{
TORCH_CHECK
(
((
inputs
.
size
()
==
weights
.
size
())
&&
(
inputs
.
size
()
==
outputs
.
size
())),
"The group count of inputs, weights and outputs should be the same."
);
}
}
static
void
check_device_dtype
(
const
torch
::
Dtype
&
dtype
,
const
std
::
vector
<
torch
::
Tensor
>&
tensors
)
{
static
void
check_device_dtype
(
const
torch
::
Dtype
&
dtype
,
const
std
::
vector
<
torch
::
Tensor
>&
tensors
)
{
...
@@ -68,21 +71,26 @@ static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tens
...
@@ -68,21 +71,26 @@ static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tens
static
torch
::
Tensor
create_ptr_pointer
(
const
std
::
vector
<
void
*>&
ptrs
,
cudaStream_t
stream
)
{
static
torch
::
Tensor
create_ptr_pointer
(
const
std
::
vector
<
void
*>&
ptrs
,
cudaStream_t
stream
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kDouble
).
device
(
torch
::
kCUDA
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kDouble
).
device
(
torch
::
kCUDA
);
torch
::
Tensor
gpu_ptrs
=
torch
::
empty
({
static_cast
<
int
>
(
ptrs
.
size
())},
options
);
torch
::
Tensor
gpu_ptrs
=
torch
::
empty
({
static_cast
<
int
>
(
ptrs
.
size
())},
options
);
TORCH_CHECK
(
cudaMemcpyAsync
(
gpu_ptrs
.
data_ptr
(),
ptrs
.
data
(),
sizeof
(
void
*
)
*
ptrs
.
size
(),
cudaMemcpyHostToDevice
,
TORCH_CHECK
(
stream
)
==
CUBLAS_STATUS_SUCCESS
);
cudaMemcpyAsync
(
gpu_ptrs
.
data_ptr
(),
ptrs
.
data
(),
sizeof
(
void
*
)
*
ptrs
.
size
(),
cudaMemcpyHostToDevice
,
stream
)
==
CUBLAS_STATUS_SUCCESS
);
return
gpu_ptrs
;
return
gpu_ptrs
;
}
}
// We want compute input @ weight^T in row major
// We want compute input @ weight^T in row major
// This is equivalent to computing weight @ input^T in col major
// This is equivalent to computing weight @ input^T in col major
// Cublas only accepts matrix in column major, so this arrangement is needed
// Cublas only accepts matrix in column major, so this arrangement is needed
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
// b: (m, k) row major = (k, m) col major
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
// a: (n, k) row major = (n, k)^T col major
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
// b: (m, k) row major = (k, m) col major
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
// c: (m, n) row major = (n, m) col major
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
// a: (n, k) row major = (n, k)^T col major
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cuda_stream
)
{
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
// c: (m, n) row major = (n, m) col major
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
const
torch
::
Dtype
&
out_dtype
,
"cublas grouped_gemm can"
int64_t
cublas_handle
,
"only be applied to float16 and bfloat16 dtype"
);
int64_t
cuda_stream
)
{
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"cublas grouped_gemm can"
"only be applied to float16 and bfloat16 dtype"
);
int
group_count
=
inputs
.
size
();
int
group_count
=
inputs
.
size
();
check_group_count
(
inputs
,
weights
,
outputs
);
check_group_count
(
inputs
,
weights
,
outputs
);
...
@@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k
...
@@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k
torch
::
Tensor
d_c
=
create_ptr_pointer
(
c_array
,
stream
);
torch
::
Tensor
d_c
=
create_ptr_pointer
(
c_array
,
stream
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
auto
status
=
cublasGemmGroupedBatchedEx
(
handle
,
transa_array
.
data
(),
transb_array
.
data
(),
m_array
.
data
(),
auto
status
=
cublasGemmGroupedBatchedEx
(
n_array
.
data
(),
k_array
.
data
(),
alpha_array
.
data
(),
(
void
**
)
d_a
.
data_ptr
(),
handle
,
cuda_data_type
,
lda_array
.
data
(),
(
void
**
)
d_b
.
data_ptr
(),
cuda_data_type
,
transa_array
.
data
(),
ldb_array
.
data
(),
beta_array
.
data
(),
(
void
**
)
d_c
.
data_ptr
(),
cuda_data_type
,
transb_array
.
data
(),
ldc_array
.
data
(),
group_count
,
group_size
.
data
(),
CUBLAS_COMPUTE_32F
);
m_array
.
data
(),
n_array
.
data
(),
k_array
.
data
(),
alpha_array
.
data
(),
(
void
**
)
d_a
.
data_ptr
(),
cuda_data_type
,
lda_array
.
data
(),
(
void
**
)
d_b
.
data_ptr
(),
cuda_data_type
,
ldb_array
.
data
(),
beta_array
.
data
(),
(
void
**
)
d_c
.
data_ptr
(),
cuda_data_type
,
ldc_array
.
data
(),
group_count
,
group_size
.
data
(),
CUBLAS_COMPUTE_32F
);
TORCH_CHECK
(
status
==
CUBLAS_STATUS_SUCCESS
,
"cublas grouped gemm failed: "
,
cublasGetStatusString
(
status
));
TORCH_CHECK
(
status
==
CUBLAS_STATUS_SUCCESS
,
"cublas grouped gemm failed: "
,
cublasGetStatusString
(
status
));
TORCH_CHECK
(
cudaStreamSynchronize
(
stream
)
==
cudaSuccess
,
"Failed when stream synchronization"
);
TORCH_CHECK
(
cudaStreamSynchronize
(
stream
)
==
cudaSuccess
,
"Failed when stream synchronization"
);
return
;
return
;
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
TORCH_CHECK_NOT_IMPLEMENTED
(
"Cublas GroupGemm is not implemented with current compute capability: "
,
getSMVersion
());
false
,
"Cublas GroupGemm is not implemented with current compute capability: "
,
getSMVersion
());
}
}
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
View file @
d052f4c8
...
@@ -35,8 +35,12 @@
...
@@ -35,8 +35,12 @@
using
namespace
cute
;
using
namespace
cute
;
template
<
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
,
int
ScaleGranularityM
=
1
>
template
<
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
,
int
ScaleGranularityM
=
1
>
void
launch_sm90_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
void
launch_sm90_fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
ElementAccumulator
=
float
;
using
ElementAccumulator
=
float
;
using
ElementCompute
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
using
ElementBlockScale
=
float
;
...
@@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
...
@@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
using
KernelSchedule
=
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
;
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
ArchTag
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
ArchTag
,
TileShape
,
ClusterShape
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
Gemm
gemm_op
;
Gemm
gemm_op
;
...
@@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
...
@@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
}
}
template
<
typename
OutType
>
template
<
typename
OutType
>
void
sm90_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
void
sm90_fp8_blockwise_dispatch_shape
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
launch_sm90_fp8_blockwise_scaled_mm
<
OutType
,
TileShape
,
ClusterShape
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
launch_sm90_fp8_blockwise_scaled_mm
<
OutType
,
TileShape
,
ClusterShape
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
}
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Dtype
&
out_dtype
)
{
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
)
{
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
...
@@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
...
@@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_a must be a column major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_a must be a column major tensor"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
(
mat_a
.
size
(
1
)
*
mat_a
.
element_size
())
%
16
==
0
,
TORCH_CHECK
(
"mat_a must be multiple of 16 bytes for memory alignment"
);
(
mat_a
.
size
(
1
)
*
mat_a
.
element_size
())
%
16
==
0
,
"mat_a must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
(
mat_b
.
size
(
0
)
*
mat_b
.
element_size
())
%
16
==
0
,
TORCH_CHECK
(
"mat_b must be multiple of 16 bytes for memory alignment"
);
(
mat_b
.
size
(
0
)
*
mat_b
.
element_size
())
%
16
==
0
,
"mat_b must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
mat_a
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_a must be Float8_e4m3fn"
);
TORCH_CHECK
(
mat_a
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_a must be Float8_e4m3fn"
);
TORCH_CHECK
(
mat_b
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_b must be Float8_e4m3fn"
);
TORCH_CHECK
(
mat_b
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_b must be Float8_e4m3fn"
);
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"out_dtype must be Half or BFloat16"
);
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"out_dtype must be Half or BFloat16"
);
...
@@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
...
@@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
#endif
#endif
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
TORCH_CHECK_NOT_IMPLEMENTED
(
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
}
}
sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
View file @
d052f4c8
...
@@ -53,10 +53,17 @@ limitations under the License.
...
@@ -53,10 +53,17 @@ limitations under the License.
using
namespace
cute
;
using
namespace
cute
;
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
template
<
typename
ElementType
,
typename
OutElementType
,
typename
AccumElementType
,
typename
CtaShape
,
template
<
typename
WarpShape
,
int
Stages
,
bool
WithBias
,
typename
FP8MathOperator
=
cutlass
::
arch
::
OpMultiplyAdd
,
typename
ElementType
,
template
<
typename
...
>
typename
EpilogueVisitor
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
,
typename
OutElementType
,
typename
ThreadblockSwizzle
=
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
>
>
typename
AccumElementType
,
typename
CtaShape
,
typename
WarpShape
,
int
Stages
,
bool
WithBias
,
typename
FP8MathOperator
=
cutlass
::
arch
::
OpMultiplyAdd
,
template
<
typename
...
>
typename
EpilogueVisitor
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
,
typename
ThreadblockSwizzle
=
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
>
>
struct
DeviceGemmFp8RowwiseSm89
{
struct
DeviceGemmFp8RowwiseSm89
{
static_assert
(
std
::
is_same_v
<
ElementType
,
cutlass
::
float_e4m3_t
>
,
"ElementType must be FP8(e4m3)"
);
static_assert
(
std
::
is_same_v
<
ElementType
,
cutlass
::
float_e4m3_t
>
,
"ElementType must be FP8(e4m3)"
);
...
@@ -85,56 +92,86 @@ struct DeviceGemmFp8RowwiseSm89 {
...
@@ -85,56 +92,86 @@ struct DeviceGemmFp8RowwiseSm89 {
// Number of epilogue stages in EVT
// Number of epilogue stages in EVT
static
constexpr
int
EVTEpilogueStages
=
1
;
static
constexpr
int
EVTEpilogueStages
=
1
;
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
OutputTileThreadLayout
<
CtaShape
,
WarpShape
,
ElementC
,
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
AlignmentC
,
EVTEpilogueStages
>
;
OutputTileThreadLayout
<
CtaShape
,
WarpShape
,
ElementC
,
AlignmentC
,
EVTEpilogueStages
>
;
// Definition of EVT
// Definition of EVT
using
accSrc
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
accSrc
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ComputeBScale
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
using
ComputeBScale
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
cutlass
::
multiplies
,
using
bScaleSrc
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
Stride
<
_0
,
_1
,
_0
>>
;
ElementComputeEpilogue
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
bScaleSrc
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementComputeEpilogue
,
Stride
<
_0
,
_1
,
_0
>>
;
using
EpilogueBScale
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeBScale
,
accSrc
,
bScaleSrc
>
;
using
EpilogueBScale
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeBScale
,
accSrc
,
bScaleSrc
>
;
using
ComputeAScale
=
using
ComputeAScale
=
cutlass
::
epilogue
::
threadblock
::
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementC
,
ElementComputeEpilogue
,
VisitorCompute
<
cutlass
::
multiplies
,
ElementC
,
ElementComputeEpilogue
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
aScaleSrc
=
cutlass
::
epilogue
::
threadblock
::
using
aScaleSrc
=
cutlass
::
epilogue
::
threadblock
::
VisitorColBroadcast
<
OutputTileThreadMap
,
ElementComputeEpilogue
,
VisitorColBroadcast
<
OutputTileThreadMap
,
ElementComputeEpilogue
,
Stride
<
_1
,
_0
,
_0
>>
;
Stride
<
_1
,
_0
,
_0
>>
;
using
EpilogueAScale
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAScale
,
EpilogueBScale
,
aScaleSrc
>
;
using
EpilogueAScale
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAScale
,
EpilogueBScale
,
aScaleSrc
>
;
// With bias
// With bias
using
biasSrc
=
using
biasSrc
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementOutput
,
Stride
<
_0
,
_1
,
_0
>>
;
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementOutput
,
Stride
<
_0
,
_1
,
_0
>>
;
using
ComputeAScaleWithBias
=
using
ComputeAScaleWithBias
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementC
,
ElementComputeEpilogue
,
cutlass
::
multiply_add
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
ElementC
,
ElementComputeEpilogue
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EpilogueAScaleWithBias
=
using
EpilogueAScaleWithBias
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAScaleWithBias
,
EpilogueBScale
,
aScaleSrc
,
biasSrc
>
;
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAScaleWithBias
,
EpilogueBScale
,
aScaleSrc
,
biasSrc
>
;
using
dTar
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
using
dTar
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
OutputTileThreadMap
,
ElementC
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
Stride
<
int64_t
,
_1
,
_0
>>
;
OutputTileThreadMap
,
using
EpilogueStore
=
ElementC
,
typename
cutlass
::
platform
::
conditional
<
WithBias
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
dTar
,
EpilogueAScaleWithBias
>
,
Stride
<
int64_t
,
_1
,
_0
>>
;
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
dTar
,
EpilogueAScale
>>::
type
;
using
EpilogueStore
=
typename
cutlass
::
platform
::
conditional
<
WithBias
,
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
dTar
,
EpilogueAScaleWithBias
>
,
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
dTar
,
EpilogueAScale
>>::
type
;
using
EpilogueOp
=
EpilogueStore
;
using
EpilogueOp
=
EpilogueStore
;
using
GemmKernel
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
using
GemmKernel
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementA
,
LayoutA
,
cutlass
::
ComplexTransform
::
kNone
,
AlignmentA
,
ElementB
,
LayoutB
,
ElementA
,
cutlass
::
ComplexTransform
::
kNone
,
AlignmentB
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementAccumulator
,
LayoutA
,
ElementComputeEpilogue
,
OperatorClass
,
ArchTag
,
CtaShape
,
WarpShape
,
InstructionShape
,
EpilogueOp
,
cutlass
::
ComplexTransform
::
kNone
,
ThreadblockSwizzle
,
Stages
,
FP8MathOperator
,
EVTEpilogueStages
>::
GemmKernel
;
AlignmentA
,
ElementB
,
LayoutB
,
cutlass
::
ComplexTransform
::
kNone
,
AlignmentB
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementAccumulator
,
ElementComputeEpilogue
,
OperatorClass
,
ArchTag
,
CtaShape
,
WarpShape
,
InstructionShape
,
EpilogueOp
,
ThreadblockSwizzle
,
Stages
,
FP8MathOperator
,
EVTEpilogueStages
>::
GemmKernel
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
};
};
template
<
typename
Gemm
,
bool
WithBias
>
template
<
typename
Gemm
,
bool
WithBias
>
typename
Gemm
::
Arguments
prepare_sm89_fp8_args
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
typename
Gemm
::
Arguments
prepare_sm89_fp8_args
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
&
out
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
ElementT
=
typename
Gemm
::
ElementA
;
using
ElementT
=
typename
Gemm
::
ElementA
;
using
ElementOutput
=
typename
Gemm
::
ElementD
;
using
ElementOutput
=
typename
Gemm
::
ElementD
;
using
ElementComputeEpilogue
=
float
;
using
ElementComputeEpilogue
=
float
;
...
@@ -158,54 +195,61 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::
...
@@ -158,54 +195,61 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::
ElementComputeEpilogue
const
*
ptr_scales_a
=
reinterpret_cast
<
ElementComputeEpilogue
const
*>
(
scales_a
.
data_ptr
());
ElementComputeEpilogue
const
*
ptr_scales_a
=
reinterpret_cast
<
ElementComputeEpilogue
const
*>
(
scales_a
.
data_ptr
());
ElementComputeEpilogue
const
*
ptr_scales_b
=
reinterpret_cast
<
ElementComputeEpilogue
const
*>
(
scales_b
.
data_ptr
());
ElementComputeEpilogue
const
*
ptr_scales_b
=
reinterpret_cast
<
ElementComputeEpilogue
const
*>
(
scales_b
.
data_ptr
());
typename
Gemm
::
Arguments
args
(
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
// Mode
typename
Gemm
::
Arguments
args
(
{
m
,
n
,
k
},
// Problem size
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
// Mode
1
,
// Split-k factor
{
m
,
n
,
k
},
// Problem size
{},
// Epilogue args
1
,
// Split-k factor
ptr_a
,
// a pointer
{},
// Epilogue args
ptr_b
,
// b pointer
ptr_a
,
// a pointer
nullptr
,
// c pointer (unused)
ptr_b
,
// b pointer
nullptr
,
// d pointer (unused)
nullptr
,
// c pointer (unused)
m
*
k
,
// batch stride a (unused)
nullptr
,
// d pointer (unused)
n
*
k
,
// batch stride b (unused)
m
*
k
,
// batch stride a (unused)
m
*
n
,
// batch stride c (unused)
n
*
k
,
// batch stride b (unused)
m
*
n
,
// batch stride d (unused)
m
*
n
,
// batch stride c (unused)
lda
,
// stride a
m
*
n
,
// batch stride d (unused)
ldb
,
// stride b
lda
,
// stride a
ldc
,
// stride c (unused)
ldb
,
// stride b
ldc
);
// stride d (unused)
ldc
,
// stride c (unused)
ldc
);
// stride d (unused)
if
constexpr
(
WithBias
)
{
if
constexpr
(
WithBias
)
{
args
.
epilogue
=
{{
args
.
epilogue
=
{
{
{
{},
// Accumulator
{
{
ptr_scales_b
,
ElementComputeEpilogue
(
0
),
{
_0
{},
_1
{},
_0
{}}},
{},
// Accumulator
{}
// Multiplies
{
ptr_scales_b
,
ElementComputeEpilogue
(
0
),
{
_0
{},
_1
{},
_0
{}}},
},
{}
// Multiplies
{
ptr_scales_a
,
ElementComputeEpilogue
(
0
),
{
_1
{},
_0
{},
_0
{}}},
},
{
ptr_bias
,
ElementOutput
(
0
),
{
_0
{},
_1
{},
_0
{}}},
{
ptr_scales_a
,
ElementComputeEpilogue
(
0
),
{
_1
{},
_0
{},
_0
{}}},
{}
// Multiplies
{
ptr_bias
,
ElementOutput
(
0
),
{
_0
{},
_1
{},
_0
{}}},
},
{}
// Multiplies
{
ptr_d
,
{
n
,
_1
{},
_0
{}}}};
},
{
ptr_d
,
{
n
,
_1
{},
_0
{}}}};
}
else
{
}
else
{
args
.
epilogue
=
{{
args
.
epilogue
=
{
{
{
{},
// Accumulator
{
{
ptr_scales_b
,
ElementComputeEpilogue
(
0
),
{
_0
{},
_1
{},
_0
{}}},
{},
// Accumulator
{}
// Multiplies
{
ptr_scales_b
,
ElementComputeEpilogue
(
0
),
{
_0
{},
_1
{},
_0
{}}},
},
{}
// Multiplies
{
ptr_scales_a
,
ElementComputeEpilogue
(
0
),
{
_1
{},
_0
{},
_0
{}}},
},
{}
// Multiplies
{
ptr_scales_a
,
ElementComputeEpilogue
(
0
),
{
_1
{},
_0
{},
_0
{}}},
},
{}
// Multiplies
{
ptr_d
,
{
n
,
_1
{},
_0
{}}}};
},
{
ptr_d
,
{
n
,
_1
{},
_0
{}}}};
}
}
return
args
;
return
args
;
}
}
template
<
typename
Gemm
,
bool
WithBias
>
template
<
typename
Gemm
,
bool
WithBias
>
void
launch_sm89_fp8_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
void
launch_sm89_fp8_scaled_mm
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
&
out
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
auto
args
=
prepare_sm89_fp8_args
<
Gemm
,
WithBias
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
auto
args
=
prepare_sm89_fp8_args
<
Gemm
,
WithBias
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
Gemm
gemm_op
;
Gemm
gemm_op
;
...
@@ -222,109 +266,187 @@ void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
...
@@ -222,109 +266,187 @@ void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
}
}
template
<
typename
OutType
,
typename
CtaShape
,
typename
WarpShape
,
int
Stages
>
template
<
typename
OutType
,
typename
CtaShape
,
typename
WarpShape
,
int
Stages
>
void
sm89_fp8_dispatch_bias
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
void
sm89_fp8_dispatch_bias
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
&
out
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
ElementInput
=
cutlass
::
float_e4m3_t
;
using
ElementInput
=
cutlass
::
float_e4m3_t
;
using
ElementOutput
=
OutType
;
using
ElementOutput
=
OutType
;
using
AccumElementType
=
float
;
using
AccumElementType
=
float
;
if
(
bias
)
{
if
(
bias
)
{
using
Gemm
=
typename
DeviceGemmFp8RowwiseSm89
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CtaShape
,
WarpShape
,
using
Gemm
=
typename
DeviceGemmFp8RowwiseSm89
<
Stages
,
true
>::
Gemm
;
ElementInput
,
ElementOutput
,
AccumElementType
,
CtaShape
,
WarpShape
,
Stages
,
true
>::
Gemm
;
return
launch_sm89_fp8_scaled_mm
<
Gemm
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
return
launch_sm89_fp8_scaled_mm
<
Gemm
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
using
Gemm
=
typename
DeviceGemmFp8RowwiseSm89
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CtaShape
,
WarpShape
,
using
Gemm
=
typename
DeviceGemmFp8RowwiseSm89
<
Stages
,
false
>::
Gemm
;
ElementInput
,
ElementOutput
,
AccumElementType
,
CtaShape
,
WarpShape
,
Stages
,
false
>::
Gemm
;
return
launch_sm89_fp8_scaled_mm
<
Gemm
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
return
launch_sm89_fp8_scaled_mm
<
Gemm
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
}
template
<
typename
OutType
>
template
<
typename
OutType
>
void
sm89_fp8_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
void
sm89_fp8_dispatch_shape
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
&
out
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
n
=
out
.
size
(
1
);
uint32_t
const
n
=
out
.
size
(
1
);
if
(
m
==
1
)
{
if
(
m
==
1
)
{
if
(
n
<=
8192
)
{
if
(
n
<=
8192
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
if
(
m
<=
16
)
{
}
else
if
(
m
<=
16
)
{
// M in (1, 16]
// M in (1, 16]
if
(
n
<=
8192
)
{
if
(
n
<=
8192
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
n
<=
16384
)
{
}
else
if
(
n
<=
16384
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
if
(
m
<=
64
)
{
}
else
if
(
m
<=
64
)
{
// M in (16, 64]
// M in (16, 64]
if
(
n
<=
16384
)
{
if
(
n
<=
16384
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
if
(
m
<=
128
)
{
}
else
if
(
m
<=
128
)
{
// M in (64, 128]
// M in (64, 128]
if
(
n
<=
8192
)
{
if
(
n
<=
8192
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
n
<=
16384
)
{
}
else
if
(
n
<=
16384
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
if
(
m
<=
256
)
{
}
else
if
(
m
<=
256
)
{
// M in (128, 256]
// M in (128, 256]
if
(
n
<=
8192
)
{
if
(
n
<=
8192
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
64
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
n
<=
16384
)
{
}
else
if
(
n
<=
16384
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
128
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
128
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
128
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
if
(
m
<=
512
)
{
}
else
if
(
m
<=
512
)
{
// M in (256, 512)
// M in (256, 512)
if
(
n
<=
16384
)
{
if
(
n
<=
16384
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
2
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
2
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
{
}
else
{
// M in (512, inf)
// M in (512, inf)
if
(
n
<=
8192
)
{
if
(
n
<=
8192
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
3
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
3
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
return
sm89_fp8_dispatch_bias
<
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
2
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
2
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
}
}
}
#endif
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
template
<
typename
ElementType
,
typename
OutElementType
,
typename
AccumElementType
,
typename
CTAShape
,
template
<
typename
ClusterShape
,
typename
MainloopScheduleType
,
typename
EpilogueScheduleType
,
typename
ElementType
,
typename
TileSchedulerType
=
void
,
bool
WithBias
=
false
>
typename
OutElementType
,
typename
AccumElementType
,
typename
CTAShape
,
typename
ClusterShape
,
typename
MainloopScheduleType
,
typename
EpilogueScheduleType
,
typename
TileSchedulerType
=
void
,
bool
WithBias
=
false
>
struct
DeviceGemmFp8RowwiseSm90
{
struct
DeviceGemmFp8RowwiseSm90
{
static_assert
(
std
::
is_same_v
<
ElementType
,
cutlass
::
float_e4m3_t
>
,
"ElementType must be FP8(e4m3)"
);
static_assert
(
std
::
is_same_v
<
ElementType
,
cutlass
::
float_e4m3_t
>
,
"ElementType must be FP8(e4m3)"
);
...
@@ -374,44 +496,70 @@ struct DeviceGemmFp8RowwiseSm90 {
...
@@ -374,44 +496,70 @@ struct DeviceGemmFp8RowwiseSm90 {
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
// Kernel to launch based on the default
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
// Kernel to launch based on the default
// setting in the Collective Builder
// setting in the Collective Builder
// Implement rowwise scaling epilogue.
// Implement rowwise scaling epilogue.
using
XScale
=
using
XScale
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
,
TileShape
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
0
,
cute
::
Stride
<
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>
,
cute
::
Int
<
0
>>>
;
TileShape
,
ElementComputeEpilogue
,
using
WScale
=
ElementComputeEpilogue
,
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
cute
::
Stride
<
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>
,
cute
::
Int
<
0
>>>
;
cute
::
Stride
<
cute
::
Int
<
0
>
,
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>>>
;
using
WScale
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
ElementOutput
,
ElementOutput
,
0
,
cute
::
Stride
<
cute
::
Int
<
0
>
,
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>>>
;
TileShape
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
cute
::
Stride
<
cute
::
Int
<
0
>
,
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>>>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
ElementOutput
,
ElementOutput
,
cute
::
Stride
<
cute
::
Int
<
0
>
,
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>>>
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
ElementComputeEpilogue
,
// First stage output type.
cutlass
::
multiplies
,
ElementComputeEpilogue
,
// First stage input types.
ElementComputeEpilogue
,
// First stage output type.
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
ElementComputeEpilogue
,
// First stage input types.
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
WScale
,
Accum
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
WScale
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementOutput
,
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
ElementComputeEpilogue
,
// Second stage input types.
cutlass
::
multiplies
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
ElementOutput
,
ElementComputeEpilogue
,
// Second stage input types.
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute1
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
XScale
,
EVTCompute0
>
;
using
EVTCompute1
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
XScale
,
EVTCompute0
>
;
// With bias
// With bias
using
ComputeWithBias
=
using
ComputeWithBias
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementOutput
,
ElementComputeEpilogue
,
cutlass
::
multiply_add
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
ElementOutput
,
ElementComputeEpilogue
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeWithBias
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeWithBias
,
XScale
,
EVTCompute0
,
Bias
>
;
using
EVTComputeWithBias
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeWithBias
,
XScale
,
EVTCompute0
,
Bias
>
;
using
EpilogueEVT
=
typename
cutlass
::
platform
::
conditional
<
WithBias
,
EVTComputeWithBias
,
EVTCompute1
>::
type
;
using
EpilogueEVT
=
typename
cutlass
::
platform
::
conditional
<
WithBias
,
EVTComputeWithBias
,
EVTCompute1
>::
type
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
arch
::
Sm90
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementComputeEpilogue
,
ElementC
,
LayoutC
,
cutlass
::
arch
::
OpClassTensorOp
,
AlignmentC
,
ElementOutput
,
LayoutOutput
,
AlignmentOutput
,
cutlass
::
epilogue
::
TmaWarpSpecialized
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementComputeEpilogue
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementOutput
,
LayoutOutput
,
AlignmentOutput
,
cutlass
::
epilogue
::
TmaWarpSpecialized
,
EpilogueEVT
>::
CollectiveOp
;
EpilogueEVT
>::
CollectiveOp
;
using
DefaultSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
DefaultSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
...
@@ -423,22 +571,38 @@ struct DeviceGemmFp8RowwiseSm90 {
...
@@ -423,22 +571,38 @@ struct DeviceGemmFp8RowwiseSm90 {
using
FastAccum
=
FastPongSchedule
;
// Default apply Pingpong
using
FastAccum
=
FastPongSchedule
;
// Default apply Pingpong
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
ArchTag
,
TileShape
,
ClusterShape
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduleType
>::
CollectiveOp
;
MainloopScheduleType
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
CollectiveMainloop
,
CollectiveEpilogue
,
TileSchedulerType
>
;
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
TileSchedulerType
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
};
};
template
<
typename
Gemm
,
bool
WithBias
>
template
<
typename
Gemm
,
bool
WithBias
>
typename
Gemm
::
Arguments
prepare_sm90_fp8_args
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
typename
Gemm
::
Arguments
prepare_sm90_fp8_args
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
&
out
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
ElementT
=
typename
Gemm
::
ElementA
;
using
ElementT
=
typename
Gemm
::
ElementA
;
using
ElementOutput
=
typename
Gemm
::
ElementD
;
using
ElementOutput
=
typename
Gemm
::
ElementD
;
using
ElementComputeEpilogue
=
float
;
using
ElementComputeEpilogue
=
float
;
...
@@ -465,14 +629,15 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::
...
@@ -465,14 +629,15 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::
StrideB
stride_b
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
make_shape
(
n
,
k
,
1
));
StrideB
stride_b
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
make_shape
(
n
,
k
,
1
));
StrideC
stride_c
;
StrideC
stride_c
;
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
make_shape
(
m
,
n
,
1
));
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
make_shape
(
m
,
n
,
1
));
typename
Gemm
::
Arguments
args
=
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
typename
Gemm
::
Arguments
args
=
{
{
m
,
n
,
k
,
1
},
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
ptr_a
,
stride_a
,
ptr_b
,
stride_b
},
{
m
,
n
,
k
,
1
},
{{},
// epilogue.thread
{
ptr_a
,
stride_a
,
ptr_b
,
stride_b
},
nullptr
,
{{},
// epilogue.thread
stride_c
,
nullptr
,
ptr_d
,
stride_c
,
stride_d
}};
ptr_d
,
stride_d
}};
if
constexpr
(
WithBias
)
{
if
constexpr
(
WithBias
)
{
args
.
epilogue
.
thread
=
{
args
.
epilogue
.
thread
=
{
{
ptr_scales_a
},
{
ptr_scales_a
},
...
@@ -500,9 +665,13 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::
...
@@ -500,9 +665,13 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::
}
}
template
<
typename
Gemm
,
bool
WithBias
>
template
<
typename
Gemm
,
bool
WithBias
>
void
launch_sm90_fp8_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
void
launch_sm90_fp8_scaled_mm
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
&
out
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
auto
args
=
prepare_sm90_fp8_args
<
Gemm
,
WithBias
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
auto
args
=
prepare_sm90_fp8_args
<
Gemm
,
WithBias
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
Gemm
gemm_op
;
Gemm
gemm_op
;
...
@@ -519,66 +688,117 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
...
@@ -519,66 +688,117 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
)
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
)
}
}
template
<
typename
OutType
,
typename
CTAShape
,
typename
ClusterShape
,
typename
MainloopScheduleType
,
template
<
typename
TileSchedulerType
>
typename
OutType
,
void
sm90_fp8_dispatch_bias
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
typename
CTAShape
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
typename
ClusterShape
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
,
bool
fast_accum
=
true
,
typename
MainloopScheduleType
,
bool
use_persistent
=
false
)
{
typename
TileSchedulerType
>
void
sm90_fp8_dispatch_bias
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
,
bool
fast_accum
=
true
,
bool
use_persistent
=
false
)
{
using
ElementInput
=
cutlass
::
float_e4m3_t
;
using
ElementInput
=
cutlass
::
float_e4m3_t
;
using
ElementOutput
=
OutType
;
using
ElementOutput
=
OutType
;
using
AccumElementType
=
float
;
using
AccumElementType
=
float
;
using
EpilogueScheduleType
=
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
EpilogueScheduleType
=
cutlass
::
epilogue
::
TmaWarpSpecialized
;
if
(
bias
)
{
if
(
bias
)
{
using
Gemm
=
using
Gemm
=
typename
DeviceGemmFp8RowwiseSm90
<
typename
DeviceGemmFp8RowwiseSm90
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
ElementInput
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
true
>::
Gemm
;
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
true
>::
Gemm
;
return
launch_sm90_fp8_scaled_mm
<
Gemm
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
return
launch_sm90_fp8_scaled_mm
<
Gemm
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
using
Gemm
=
using
Gemm
=
typename
DeviceGemmFp8RowwiseSm90
<
typename
DeviceGemmFp8RowwiseSm90
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
ElementInput
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
false
>::
Gemm
;
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
false
>::
Gemm
;
return
launch_sm90_fp8_scaled_mm
<
Gemm
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
return
launch_sm90_fp8_scaled_mm
<
Gemm
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
}
template
<
typename
OutType
>
template
<
typename
OutType
>
void
sm90_fp8_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
void
sm90_fp8_dispatch_shape
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
&
out
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
m
=
a
.
size
(
0
);
using
FastPingpongScheduler
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
FastPingpongScheduler
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
FastBasicScheduler
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
FastBasicScheduler
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
PersistentTileScheduler
=
cutlass
::
gemm
::
PersistentScheduler
;
using
PersistentTileScheduler
=
cutlass
::
gemm
::
PersistentScheduler
;
using
BasicTileScheduler
=
void
;
using
BasicTileScheduler
=
void
;
if
(
m
<=
1
)
{
if
(
m
<=
1
)
{
return
sm90_fp8_dispatch_bias
<
OutType
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_8
,
_1
>
,
FastBasicScheduler
,
return
sm90_fp8_dispatch_bias
<
BasicTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_8
,
_1
>
,
FastBasicScheduler
,
BasicTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
if
(
m
<=
64
)
{
if
(
m
<=
64
)
{
// m in [1, 64]
// m in [1, 64]
return
sm90_fp8_dispatch_bias
<
OutType
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_4
,
_1
>
,
FastPingpongScheduler
,
return
sm90_fp8_dispatch_bias
<
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_4
,
_1
>
,
FastPingpongScheduler
,
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
m
<=
256
)
{
}
else
if
(
m
<=
256
)
{
// m in (64, 256]
// m in (64, 256]
return
sm90_fp8_dispatch_bias
<
OutType
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_1
,
_1
>
,
FastPingpongScheduler
,
return
sm90_fp8_dispatch_bias
<
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_1
,
_1
>
,
FastPingpongScheduler
,
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
m
<=
1024
)
{
}
else
if
(
m
<=
1024
)
{
// m in (256, 1024]
// m in (256, 1024]
return
sm90_fp8_dispatch_bias
<
OutType
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_1
,
_1
,
_1
>
,
FastPingpongScheduler
,
return
sm90_fp8_dispatch_bias
<
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_1
,
_1
,
_1
>
,
FastPingpongScheduler
,
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
// m in (1024, inf)
// m in (1024, inf)
return
sm90_fp8_dispatch_bias
<
OutType
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
FastPingpongScheduler
,
return
sm90_fp8_dispatch_bias
<
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
OutType
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
FastPingpongScheduler
,
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
}
#endif
#endif
torch
::
Tensor
fp8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
torch
::
Tensor
fp8_scaled_mm
(
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Tensor
&
mat_a
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
...
@@ -587,10 +807,10 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat
...
@@ -587,10 +807,10 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_a must be a column major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_a must be a column major tensor"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
(
mat_a
.
size
(
1
)
*
mat_a
.
element_size
())
%
16
==
0
,
TORCH_CHECK
(
"mat_a must be multiple of 16 bytes for memory alignment"
);
(
mat_a
.
size
(
1
)
*
mat_a
.
element_size
())
%
16
==
0
,
"mat_a must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
(
mat_b
.
size
(
0
)
*
mat_b
.
element_size
())
%
16
==
0
,
TORCH_CHECK
(
"mat_b must be multiple of 16 bytes for memory alignment"
);
(
mat_b
.
size
(
0
)
*
mat_b
.
element_size
())
%
16
==
0
,
"mat_b must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
mat_a
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_a must be Float8_e4m3fn"
);
TORCH_CHECK
(
mat_a
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_a must be Float8_e4m3fn"
);
TORCH_CHECK
(
mat_b
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_b must be Float8_e4m3fn"
);
TORCH_CHECK
(
mat_b
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_b must be Float8_e4m3fn"
);
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"out_dtype must be Half or BFloat16"
);
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"out_dtype must be Half or BFloat16"
);
...
...
sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
View file @
d052f4c8
...
@@ -35,11 +35,20 @@ limitations under the License.
...
@@ -35,11 +35,20 @@ limitations under the License.
using
namespace
cute
;
using
namespace
cute
;
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
ThreadblockShape
,
typename
WarpShape
,
template
<
typename
InstructionShape
,
int
NumStages
>
typename
ElementOutput
,
void
cutlass_int8_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
typename
ArchTag
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
typename
ThreadblockShape
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
typename
WarpShape
,
typename
InstructionShape
,
int
NumStages
>
void
cutlass_int8_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
ElementAccumulator
=
int32_t
;
using
ElementAccumulator
=
int32_t
;
using
ElementCompute
=
float
;
using
ElementCompute
=
float
;
using
ElementInputA
=
int8_t
;
using
ElementInputA
=
int8_t
;
...
@@ -48,30 +57,51 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
...
@@ -48,30 +57,51 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
ThreadblockSwizzle
=
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
;
using
ThreadblockSwizzle
=
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
;
using
DefaultGemmConf
=
cutlass
::
gemm
::
device
::
DefaultGemmConfiguration
<
OperatorClass
,
ArchTag
,
ElementInputA
,
using
DefaultGemmConf
=
cutlass
::
gemm
::
device
::
ElementInputB
,
ElementOutput
,
ElementCompute
>
;
DefaultGemmConfiguration
<
OperatorClass
,
ArchTag
,
ElementInputA
,
ElementInputB
,
ElementOutput
,
ElementCompute
>
;
using
EpilogueOutputOp
=
typename
DefaultGemmConf
::
EpilogueOutputOp
;
using
EpilogueOutputOp
=
typename
DefaultGemmConf
::
EpilogueOutputOp
;
using
GemmKernel_
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemm
<
using
GemmKernel_
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemm
<
ElementInputA
,
cutlass
::
layout
::
RowMajor
,
DefaultGemmConf
::
kAlignmentA
,
ElementInputB
,
ElementInputA
,
cutlass
::
layout
::
ColumnMajor
,
DefaultGemmConf
::
kAlignmentB
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
OperatorClass
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
EpilogueOutputOp
,
DefaultGemmConf
::
kAlignmentA
,
ThreadblockSwizzle
,
NumStages
,
true
,
typename
DefaultGemmConf
::
Operator
>::
GemmKernel
;
ElementInputB
,
cutlass
::
layout
::
ColumnMajor
,
DefaultGemmConf
::
kAlignmentB
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
OperatorClass
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
EpilogueOutputOp
,
ThreadblockSwizzle
,
NumStages
,
true
,
typename
DefaultGemmConf
::
Operator
>::
GemmKernel
;
using
AlphaColTileIterator
=
cutlass
::
epilogue
::
threadblock
::
PredicatedTileIterator
<
using
AlphaColTileIterator
=
cutlass
::
epilogue
::
threadblock
::
PredicatedTileIterator
<
cutlass
::
epilogue
::
threadblock
::
OutputTileOptimalThreadMap
<
cutlass
::
epilogue
::
threadblock
::
OutputTileOptimalThreadMap
<
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
Shape
,
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
Shape
,
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
Count
,
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
Count
,
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
kThreads
,
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
kThreads
,
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
kElementsPerAccess
,
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
>
,
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
kElementsPerAccess
,
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
>
,
ElementCompute
>
;
ElementCompute
>
;
using
EpilogueVisitor
=
typename
cutlass
::
epilogue
::
threadblock
::
EpilogueVisitorPerRowPerCol
<
using
EpilogueVisitor
=
typename
cutlass
::
epilogue
::
threadblock
::
EpilogueVisitorPerRowPerCol
<
ThreadblockShape
,
GemmKernel_
::
kThreadCount
,
AlphaColTileIterator
,
ThreadblockShape
,
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
,
ElementAccumulator
,
ElementCompute
,
EpilogueOutputOp
>
;
GemmKernel_
::
kThreadCount
,
AlphaColTileIterator
,
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
,
ElementAccumulator
,
ElementCompute
,
EpilogueOutputOp
>
;
using
Epilogue
=
typename
cutlass
::
epilogue
::
threadblock
::
EpilogueWithVisitorFromExistingEpilogue
<
using
Epilogue
=
typename
cutlass
::
epilogue
::
threadblock
::
EpilogueVisitor
,
typename
GemmKernel_
::
Epilogue
>::
Epilogue
;
EpilogueWithVisitorFromExistingEpilogue
<
EpilogueVisitor
,
typename
GemmKernel_
::
Epilogue
>::
Epilogue
;
using
GemmKernel
=
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmWithEpilogueVisitor
<
typename
GemmKernel_
::
Mma
,
Epilogue
,
ThreadblockSwizzle
>
;
cutlass
::
gemm
::
kernel
::
GemmWithEpilogueVisitor
<
typename
GemmKernel_
::
Mma
,
Epilogue
,
ThreadblockSwizzle
>
;
...
@@ -104,98 +134,164 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
...
@@ -104,98 +134,164 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
typename
EpilogueOutputOp
::
Params
linearScalingParams
;
typename
EpilogueOutputOp
::
Params
linearScalingParams
;
typename
EpilogueVisitor
::
Arguments
visitor_args
{
linearScalingParams
};
typename
EpilogueVisitor
::
Arguments
visitor_args
{
linearScalingParams
};
typename
Gemm
::
Arguments
args
{
{
m
,
n
,
k
},
{
a_ptr
,
lda
},
{
b_ptr
,
ldb
},
{
b_s_ptr
,
0
},
typename
Gemm
::
Arguments
args
{
{
a_s_ptr
,
0
},
{
bias_ptr
,
ldc
},
{
o_ptr
,
ldd
},
visitor_args
};
{
m
,
n
,
k
},
{
a_ptr
,
lda
},
{
b_ptr
,
ldb
},
{
b_s_ptr
,
0
},
{
a_s_ptr
,
0
},
{
bias_ptr
,
ldc
},
{
o_ptr
,
ldd
},
visitor_args
};
auto
workspace
=
torch
::
empty
(
gemm_op
.
get_workspace_size
(
args
),
auto
workspace
=
torch
::
empty
(
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
mat_a
.
device
()));
gemm_op
.
get_workspace_size
(
args
),
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
mat_a
.
device
()));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
mat_a
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
mat_a
.
get_device
());
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
,
TORCH_CHECK
(
"gemm cannot implement, error: "
,
cutlassGetStatusString
(
can_implement
));
can_implement
==
cutlass
::
Status
::
kSuccess
,
"gemm cannot implement, error: "
,
cutlassGetStatusString
(
can_implement
));
auto
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
auto
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"gemm executioin failed, error: "
,
cutlassGetStatusString
(
status
));
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"gemm executioin failed, error: "
,
cutlassGetStatusString
(
status
));
}
}
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
InstructionShape
>
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
InstructionShape
>
void
sm75_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
void
sm75_dispatch_shape
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
&
out
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
int
m
=
mat_a
.
size
(
0
);
int
m
=
mat_a
.
size
(
0
);
if
(
m
<=
32
)
{
if
(
m
<=
32
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
m
<=
64
)
{
}
else
if
(
m
<=
64
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
m
<=
256
)
{
}
else
if
(
m
<=
256
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
128
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
}
}
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
InstructionShape
>
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
InstructionShape
>
void
sm80_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
void
sm80_dispatch_shape
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
&
out
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
int
m
=
mat_a
.
size
(
0
);
int
m
=
mat_a
.
size
(
0
);
int
n
=
mat_b
.
size
(
1
);
int
n
=
mat_b
.
size
(
1
);
if
(
m
<=
16
)
{
if
(
m
<=
16
)
{
if
(
n
<=
4096
)
{
if
(
n
<=
4096
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
InstructionShape
,
6
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
InstructionShape
,
6
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
if
(
m
<=
32
)
{
}
else
if
(
m
<=
32
)
{
if
(
n
<=
4096
)
{
if
(
n
<=
4096
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
6
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
6
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
if
(
m
<=
64
)
{
}
else
if
(
m
<=
64
)
{
if
(
n
<=
4096
)
{
if
(
n
<=
4096
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
if
(
m
<=
128
&&
n
<
8192
)
{
}
else
if
(
m
<=
128
&&
n
<
8192
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass_int8_scaled_mm
<
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
ElementOutput
,
scales_b
,
bias
);
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
}
}
template
<
typename
ElementOutput
,
typename
TileShape
,
typename
ClusterShape
,
typename
MainloopScheduleType
,
template
<
bool
WithBias
>
typename
ElementOutput
,
void
cutlass_int8_scaled_mm_sm90
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
typename
TileShape
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
typename
ClusterShape
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
typename
MainloopScheduleType
,
bool
WithBias
>
void
cutlass_int8_scaled_mm_sm90
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
ElementAccumulator
=
int32_t
;
using
ElementAccumulator
=
int32_t
;
...
@@ -213,50 +309,75 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
...
@@ -213,50 +309,75 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
using
EpilogueScheduleType
=
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
EpilogueScheduleType
=
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileSchedulerType
=
cutlass
::
gemm
::
PersistentScheduler
;
using
TileSchedulerType
=
cutlass
::
gemm
::
PersistentScheduler
;
using
XScale
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
,
TileShape
,
ElementCompute
,
ElementCompute
,
using
XScale
=
cutlass
::
epilogue
::
fusion
::
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
Sm90ColBroadcast
<
0
,
TileShape
,
ElementCompute
,
ElementCompute
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
WScale
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
ElementCompute
,
ElementCompute
,
using
WScale
=
cutlass
::
epilogue
::
fusion
::
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
Sm90RowBroadcast
<
0
,
TileShape
,
ElementCompute
,
ElementCompute
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
ElementOutput
,
ElementOutput
,
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
Sm90RowBroadcast
<
0
,
TileShape
,
ElementOutput
,
ElementOutput
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
// Scale
// Scale
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementCompute
,
ElementCompute
,
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
Sm90Compute
<
cutlass
::
multiplies
,
ElementCompute
,
ElementCompute
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
WScale
,
Accum
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
WScale
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementOutput
,
ElementCompute
,
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
Sm90Compute
<
cutlass
::
multiplies
,
ElementOutput
,
ElementCompute
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute1
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
XScale
,
EVTCompute0
>
;
using
EVTCompute1
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
XScale
,
EVTCompute0
>
;
// With bias
// With bias
using
ComputeWithBias
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementOutput
,
ElementCompute
,
using
ComputeWithBias
=
cutlass
::
epilogue
::
fusion
::
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
Sm90Compute
<
cutlass
::
multiply_add
,
ElementOutput
,
ElementCompute
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeWithBias
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeWithBias
,
XScale
,
EVTCompute0
,
Bias
>
;
using
EVTComputeWithBias
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeWithBias
,
XScale
,
EVTCompute0
,
Bias
>
;
using
EpilogueEVT
=
typename
cutlass
::
platform
::
conditional
<
WithBias
,
EVTComputeWithBias
,
EVTCompute1
>::
type
;
using
EpilogueEVT
=
typename
cutlass
::
platform
::
conditional
<
WithBias
,
EVTComputeWithBias
,
EVTCompute1
>::
type
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ArchTag
,
ElementAccumulator
,
ElementCompute
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
AlignmentC
,
ElementOutput
,
OperatorClass
,
cutlass
::
layout
::
RowMajor
,
AlignmentOutput
,
EpilogueScheduleType
,
EpilogueEVT
>::
CollectiveOp
;
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
AlignmentC
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
AlignmentOutput
,
EpilogueScheduleType
,
EpilogueEVT
>::
CollectiveOp
;
using
Stages
=
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
using
Stages
=
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
;
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementInputA
,
cutlass
::
layout
::
RowMajor
,
AlignmentA
,
ElementInputB
,
ArchTag
,
cutlass
::
layout
::
ColumnMajor
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
Stages
,
OperatorClass
,
ElementInputA
,
cutlass
::
layout
::
RowMajor
,
AlignmentA
,
ElementInputB
,
cutlass
::
layout
::
ColumnMajor
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
Stages
,
MainloopScheduleType
>::
CollectiveOp
;
MainloopScheduleType
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
CollectiveMainloop
,
CollectiveEpilogue
,
TileSchedulerType
>
;
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
TileSchedulerType
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
...
@@ -283,14 +404,15 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
...
@@ -283,14 +404,15 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
StrideC
stride_c
;
StrideC
stride_c
;
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
make_shape
(
m
,
n
,
1
));
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
make_shape
(
m
,
n
,
1
));
typename
Gemm
::
Arguments
args
=
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
typename
Gemm
::
Arguments
args
=
{
{
m
,
n
,
k
,
1
},
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
a_ptr
,
stride_a
,
b_ptr
,
stride_b
},
{
m
,
n
,
k
,
1
},
{{},
// epilogue.thread
{
a_ptr
,
stride_a
,
b_ptr
,
stride_b
},
nullptr
,
{{},
// epilogue.thread
stride_c
,
nullptr
,
o_ptr
,
stride_c
,
stride_d
}};
o_ptr
,
stride_d
}};
if
constexpr
(
WithBias
)
{
if
constexpr
(
WithBias
)
{
ElementOutput
*
bias_ptr
=
static_cast
<
ElementOutput
*>
(
bias
->
data_ptr
());
ElementOutput
*
bias_ptr
=
static_cast
<
ElementOutput
*>
(
bias
->
data_ptr
());
...
@@ -308,23 +430,29 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
...
@@ -308,23 +430,29 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
};
};
}
}
auto
workspace
=
torch
::
empty
(
gemm_op
.
get_workspace_size
(
args
),
auto
workspace
=
torch
::
empty
(
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
mat_a
.
device
()));
gemm_op
.
get_workspace_size
(
args
),
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
mat_a
.
device
()));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
mat_a
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
mat_a
.
get_device
());
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
,
TORCH_CHECK
(
"gemm cannot implement, error: "
,
cutlassGetStatusString
(
can_implement
));
can_implement
==
cutlass
::
Status
::
kSuccess
,
"gemm cannot implement, error: "
,
cutlassGetStatusString
(
can_implement
));
auto
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
auto
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"gemm executioin failed, error: "
,
cutlassGetStatusString
(
status
));
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"gemm executioin failed, error: "
,
cutlassGetStatusString
(
status
));
}
}
template
<
typename
ElementOutput
,
typename
TileShape
,
typename
ClusterShape
,
typename
MainloopScheduleType
>
template
<
typename
ElementOutput
,
typename
TileShape
,
typename
ClusterShape
,
typename
MainloopScheduleType
>
void
sm90_dispatch_bias
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
void
sm90_dispatch_bias
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
&
out
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
if
(
bias
)
{
if
(
bias
)
{
cutlass_int8_scaled_mm_sm90
<
ElementOutput
,
TileShape
,
ClusterShape
,
MainloopScheduleType
,
true
>
(
cutlass_int8_scaled_mm_sm90
<
ElementOutput
,
TileShape
,
ClusterShape
,
MainloopScheduleType
,
true
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
...
@@ -335,45 +463,73 @@ void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const to
...
@@ -335,45 +463,73 @@ void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const to
}
}
template
<
typename
ElementOutput
>
template
<
typename
ElementOutput
>
void
sm90_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
void
sm90_dispatch_shape
(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
torch
::
Tensor
&
out
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
int
m
=
mat_a
.
size
(
0
);
int
m
=
mat_a
.
size
(
0
);
int
n
=
mat_b
.
size
(
1
);
int
n
=
mat_b
.
size
(
1
);
if
(
m
<=
32
)
{
if
(
m
<=
32
)
{
if
(
n
<
8192
)
{
if
(
n
<
8192
)
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_8
,
_1
>
,
return
sm90_dispatch_bias
<
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
ElementOutput
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_8
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_128
,
_128
>
,
Shape
<
_1
,
_8
,
_1
>
,
return
sm90_dispatch_bias
<
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
ElementOutput
,
Shape
<
_64
,
_128
,
_128
>
,
Shape
<
_1
,
_8
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
if
(
m
<=
64
)
{
}
else
if
(
m
<=
64
)
{
if
(
n
<
8192
)
{
if
(
n
<
8192
)
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_4
,
_1
>
,
return
sm90_dispatch_bias
<
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
ElementOutput
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_4
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_64
,
_256
>
,
Shape
<
_1
,
_1
,
_1
>
,
return
sm90_dispatch_bias
<
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
ElementOutput
,
Shape
<
_64
,
_64
,
_256
>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
if
(
m
<=
128
)
{
}
else
if
(
m
<=
128
)
{
if
(
n
<=
4096
)
{
if
(
n
<=
4096
)
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
return
sm90_dispatch_bias
<
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
ElementOutput
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
}
else
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_128
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
return
sm90_dispatch_bias
<
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
ElementOutput
,
Shape
<
_64
,
_128
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
}
else
{
}
else
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
return
sm90_dispatch_bias
<
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
ElementOutput
,
bias
);
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
}
}
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Tensor
&
mat_a
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
...
...
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
View file @
d052f4c8
...
@@ -8,8 +8,8 @@
...
@@ -8,8 +8,8 @@
#include "utils.h"
#include "utils.h"
template
<
typename
T
>
template
<
typename
T
>
__global__
void
per_tensor_absmax_kernel
(
const
T
*
__restrict__
input
,
float
*
__restrict__
output_s
,
__global__
void
const
int64_t
num_elements
)
{
per_tensor_absmax_kernel
(
const
T
*
__restrict__
input
,
float
*
__restrict__
output_s
,
const
int64_t
num_elements
)
{
float
max_value
=
0.0
f
;
float
max_value
=
0.0
f
;
unsigned
int
tid
=
threadIdx
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
unsigned
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
unsigned
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
@@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r
...
@@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
per_tensor_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output
,
__global__
void
per_tensor_quant_fp8_kernel
(
const
float
*
__restrict__
scale
,
const
int64_t
num_elements
)
{
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output
,
const
float
*
__restrict__
scale
,
const
int64_t
num_elements
)
{
const
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
const
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
const
float
scale_val
=
1.0
f
/
(
*
scale
);
const
float
scale_val
=
1.0
f
/
(
*
scale
);
...
@@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
...
@@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
}
}
per_tensor_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
per_tensor_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
num_elements
);
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
num_elements
);
return
true
;
return
true
;
});
});
}
}
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
View file @
d052f4c8
...
@@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) {
...
@@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) {
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
per_token_group_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
__global__
void
per_token_group_quant_fp8_kernel
(
float
*
__restrict__
output_s
,
const
int
group_size
,
const
T
*
__restrict__
input
,
const
int
num_groups
,
const
float
eps
,
const
float
fp8_min
,
void
*
__restrict__
output_q
,
const
float
fp8_max
)
{
float
*
__restrict__
output_s
,
const
int
group_size
,
const
int
num_groups
,
const
float
eps
,
const
float
fp8_min
,
const
float
fp8_max
)
{
const
int
groups_per_block
=
16
;
const
int
groups_per_block
=
16
;
const
int
local_group_id
=
threadIdx
.
x
/
16
;
const
int
local_group_id
=
threadIdx
.
x
/
16
;
const
int
lane_id
=
threadIdx
.
x
%
16
;
const
int
lane_id
=
threadIdx
.
x
%
16
;
...
@@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
...
@@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
}
}
}
}
void
sgl_per_token_group_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
void
sgl_per_token_group_quant_fp8
(
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
)
{
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_s
);
CHECK_INPUT
(
output_s
);
...
@@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q,
...
@@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q,
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
per_token_group_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
per_token_group_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
output_q
.
data_ptr
(),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
group_size
,
num_groups
,
(
float
)
eps
,
(
float
)
fp8_min
,
(
float
)
fp8_max
);
output_q
.
data_ptr
(),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
group_size
,
num_groups
,
(
float
)
eps
,
(
float
)
fp8_min
,
(
float
)
fp8_max
);
return
true
;
return
true
;
});
});
}
}
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
View file @
d052f4c8
...
@@ -7,9 +7,12 @@
...
@@ -7,9 +7,12 @@
#include "utils.h"
#include "utils.h"
template
<
typename
T
>
template
<
typename
T
>
__global__
void
per_token_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output_q
,
__global__
void
per_token_quant_fp8_kernel
(
float
*
__restrict__
output_s
,
const
int64_t
hidden_dim
,
const
T
*
__restrict__
input
,
const
int64_t
num_tokens
)
{
FP8_TYPE
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int64_t
hidden_dim
,
const
int64_t
num_tokens
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
if
(
token_idx
>=
num_tokens
)
return
;
if
(
token_idx
>=
num_tokens
)
return
;
...
@@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
...
@@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
per_token_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
per_token_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
return
true
;
return
true
;
});
});
}
}
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
d052f4c8
...
@@ -25,9 +25,11 @@ limitations under the License.
...
@@ -25,9 +25,11 @@ limitations under the License.
#define WARP_SIZE 32
#define WARP_SIZE 32
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
__global__
void
count_and_sort_expert_tokens_kernel
(
int32_t
*
__restrict__
sorted_token_ids
,
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
size_t
numel
)
{
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
size_t
numel
)
{
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
...
@@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__
...
@@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
__global__
void
moe_align_block_size_kernel
(
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
)
{
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
)
{
__shared__
int32_t
shared_counts
[
WARP_SIZE
][
8
];
__shared__
int32_t
shared_counts
[
WARP_SIZE
][
8
];
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
...
@@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
...
@@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
}
}
}
}
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
void
moe_align_block_size
(
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
topk_ids
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
TORCH_CHECK
(
num_experts
==
256
,
"moe_align_block_size kernel only support deepseek v3 now."
);
TORCH_CHECK
(
num_experts
==
256
,
"moe_align_block_size kernel only support deepseek v3 now."
);
DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
align_kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
align_kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
const
int
block_threads
=
256
;
const
int
block_threads
=
256
;
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
...
@@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
...
@@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
auto
sort_kernel
=
count_and_sort_expert_tokens_kernel
<
scalar_t
>
;
auto
sort_kernel
=
count_and_sort_expert_tokens_kernel
<
scalar_t
>
;
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
});
});
}
}
sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu
View file @
d052f4c8
...
@@ -23,10 +23,18 @@
...
@@ -23,10 +23,18 @@
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
__global__
void
build_tree_efficient
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
__global__
void
build_tree_efficient
(
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int64_t
*
parent_list
,
int64_t
*
retrive_next_token
,
int64_t
*
retrive_next_sibling
,
int
topk
,
int
depth
,
int64_t
*
selected_index
,
int
draft_token_num
)
{
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int64_t
*
retrive_next_token
,
int64_t
*
retrive_next_sibling
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
int
bid
=
blockIdx
.
x
;
int
bid
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
...
@@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind
...
@@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind
}
}
}
}
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
void
build_tree_kernel_efficient
(
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
parent_list
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
at
::
Tensor
selected_index
,
int64_t
depth
,
int64_t
draft_token_num
)
{
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
// TODO (ying) check shape
// TODO (ying) check shape
// TODO (ying) check type
// TODO (ying) check type
int
bs
=
parent_list
.
size
(
0
);
int
bs
=
parent_list
.
size
(
0
);
...
@@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
...
@@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
build_tree_efficient
<<<
grid
,
block
,
0
,
stream
>>>
(
build_tree_efficient
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_sibling
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
}
}
// parent_list [bs, topk * (depth - 1) + 1)]
// parent_list [bs, topk * (depth - 1) + 1)]
...
@@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
...
@@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2]
// draft_token, depth + 2]
__global__
void
build_tree
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
__global__
void
build_tree
(
int64_t
*
positions
,
int64_t
*
retrive_index
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
int
bid
=
blockIdx
.
x
;
int
bid
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
...
@@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_
...
@@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_
}
}
}
}
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
void
build_tree_kernel
(
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
at
::
Tensor
parent_list
,
int64_t
depth
,
int64_t
draft_token_num
)
{
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
// TODO (ying) check shape
// TODO (ying) check shape
// TODO (ying) check type
// TODO (ying) check type
int
bs
=
parent_list
.
size
(
0
);
int
bs
=
parent_list
.
size
(
0
);
...
@@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te
...
@@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
build_tree
<<<
grid
,
block
,
0
,
stream
>>>
(
build_tree
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
int32_t
(
topk
),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
}
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment