Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0f3eb1d2
Unverified
Commit
0f3eb1d2
authored
Jan 06, 2025
by
Ke Bao
Committed by
GitHub
Jan 06, 2025
Browse files
Support cutlass Int8 gemm (#2752)
parent
06dd2eab
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1434 additions
and
0 deletions
+1434
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/benchmark/bench_int8_gemm.py
sgl-kernel/benchmark/bench_int8_gemm.py
+55
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+2
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+2
-0
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
...lass_extensions/epilogue/epilogue_per_row_per_col_scale.h
+278
-0
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
...csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
+346
-0
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
...csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
+456
-0
sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
+209
-0
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+7
-0
sgl-kernel/src/sgl-kernel/csrc/utils.hpp
sgl-kernel/src/sgl-kernel/csrc/utils.hpp
+10
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+12
-0
sgl-kernel/tests/test_int8_gemm.py
sgl-kernel/tests/test_int8_gemm.py
+56
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
0f3eb1d2
...
@@ -31,6 +31,7 @@ add_library(_kernels SHARED
...
@@ -31,6 +31,7 @@ add_library(_kernels SHARED
src/sgl-kernel/csrc/trt_reduce_internal.cu
src/sgl-kernel/csrc/trt_reduce_internal.cu
src/sgl-kernel/csrc/trt_reduce_kernel.cu
src/sgl-kernel/csrc/trt_reduce_kernel.cu
src/sgl-kernel/csrc/moe_align_kernel.cu
src/sgl-kernel/csrc/moe_align_kernel.cu
src/sgl-kernel/csrc/int8_gemm_kernel.cu
src/sgl-kernel/csrc/sgl_kernel_ops.cu
src/sgl-kernel/csrc/sgl_kernel_ops.cu
)
)
...
...
sgl-kernel/benchmark/bench_int8_gemm.py
0 → 100644
View file @
0f3eb1d2
import
torch
import
triton
from
sgl_kernel
import
int8_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
def
to_int8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"sgl-kernel"
],
line_names
=
[
"vllm int8 gemm"
,
"sgl-kernel int8 gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
ylabel
=
"GB/s"
,
plot_name
=
"int8 scaled matmul"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
):
M
,
N
,
K
=
batch_size
,
4096
,
8192
a
=
to_int8
(
torch
.
randn
((
M
,
K
),
device
=
"cuda"
)
*
5
)
b
=
to_int8
(
torch
.
randn
((
N
,
K
),
device
=
"cuda"
).
t
()
*
5
)
scale_a
=
torch
.
randn
((
M
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"sgl-kernel"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
int8_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
),
quantiles
=
quantiles
,
)
if
provider
==
"vllm"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
vllm_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
),
quantiles
=
quantiles
,
)
gbps
=
(
lambda
ms
:
(
(
2
*
M
*
N
*
K
-
M
*
N
)
*
a
.
element_size
()
+
(
3
*
M
*
N
)
*
scale_a
.
element_size
()
)
*
1e-9
/
(
ms
*
1e-3
)
)
return
gbps
(
ms
),
gbps
(
max_ms
),
gbps
(
min_ms
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_int8_res"
)
sgl-kernel/setup.py
View file @
0f3eb1d2
...
@@ -26,6 +26,7 @@ cutlass = root / "3rdparty" / "cutlass"
...
@@ -26,6 +26,7 @@ cutlass = root / "3rdparty" / "cutlass"
include_dirs
=
[
include_dirs
=
[
cutlass
.
resolve
()
/
"include"
,
cutlass
.
resolve
()
/
"include"
,
cutlass
.
resolve
()
/
"tools"
/
"util"
/
"include"
,
cutlass
.
resolve
()
/
"tools"
/
"util"
/
"include"
,
root
/
"src"
/
"sgl-kernel"
/
"csrc"
,
]
]
nvcc_flags
=
[
nvcc_flags
=
[
"-O3"
,
"-O3"
,
...
@@ -48,6 +49,7 @@ ext_modules = [
...
@@ -48,6 +49,7 @@ ext_modules = [
"src/sgl-kernel/csrc/trt_reduce_internal.cu"
,
"src/sgl-kernel/csrc/trt_reduce_internal.cu"
,
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
],
],
include_dirs
=
include_dirs
,
include_dirs
=
include_dirs
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
0f3eb1d2
...
@@ -2,6 +2,7 @@ from sgl_kernel.ops import (
...
@@ -2,6 +2,7 @@ from sgl_kernel.ops import (
custom_dispose
,
custom_dispose
,
custom_reduce
,
custom_reduce
,
init_custom_reduce
,
init_custom_reduce
,
int8_scaled_mm
,
moe_align_block_size
,
moe_align_block_size
,
)
)
...
@@ -10,4 +11,5 @@ __all__ = [
...
@@ -10,4 +11,5 @@ __all__ = [
"init_custom_reduce"
,
"init_custom_reduce"
,
"custom_dispose"
,
"custom_dispose"
,
"custom_reduce"
,
"custom_reduce"
,
"int8_scaled_mm"
,
]
]
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
0 → 100644
View file @
0f3eb1d2
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
#pragma once
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/numeric_conversion.h"
namespace
cutlass
{
namespace
epilogue
{
namespace
threadblock
{
template
<
typename
ThreadblockShape_
,
int
ThreadCount
,
typename
ScaleTileIterator_
,
typename
OutputTileIterator_
,
typename
ElementAccumulator_
,
typename
ElementCompute_
,
typename
ElementwiseFunctor_
,
bool
UseMasking_
=
false
>
class
EpilogueVisitorPerRowPerCol
{
public:
using
ThreadblockShape
=
ThreadblockShape_
;
static
int
const
kThreadCount
=
ThreadCount
;
using
ScaleTileIterator
=
ScaleTileIterator_
;
using
OutputTileIterator
=
OutputTileIterator_
;
using
ElementwiseFunctor
=
ElementwiseFunctor_
;
static
int
const
kIterations
=
OutputTileIterator
::
kIterations
;
static
int
const
kElementsPerAccess
=
OutputTileIterator
::
kElementsPerAccess
;
using
ElementOutput
=
typename
OutputTileIterator
::
Element
;
using
LayoutOutput
=
cutlass
::
layout
::
RowMajor
;
using
ElementAccumulator
=
ElementAccumulator_
;
using
AlphaScaleElementType
=
typename
ScaleTileIterator
::
Element
;
using
ElementCompute
=
ElementCompute_
;
using
AccumulatorFragment
=
Array
<
ElementAccumulator
,
kElementsPerAccess
>
;
using
ComputeFragment
=
Array
<
ElementCompute_
,
kElementsPerAccess
>
;
using
OutputVector
=
Array
<
ElementOutput
,
kElementsPerAccess
>
;
static
int
const
kThreadsPerRow
=
OutputTileIterator
::
ThreadMap
::
Detail
::
kAccessWidth
;
static
bool
const
kHasMultiStepsInRow
=
(
OutputTileIterator
::
ThreadMap
::
Iterations
::
kColumn
>
1
);
/// Argument structure
struct
Arguments
{
typename
ElementwiseFunctor
::
Params
elementwise
;
int64_t
batch_stride_alpha
;
int64_t
batch_stride_C
;
int64_t
batch_stride_D
;
//
// Methods
//
Arguments
()
:
batch_stride_alpha
(
0
),
batch_stride_C
(
0
),
batch_stride_D
(
0
)
{}
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
)
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
0
),
batch_stride_C
(
0
),
batch_stride_D
(
0
)
{}
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
,
int64_t
batch_stride_alpha_
,
int64_t
batch_stride_C_
,
int64_t
batch_stride_D_
)
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
batch_stride_alpha_
),
batch_stride_C
(
batch_stride_C_
),
batch_stride_D
(
batch_stride_D_
)
{}
};
struct
Params
{
typename
ElementwiseFunctor
::
Params
elementwise
;
int64_t
batch_stride_alpha
;
int64_t
batch_stride_C
;
int64_t
batch_stride_D
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
{}
CUTLASS_HOST_DEVICE
Params
(
Arguments
const
&
args
)
:
elementwise
(
args
.
elementwise
),
batch_stride_alpha
(
args
.
batch_stride_alpha
),
batch_stride_C
(
args
.
batch_stride_C
),
batch_stride_D
(
args
.
batch_stride_D
)
{}
};
/// Shared storage
struct
SharedStorage
{};
private:
Params
const
&
params_
;
SharedStorage
&
shared_storage_
;
MatrixCoord
extent_
;
MatrixCoord
extent_real_
;
ElementwiseFunctor
elementwise_
;
bool
const
with_bias_
;
bool
const
per_token_quant_
;
bool
const
per_channel_quant_
;
AlphaScaleElementType
*
ptr_alpha_row_
;
AlphaScaleElementType
*
ptr_alpha_col_
;
ScaleTileIterator
iterator_alpha_col_
;
OutputTileIterator
iterator_C_
;
OutputTileIterator
iterator_D_
;
AlphaScaleElementType
element_alpha_row_
=
1.0
f
;
AlphaScaleElementType
element_alpha_col_
=
1.0
f
;
typename
ScaleTileIterator
::
Fragment
fragment_alpha_col_
;
typename
OutputTileIterator
::
Fragment
fragment_C_
;
typename
OutputTileIterator
::
Fragment
fragment_D_
;
ElementAccumulator
beta_
;
int
column_offset_
;
MatrixCoord
thread_offset_
;
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol
(
Params
const
&
params
,
SharedStorage
&
shared_storage
,
cutlass
::
MatrixCoord
const
&
problem_size
,
int
thread_idx
,
int
warp_idx
,
int
lane_idx
,
typename
ScaleTileIterator
::
Params
params_alpha_col
,
typename
OutputTileIterator
::
Params
params_C
,
typename
OutputTileIterator
::
Params
params_D
,
bool
with_bias
,
bool
per_token_quant
,
bool
per_channel_quant
,
AlphaScaleElementType
*
ptr_alpha_row
,
AlphaScaleElementType
*
ptr_alpha_col
,
typename
OutputTileIterator
::
Element
*
ptr_C
,
typename
OutputTileIterator
::
Element
*
ptr_D
,
cutlass
::
MatrixCoord
const
&
threadblock_offset
=
cutlass
::
MatrixCoord
(
0
,
0
),
int
column_offset
=
0
,
cutlass
::
MatrixCoord
const
&
problem_size_real
=
cutlass
::
MatrixCoord
(
0
,
0
))
:
params_
(
params
),
shared_storage_
(
shared_storage
),
extent_
(
problem_size
),
elementwise_
(
params
.
elementwise
),
with_bias_
(
with_bias
),
per_token_quant_
(
per_token_quant
),
per_channel_quant_
(
per_channel_quant
),
ptr_alpha_row_
(
ptr_alpha_row
),
ptr_alpha_col_
(
ptr_alpha_col
),
iterator_alpha_col_
(
params_alpha_col
,
ptr_alpha_col
,
problem_size
,
thread_idx
,
threadblock_offset
),
iterator_C_
(
params_C
,
ptr_C
,
problem_size
,
thread_idx
,
threadblock_offset
),
iterator_D_
(
params_D
,
ptr_D
,
problem_size
,
thread_idx
,
threadblock_offset
),
extent_real_
(
problem_size_real
)
{
if
(
!
per_channel_quant_
&&
(
ptr_alpha_col_
!=
nullptr
))
{
element_alpha_col_
=
*
ptr_alpha_col_
;
}
if
(
!
per_token_quant_
&&
(
ptr_alpha_row_
!=
nullptr
))
{
element_alpha_row_
=
*
ptr_alpha_row_
;
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void
set_k_partition
(
int
split_k_index
,
///< Index of this threadblock within split-K partitioned scheme
int
split_k_slices
)
{
///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void
set_batch_index
(
int
batch_idx
)
{
iterator_alpha_col_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_alpha
);
iterator_C_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_C
);
iterator_D_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_D
);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void
begin_epilogue
()
{
if
(
per_channel_quant_
)
{
iterator_alpha_col_
.
load
(
fragment_alpha_col_
);
}
if
(
with_bias_
)
{
iterator_C_
.
load
(
fragment_C_
);
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void
begin_step
(
int
step_idx
)
{
fragment_D_
.
clear
();
}
/// Called at the start of a row
CUTLASS_DEVICE
void
begin_row
(
int
row_idx
)
{
// load alpha_row in begin_step only when per token(row) scaling is used
if
(
per_token_quant_
)
{
int
thread_offset_row
=
iterator_D_
.
thread_start_row
()
+
OutputTileIterator
::
ThreadMap
::
iteration_offset
(
row_idx
).
row
();
arch
::
global_load
<
AlphaScaleElementType
,
sizeof
(
AlphaScaleElementType
)
>
(
element_alpha_row_
,
ptr_alpha_row_
+
thread_offset_row
,
thread_offset_row
<
extent_
.
row
());
}
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void
visit
(
int
iter_idx
,
int
row_idx
,
int
column_idx
,
int
frag_idx
,
AccumulatorFragment
const
&
accum
)
{
NumericArrayConverter
<
ElementCompute
,
ElementAccumulator
,
kElementsPerAccess
>
source_converter
;
ComputeFragment
result
=
source_converter
(
accum
);
if
(
per_channel_quant_
)
{
ComputeFragment
alpha_col
=
reinterpret_cast
<
ComputeFragment
*>
(
&
fragment_alpha_col_
)[
column_idx
];
result
=
per_token_channel_scale_accumulator_
(
result
,
alpha_col
,
element_alpha_row_
);
}
else
{
result
=
per_token_scale_accumulator_
(
result
,
element_alpha_col_
,
element_alpha_row_
);
}
if
(
with_bias_
)
{
NumericArrayConverter
<
ElementCompute
,
ElementOutput
,
kElementsPerAccess
>
bias_converter
;
OutputVector
bias
=
reinterpret_cast
<
OutputVector
*>
(
&
fragment_C_
)[
column_idx
];
result
=
bias_accumulator_
(
result
,
bias_converter
(
bias
));
}
// Convert to the output
NumericArrayConverter
<
ElementOutput
,
ElementCompute
,
kElementsPerAccess
>
output_converter
;
OutputVector
&
output
=
reinterpret_cast
<
OutputVector
*>
(
&
fragment_D_
)[
frag_idx
];
output
=
output_converter
(
result
);
}
/// Called at the end of a row
CUTLASS_DEVICE
void
end_row
(
int
row_idx
)
{}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void
end_step
(
int
step_idx
)
{
iterator_D_
.
store
(
fragment_D_
);
++
iterator_D_
;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void
end_epilogue
()
{}
private:
CUTLASS_DEVICE
ComputeFragment
per_token_channel_scale_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
result
[
i
]
=
accum
[
i
]
*
(
scale_col
[
i
]
*
scale_row
);
}
return
result
;
}
CUTLASS_DEVICE
ComputeFragment
per_token_scale_accumulator_
(
ComputeFragment
const
&
accum
,
AlphaScaleElementType
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
result
[
i
]
=
accum
[
i
]
*
(
scale_col
*
scale_row
);
}
return
result
;
}
CUTLASS_DEVICE
ComputeFragment
bias_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
bias
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
OutputVector
::
kElements
;
++
i
)
{
result
[
i
]
=
accum
[
i
]
+
bias
[
i
];
}
return
result
;
}
};
}
// namespace threadblock
}
// namespace epilogue
}
// namespace cutlass
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
0 → 100644
View file @
0f3eb1d2
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/numeric_types.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
that feature at the moment.
*/
template
<
typename
GemmKernel_
>
class
GemmUniversalBaseCompat
{
public:
using
GemmKernel
=
GemmKernel_
;
using
ThreadblockShape
=
typename
GemmKernel
::
Mma
::
Shape
;
using
ElementA
=
typename
GemmKernel
::
ElementA
;
using
LayoutA
=
typename
GemmKernel
::
LayoutA
;
using
TensorRefA
=
TensorRef
<
ElementA
const
,
LayoutA
>
;
static
ComplexTransform
const
kTransformA
=
GemmKernel
::
kTransformA
;
using
ElementB
=
typename
GemmKernel
::
ElementB
;
using
LayoutB
=
typename
GemmKernel
::
LayoutB
;
using
TensorRefB
=
TensorRef
<
ElementB
const
,
LayoutB
>
;
static
ComplexTransform
const
kTransformB
=
GemmKernel
::
kTransformB
;
using
ElementC
=
typename
GemmKernel
::
ElementC
;
using
LayoutC
=
typename
GemmKernel
::
LayoutC
;
using
TensorRefC
=
TensorRef
<
ElementC
const
,
LayoutC
>
;
using
TensorRefD
=
TensorRef
<
ElementC
,
LayoutC
>
;
using
ElementAccumulator
=
typename
GemmKernel
::
Mma
::
Policy
::
Operator
::
ElementC
;
using
EpilogueOutputOp
=
typename
GemmKernel
::
EpilogueOutputOp
;
using
ThreadblockSwizzle
=
typename
GemmKernel
::
ThreadblockSwizzle
;
using
Operator
=
typename
GemmKernel
::
Operator
;
/// Argument structure
using
Arguments
=
typename
GemmKernel
::
Arguments
;
protected:
/// Kernel parameters object
typename
GemmKernel
::
Params
params_
;
protected:
/// Private helper to obtain the grid dimensions with fix-up for split-K
static
void
get_grid_shape_
(
gemm
::
GemmCoord
&
grid_tiled_shape
,
int
&
gemm_k_size
,
Arguments
const
&
args
)
{
// Determine grid shape
ThreadblockSwizzle
threadblock_swizzle
;
grid_tiled_shape
=
threadblock_swizzle
.
get_tiled_shape
(
args
.
problem_size
,
{
ThreadblockShape
::
kM
,
ThreadblockShape
::
kN
,
ThreadblockShape
::
kK
},
args
.
batch_count
);
gemm_k_size
=
args
.
problem_size
.
k
();
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
||
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
int
const
kAlignK
=
const_max
(
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
),
1
);
gemm_k_size
=
round_up
(
ceil_div
(
args
.
problem_size
.
k
(),
args
.
batch_count
),
kAlignK
);
if
(
gemm_k_size
)
{
grid_tiled_shape
.
k
()
=
ceil_div
(
args
.
problem_size
.
k
(),
gemm_k_size
);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBaseCompat
()
{}
/// Determines whether the GEMM can execute the given problem.
static
Status
can_implement
(
Arguments
const
&
args
)
{
// Determine grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
ThreadblockSwizzle
threadblock_swizzle
;
dim3
grid
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
uint32_t
const
kGridYZMax
=
((
1
<<
(
sizeof
(
uint16_t
)
*
8
))
-
1
);
if
(
!
(
grid
.
y
<=
kGridYZMax
&&
grid
.
z
<=
kGridYZMax
))
{
return
Status
::
kErrorInvalidProblem
;
}
return
GemmKernel
::
can_implement
(
args
);
}
/// Gets the workspace size
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::get_workspace_size()"
);
size_t
workspace_bytes
=
0
;
// Determine grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
if
(
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
// Split-K parallel always requires a temporary workspace
workspace_bytes
=
sizeof
(
ElementC
)
*
size_t
(
args
.
batch_stride_D
)
*
size_t
(
grid_tiled_shape
.
k
());
}
else
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
&&
grid_tiled_shape
.
k
()
>
1
)
{
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes
=
sizeof
(
int
)
*
size_t
(
grid_tiled_shape
.
m
())
*
size_t
(
grid_tiled_shape
.
n
());
}
CUTLASS_TRACE_HOST
(
" workspace_bytes: "
<<
workspace_bytes
);
workspace_bytes
+=
GemmKernel
::
get_extra_workspace_size
(
args
,
grid_tiled_shape
);
return
workspace_bytes
;
}
/// Computes the grid shape
static
dim3
get_grid_shape
(
Arguments
const
&
args
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::get_grid_shape()"
);
ThreadblockSwizzle
threadblock_swizzle
;
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
dim3
result
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
CUTLASS_TRACE_HOST
(
" grid_tiled_shape: "
<<
grid_tiled_shape
<<
"
\n
"
<<
" result = {"
<<
result
<<
"}"
);
return
result
;
}
/// Computes the maximum number of active blocks per multiprocessor
static
int
maximum_active_blocks
(
int
smem_capacity
=
-
1
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::maximum_active_blocks()"
);
int
max_active_blocks
=
-
1
;
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
CUTLASS_TRACE_HOST
(
" smem_size: "
<<
smem_size
<<
" bytes"
);
if
(
smem_size
<=
(
48
<<
10
))
{
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
smem_size
);
if
(
result
==
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
return
max_active_blocks
;
}
}
else
{
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
0
);
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
if
(
smem_capacity
<
0
)
{
int
device_idx
=
0
;
result
=
cudaGetDevice
(
&
device_idx
);
if
(
result
!=
cudaSuccess
)
{
return
-
1
;
}
cudaDeviceProp
properties
;
result
=
cudaGetDeviceProperties
(
&
properties
,
device_idx
);
if
(
result
!=
cudaSuccess
)
{
return
-
1
;
}
smem_capacity
=
static_cast
<
int
>
(
properties
.
sharedMemPerMultiprocessor
);
}
int
occupancy
=
std
::
min
(
max_active_blocks
,
smem_capacity
/
smem_size
);
CUTLASS_TRACE_HOST
(
" occupancy: "
<<
occupancy
);
return
occupancy
;
}
CUTLASS_TRACE_HOST
(
" returning internal error"
);
return
-
1
;
}
/// Initializes GEMM state from arguments.
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
size_t
workspace_bytes
=
get_workspace_size
(
args
);
CUTLASS_TRACE_HOST
(
" workspace_bytes: "
<<
workspace_bytes
);
if
(
workspace_bytes
)
{
if
(
!
workspace
)
{
CUTLASS_TRACE_HOST
(
" error: device workspace must not be null"
);
return
Status
::
kErrorWorkspaceNull
;
}
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
)
{
CUTLASS_TRACE_HOST
(
" clearing device workspace"
);
cudaError_t
result
=
cudaMemsetAsync
(
workspace
,
0
,
workspace_bytes
,
stream
);
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaMemsetAsync() returned error "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
}
}
// Get CUDA grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
// Initialize the Params structure
params_
=
typename
GemmKernel
::
Params
(
args
,
grid_tiled_shape
,
gemm_k_size
,
static_cast
<
int
*>
(
workspace
));
// Specify shared memory capacity for kernel.
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
if
(
smem_size
>=
(
48
<<
10
))
{
cudaError_t
result
=
cudaFuncSetAttribute
(
Kernel
<
GemmKernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
result
!=
cudaSuccess
)
{
return
Status
::
kErrorInternal
;
}
}
return
Status
::
kSuccess
;
}
/// Lightweight update given a subset of arguments
Status
update
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat()::update() - workspace: "
<<
workspace
);
size_t
workspace_bytes
=
get_workspace_size
(
args
);
if
(
workspace_bytes
&&
!
workspace
)
{
return
Status
::
kErrorWorkspaceNull
;
}
params_
.
update
(
args
,
workspace
);
return
Status
::
kSuccess
;
}
/// Runs the kernel using initialized state.
Status
run
(
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::run()"
);
//
// Configure grid and block dimensions
//
ThreadblockSwizzle
threadblock_swizzle
;
dim3
grid
=
threadblock_swizzle
.
get_grid_shape
(
params_
.
grid_tiled_shape
);
dim3
block
(
GemmKernel
::
kThreadCount
,
1
,
1
);
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
//
// Launch kernel
//
CUTLASS_TRACE_HOST
(
" grid: ("
<<
grid
<<
"), block: ("
<<
block
<<
"), SMEM: "
<<
smem_size
<<
" bytes"
);
// Launch
cutlass
::
Kernel
<
GemmKernel
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
params_
);
//
// Query for errors
//
cudaError_t
result
=
cudaGetLastError
();
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" grid launch failed with error "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
return
Status
::
kSuccess
;
}
/// Runs the kernel using initialized state.
Status
operator
()(
cudaStream_t
stream
=
nullptr
)
{
return
run
(
stream
);
}
/// Runs the kernel using initialized state.
Status
operator
()(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
Status
status
=
initialize
(
args
,
workspace
,
stream
);
if
(
status
==
Status
::
kSuccess
)
{
status
=
run
(
stream
);
}
return
status
;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace device
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
0 → 100644
View file @
0f3eb1d2
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
///! Threadblock swizzling function
>
struct
GemmWithEpilogueVisitor
{
public:
using
Mma
=
Mma_
;
using
Epilogue
=
Epilogue_
;
using
EpilogueVisitor
=
typename
Epilogue
::
Visitor
;
using
ThreadblockSwizzle
=
ThreadblockSwizzle_
;
using
ElementA
=
typename
Mma
::
IteratorA
::
Element
;
using
LayoutA
=
typename
Mma
::
IteratorA
::
Layout
;
using
TensorRefA
=
TensorRef
<
ElementA
,
LayoutA
>
;
using
ElementB
=
typename
Mma
::
IteratorB
::
Element
;
using
LayoutB
=
typename
Mma
::
IteratorB
::
Layout
;
using
TensorRefB
=
TensorRef
<
ElementB
,
LayoutB
>
;
using
ElementCompute
=
typename
EpilogueVisitor
::
ElementCompute
;
using
LayoutAlphaCol
=
cutlass
::
layout
::
RowMajor
;
using
LayoutAlphaRow
=
cutlass
::
layout
::
ColumnMajor
;
using
TensorRefAlphaCol
=
TensorRef
<
ElementCompute
,
LayoutAlphaCol
>
;
using
TensorRefAlphaRow
=
TensorRef
<
ElementCompute
,
LayoutAlphaRow
>
;
using
ElementC
=
typename
EpilogueVisitor
::
ElementOutput
;
using
LayoutC
=
typename
Epilogue
::
Layout
;
using
TensorRefC
=
TensorRef
<
ElementC
,
LayoutC
>
;
static
ComplexTransform
const
kTransformA
=
Mma
::
kTransformA
;
static
ComplexTransform
const
kTransformB
=
Mma
::
kTransformB
;
using
Operator
=
typename
Mma
::
Operator
;
using
OperatorClass
=
typename
Mma
::
Operator
::
OperatorClass
;
using
ThreadblockShape
=
typename
Mma
::
Shape
;
using
WarpShape
=
typename
Mma
::
Operator
::
Shape
;
using
InstructionShape
=
typename
Mma
::
Policy
::
Operator
::
InstructionShape
;
using
ArchTag
=
typename
Mma
::
ArchTag
;
using
EpilogueOutputOp
=
typename
Epilogue
::
Visitor
::
ElementwiseFunctor
;
// Define type so GemmUniversalBase doesn't complain
static
int
const
kStages
=
Mma
::
kStages
;
static
int
const
kAlignmentA
=
Mma
::
IteratorA
::
AccessType
::
kElements
;
static
int
const
kAlignmentB
=
Mma
::
IteratorB
::
AccessType
::
kElements
;
static
int
const
kAlignmentC
=
EpilogueVisitor
::
kElementsPerAccess
;
/// Warp count (concept: GemmShape)
using
WarpCount
=
typename
Mma
::
WarpCount
;
static
int
const
kThreadCount
=
32
*
WarpCount
::
kCount
;
/// Split-K preserves splits that are 128b aligned
static
int
const
kSplitKAlignment
=
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
);
//
// Structures
//
/// Argument structure
struct
Arguments
{
//
// Data members
//
GemmUniversalMode
mode
;
GemmCoord
problem_size
;
int
batch_count
;
TensorRefA
ref_A
;
TensorRefB
ref_B
;
TensorRefAlphaCol
ref_alpha_col
;
TensorRefAlphaRow
ref_alpha_row
;
TensorRefC
ref_C
;
TensorRefC
ref_D
;
int64_t
batch_stride_A
;
int64_t
batch_stride_B
;
int64_t
batch_stride_D
;
typename
EpilogueVisitor
::
Arguments
epilogue_visitor
;
//
// Methods
//
Arguments
()
:
mode
(
GemmUniversalMode
::
kGemm
),
batch_count
(
1
)
{}
/// constructs an arguments structure
Arguments
(
GemmCoord
problem_size_
,
TensorRefA
ref_A_
,
TensorRefB
ref_B_
,
TensorRefAlphaCol
ref_alpha_col_
,
TensorRefAlphaRow
ref_alpha_row_
,
TensorRefC
ref_C_
,
TensorRefC
ref_D_
,
typename
EpilogueVisitor
::
Arguments
epilogue_visitor_
)
:
mode
(
GemmUniversalMode
::
kGemm
),
problem_size
(
problem_size_
),
batch_count
(
1
),
ref_A
(
ref_A_
),
ref_B
(
ref_B_
),
ref_alpha_col
(
ref_alpha_col_
),
ref_alpha_row
(
ref_alpha_row_
),
ref_C
(
ref_C_
),
ref_D
(
ref_D_
),
batch_stride_A
(
0
),
batch_stride_B
(
0
),
batch_stride_D
(
0
),
epilogue_visitor
(
epilogue_visitor_
)
{}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct
Params
{
cutlass
::
gemm
::
GemmCoord
problem_size
;
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
swizzle_log_tile
;
typename
Mma
::
IteratorA
::
Params
params_A
;
typename
Mma
::
IteratorB
::
Params
params_B
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Params
params_alpha_col
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Params
params_alpha_row
;
typename
EpilogueVisitor
::
OutputTileIterator
::
Params
params_C
;
typename
EpilogueVisitor
::
OutputTileIterator
::
Params
params_D
;
GemmUniversalMode
mode
;
int
batch_count
;
int
gemm_k_size
;
void
*
ptr_A
;
void
*
ptr_B
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Element
*
ptr_alpha_col
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Element
*
ptr_alpha_row
;
ElementC
*
ptr_C
;
ElementC
*
ptr_D
;
int64_t
batch_stride_A
;
int64_t
batch_stride_B
;
typename
EpilogueVisitor
::
Params
epilogue_visitor
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
:
swizzle_log_tile
(
0
),
params_A
(
0
),
params_B
(
0
),
params_alpha_col
(
0
),
params_C
(
0
),
params_D
(
0
),
batch_count
(
0
),
gemm_k_size
(
0
),
mode
(
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
),
ptr_A
(
nullptr
),
ptr_B
(
nullptr
),
ptr_alpha_col
(
nullptr
),
ptr_alpha_row
(
nullptr
),
ptr_C
(
nullptr
),
ptr_D
(
nullptr
),
batch_stride_A
(
0
),
batch_stride_B
(
0
)
{}
Params
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape_
,
int
gemm_k_size_
,
int
*
workspace_
)
:
problem_size
(
args
.
problem_size
),
swizzle_log_tile
(
0
),
params_A
(
args
.
ref_A
.
layout
()),
params_B
(
args
.
ref_B
.
layout
()),
params_alpha_col
(
args
.
ref_alpha_col
.
layout
()),
params_alpha_row
(
args
.
ref_alpha_col
.
layout
()),
params_C
(
args
.
ref_C
.
layout
()),
params_D
(
args
.
ref_D
.
layout
()),
mode
(
args
.
mode
),
batch_count
(
args
.
batch_count
),
gemm_k_size
(
args
.
problem_size
.
k
()),
ptr_A
(
args
.
ref_A
.
data
()),
ptr_B
(
args
.
ref_B
.
data
()),
ptr_alpha_col
(
args
.
ref_alpha_col
.
data
()),
ptr_alpha_row
(
args
.
ref_alpha_row
.
data
()),
ptr_C
(
args
.
ref_C
.
data
()),
ptr_D
(
args
.
ref_D
.
data
()),
batch_stride_A
(
args
.
batch_stride_A
),
batch_stride_B
(
args
.
batch_stride_B
),
epilogue_visitor
(
args
.
epilogue_visitor
)
{
ThreadblockSwizzle
threadblock_swizzle
;
grid_tiled_shape
=
threadblock_swizzle
.
get_tiled_shape
(
args
.
problem_size
,
{
ThreadblockShape
::
kM
,
ThreadblockShape
::
kN
,
ThreadblockShape
::
kK
},
args
.
batch_count
);
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
||
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
int
const
kAlignK
=
const_max
(
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
),
1
);
gemm_k_size
=
round_up
(
ceil_div
(
args
.
problem_size
.
k
(),
args
.
batch_count
),
kAlignK
);
if
(
gemm_k_size
)
{
grid_tiled_shape
.
k
()
=
ceil_div
(
args
.
problem_size
.
k
(),
gemm_k_size
);
}
}
swizzle_log_tile
=
threadblock_swizzle
.
get_log_tile
(
grid_tiled_shape
);
}
};
/// Shared memory storage structure
union
SharedStorage
{
typename
Mma
::
SharedStorage
main_loop
;
struct
{
typename
Epilogue
::
SharedStorage
epilogue
;
typename
EpilogueVisitor
::
SharedStorage
visitor
;
}
epilogue
;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithEpilogueVisitor
()
{}
/// Determines whether kernel satisfies alignment
static
Status
can_implement
(
cutlass
::
gemm
::
GemmCoord
const
&
problem_size
)
{
CUTLASS_TRACE_HOST
(
"GemmWithEpilogueVisitor::can_implement()"
);
static
int
const
kAlignmentA
=
Mma
::
IteratorA
::
AccessType
::
kElements
;
static
int
const
kAlignmentB
=
Mma
::
IteratorB
::
AccessType
::
kElements
;
static
int
const
kAlignmentC
=
EpilogueVisitor
::
OutputTileIterator
::
kElementsPerAccess
;
bool
isAMisaligned
=
false
;
bool
isBMisaligned
=
false
;
bool
isCMisaligned
=
false
;
if
(
platform
::
is_same
<
LayoutA
,
layout
::
RowMajor
>::
value
)
{
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajor
>::
value
)
{
isAMisaligned
=
problem_size
.
m
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajor
>::
value
)
{
isBMisaligned
=
problem_size
.
n
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
ColumnMajor
>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
64
>>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
if
(
platform
::
is_same
<
LayoutC
,
layout
::
RowMajor
>::
value
)
{
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajor
>::
value
)
{
isCMisaligned
=
problem_size
.
m
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
if
(
isAMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for A operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
if
(
isBMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for B operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
if
(
isCMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for C operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
CUTLASS_TRACE_HOST
(
" returning kSuccess"
);
return
Status
::
kSuccess
;
}
static
Status
can_implement
(
Arguments
const
&
args
)
{
return
can_implement
(
args
.
problem_size
);
}
static
size_t
get_extra_workspace_size
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape
)
{
return
0
;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
CUTLASS_DEVICE
void
run_kernel_
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
// Compute threadblock location
ThreadblockSwizzle
threadblock_swizzle
;
cutlass
::
gemm
::
GemmCoord
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// Early exit if CTA is out of range
if
(
params
.
grid_tiled_shape
.
m
()
<=
threadblock_tile_offset
.
m
()
||
params
.
grid_tiled_shape
.
n
()
<=
threadblock_tile_offset
.
n
())
{
return
;
}
int
offset_k
=
0
;
int
problem_size_k
=
params
.
problem_size
.
k
();
ElementA
*
ptr_A
=
static_cast
<
ElementA
*>
(
params
.
ptr_A
);
ElementB
*
ptr_B
=
static_cast
<
ElementB
*>
(
params
.
ptr_B
);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
||
params
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
if
(
threadblock_tile_offset
.
k
()
+
1
<
params
.
grid_tiled_shape
.
k
())
{
problem_size_k
=
(
threadblock_tile_offset
.
k
()
+
1
)
*
params
.
gemm_k_size
;
}
offset_k
=
threadblock_tile_offset
.
k
()
*
params
.
gemm_k_size
;
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kBatched
)
{
ptr_A
+=
threadblock_tile_offset
.
k
()
*
params
.
batch_stride_A
;
ptr_B
+=
threadblock_tile_offset
.
k
()
*
params
.
batch_stride_B
;
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kArray
)
{
ptr_A
=
static_cast
<
ElementA
*
const
*>
(
params
.
ptr_A
)[
threadblock_tile_offset
.
k
()];
ptr_B
=
static_cast
<
ElementB
*
const
*>
(
params
.
ptr_B
)[
threadblock_tile_offset
.
k
()];
}
#endif
// Compute initial location in logical coordinates
cutlass
::
MatrixCoord
tb_offset_A
{
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
offset_k
,
};
cutlass
::
MatrixCoord
tb_offset_B
{
offset_k
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
};
// Compute position within threadblock
int
thread_idx
=
threadIdx
.
x
;
// Construct iterators to A and B operands
typename
Mma
::
IteratorA
iterator_A
(
params
.
params_A
,
ptr_A
,
{
params
.
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorB
iterator_B
(
params
.
params_B
,
ptr_B
,
{
problem_size_k
,
params
.
problem_size
.
n
()},
thread_idx
,
tb_offset_B
);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int
warp_idx
=
__shfl_sync
(
0xffffffff
,
threadIdx
.
x
/
32
,
0
);
int
lane_idx
=
threadIdx
.
x
%
32
;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma
mma
(
shared_storage
.
main_loop
,
thread_idx
,
warp_idx
,
lane_idx
);
typename
Mma
::
FragmentC
accumulators
;
accumulators
.
clear
();
// Compute threadblock-scoped matrix multiply-add
int
gemm_k_iterations
=
(
problem_size_k
-
offset_k
+
Mma
::
Shape
::
kK
-
1
)
/
Mma
::
Shape
::
kK
;
// Compute threadblock-scoped matrix multiply-add
mma
(
gemm_k_iterations
,
accumulators
,
iterator_A
,
iterator_B
,
accumulators
);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// assume identity swizzle
MatrixCoord
threadblock_offset
(
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
);
int
block_idx
=
threadblock_tile_offset
.
m
()
+
threadblock_tile_offset
.
n
()
*
params
.
grid_tiled_shape
.
m
();
//
// Construct the epilogue visitor
//
bool
with_bias
=
true
;
if
(
params
.
ptr_C
==
nullptr
)
{
with_bias
=
false
;
}
EpilogueVisitor
epilogue_visitor
(
params
.
epilogue_visitor
,
shared_storage
.
epilogue
.
visitor
,
params
.
problem_size
.
mn
(),
thread_idx
,
warp_idx
,
lane_idx
,
params
.
params_alpha_col
,
params
.
params_C
,
params
.
params_D
,
with_bias
,
true
,
true
,
params
.
ptr_alpha_row
,
params
.
ptr_alpha_col
,
params
.
ptr_C
,
params
.
ptr_D
,
threadblock_offset
,
blockIdx
.
y
*
params
.
problem_size
.
m
());
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
)
{
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor
.
set_k_partition
(
threadblock_tile_offset
.
k
(),
params
.
grid_tiled_shape
.
k
());
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kBatched
||
params
.
mode
==
GemmUniversalMode
::
kArray
)
{
epilogue_visitor
.
set_batch_index
(
threadblock_tile_offset
.
k
());
}
// Construct the epilogue
Epilogue
epilogue
(
shared_storage
.
epilogue
.
epilogue
,
thread_idx
,
warp_idx
,
lane_idx
);
// Execute the epilogue operator to update the destination tensor.
epilogue
(
epilogue_visitor
,
accumulators
);
}
template
<
typename
CompilationArch
>
CUTLASS_DEVICE
void
run_kernel
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
if
constexpr
(
platform
::
is_same
<
ArchTag
,
CompilationArch
>::
value
)
{
run_kernel_
(
params
,
shared_storage
);
}
else
{
CUTLASS_NOT_IMPLEMENTED
();
}
}
/// Executes one GEMM
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
run_kernel
<
ArchTag
>
(
params
,
shared_storage
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
0 → 100644
View file @
0f3eb1d2
#include <ATen/cuda/CUDAContext.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/epilogue_with_visitor.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
#include "cutlass_extensions/gemm/gemm_universal_base_compat.h"
#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h"
#include "utils.hpp"
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
ThreadblockShape
,
typename
WarpShape
,
typename
InstructionShape
,
int
NumStages
>
void
cutlass_int8_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
ElementAccumulator
=
int32_t
;
using
ElementCompute
=
float
;
using
ElementInputA
=
int8_t
;
using
ElementInputB
=
int8_t
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
ThreadblockSwizzle
=
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
;
using
DefaultGemmConf
=
cutlass
::
gemm
::
device
::
DefaultGemmConfiguration
<
OperatorClass
,
ArchTag
,
ElementInputA
,
ElementInputB
,
ElementOutput
,
ElementCompute
>
;
using
EpilogueOutputOp
=
typename
DefaultGemmConf
::
EpilogueOutputOp
;
using
GemmKernel_
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemm
<
ElementInputA
,
cutlass
::
layout
::
RowMajor
,
DefaultGemmConf
::
kAlignmentA
,
ElementInputB
,
cutlass
::
layout
::
ColumnMajor
,
DefaultGemmConf
::
kAlignmentB
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
OperatorClass
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
EpilogueOutputOp
,
ThreadblockSwizzle
,
NumStages
,
true
,
typename
DefaultGemmConf
::
Operator
>::
GemmKernel
;
using
AlphaColTileIterator
=
cutlass
::
epilogue
::
threadblock
::
PredicatedTileIterator
<
cutlass
::
epilogue
::
threadblock
::
OutputTileOptimalThreadMap
<
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
Shape
,
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
Count
,
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
kThreads
,
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
kElementsPerAccess
,
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
>
,
ElementCompute
>
;
using
EpilogueVisitor
=
typename
cutlass
::
epilogue
::
threadblock
::
EpilogueVisitorPerRowPerCol
<
ThreadblockShape
,
GemmKernel_
::
kThreadCount
,
AlphaColTileIterator
,
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
,
ElementAccumulator
,
ElementCompute
,
EpilogueOutputOp
>
;
using
Epilogue
=
typename
cutlass
::
epilogue
::
threadblock
::
EpilogueWithVisitorFromExistingEpilogue
<
EpilogueVisitor
,
typename
GemmKernel_
::
Epilogue
>::
Epilogue
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmWithEpilogueVisitor
<
typename
GemmKernel_
::
Mma
,
Epilogue
,
ThreadblockSwizzle
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalBaseCompat
<
GemmKernel
>
;
Gemm
gemm_op
;
int
m
=
mat_a
.
size
(
0
);
int
k
=
mat_a
.
size
(
1
);
int
n
=
mat_b
.
size
(
1
);
auto
a_ptr
=
static_cast
<
ElementInputA
*>
(
mat_a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementInputB
*>
(
mat_b
.
data_ptr
());
auto
o_ptr
=
static_cast
<
ElementOutput
*>
(
out
.
data_ptr
());
auto
a_s_ptr
=
static_cast
<
ElementCompute
*>
(
scales_a
.
data_ptr
());
auto
b_s_ptr
=
static_cast
<
ElementCompute
*>
(
scales_b
.
data_ptr
());
int64_t
lda
=
mat_a
.
stride
(
0
);
int64_t
ldb
=
mat_b
.
stride
(
1
);
int64_t
ldd
=
out
.
stride
(
0
);
ElementOutput
*
bias_ptr
=
nullptr
;
int64_t
ldc
=
0
;
if
(
bias
)
{
bias_ptr
=
static_cast
<
ElementOutput
*>
(
bias
->
data_ptr
());
}
typename
EpilogueOutputOp
::
Params
linearScalingParams
;
typename
EpilogueVisitor
::
Arguments
visitor_args
{
linearScalingParams
};
typename
Gemm
::
Arguments
args
{{
m
,
n
,
k
},
{
a_ptr
,
lda
},
{
b_ptr
,
ldb
},
{
b_s_ptr
,
0
},
{
a_s_ptr
,
0
},
{
bias_ptr
,
ldc
},
{
o_ptr
,
ldd
},
visitor_args
};
auto
workspace
=
torch
::
empty
(
gemm_op
.
get_workspace_size
(
args
),
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
mat_a
.
device
()));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
mat_a
.
get_device
());
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
)
auto
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
)
}
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
InstructionShape
>
void
sm75_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
int
m
=
mat_a
.
size
(
0
);
if
(
m
<=
32
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
m
<=
64
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
m
<=
256
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
InstructionShape
>
void
sm80_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
int
m
=
mat_a
.
size
(
0
);
int
n
=
mat_b
.
size
(
1
);
if
(
m
<=
16
)
{
if
(
n
<=
4096
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
InstructionShape
,
6
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
else
if
(
m
<=
32
)
{
if
(
n
<=
4096
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
6
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
else
if
(
m
<=
64
||
(
m
<=
128
&&
n
<
8192
))
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_b
.
dim
()
==
2
,
"mat_b must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
stride
(
1
)
==
1
,
"mat_a must be a row major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_a must be a column major tensor"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
%
16
==
0
,
"mat_a.size(1) must be multiple of 16 for memory alignment"
);
TORCH_CHECK
(
mat_b
.
size
(
0
)
%
16
==
0
,
"mat_b.size(0) must be multiple of 16 for memory alignment"
);
TORCH_CHECK
(
mat_b
.
size
(
1
)
%
8
==
0
,
"mat_b.size(1) must be multiple of 8 for memory alignment"
);
// out.stride(0)
TORCH_CHECK
(
mat_a
.
scalar_type
()
==
torch
::
kInt8
,
"mat_a must be Int8"
);
TORCH_CHECK
(
mat_b
.
scalar_type
()
==
torch
::
kInt8
,
"mat_b must be Int8"
);
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"out_dtype must be Half or BFloat16"
);
TORCH_CHECK
(
scales_a
.
numel
()
==
mat_a
.
size
(
0
),
"size of scales_a is not matched"
);
TORCH_CHECK
(
scales_b
.
numel
()
==
mat_b
.
size
(
1
),
"size of scales_b is not matched"
);
TORCH_CHECK
(
scales_a
.
is_contiguous
(),
"scales_a must be contiguous"
);
TORCH_CHECK
(
scales_b
.
is_contiguous
(),
"scales_b msut be contiguous"
);
TORCH_CHECK
(
scales_a
.
scalar_type
()
==
torch
::
kFloat32
,
"scales_a must be Float32"
);
TORCH_CHECK
(
scales_b
.
scalar_type
()
==
torch
::
kFloat32
,
"scales_b must be Float32"
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
mat_b
.
size
(
1
),
"size of bias is not matched"
);
TORCH_CHECK
(
bias
->
is_contiguous
(),
"bias must be contiguous"
);
TORCH_CHECK
(
bias
->
dtype
()
==
out_dtype
,
"bias dtype must match output dtype"
);
}
torch
::
Tensor
out
=
torch
::
empty
({
mat_a
.
size
(
0
),
mat_b
.
size
(
1
)},
mat_a
.
options
().
dtype
(
out_dtype
));
auto
sm_version
=
getSMVersion
();
if
(
sm_version
>=
75
&&
sm_version
<
80
)
{
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
,
"out_dtype must be Half for SM75"
);
sm75_dispatch_shape
<
cutlass
::
half_t
,
cutlass
::
arch
::
Sm75
,
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
sm_version
>=
80
&&
sm_version
<=
90
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm80_dispatch_shape
<
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
sm80_dispatch_shape
<
cutlass
::
half_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
else
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented int8_scaled_mm for current compute capability."
);
}
return
out
;
}
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
View file @
0f3eb1d2
...
@@ -12,6 +12,11 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
...
@@ -12,6 +12,11 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
// int8_scaled_mm
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
// trt_reduce
// trt_reduce
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
...
@@ -19,4 +24,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -19,4 +24,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"all_reduce"
,
&
all_reduce
,
"custom all reduce (CUDA)"
);
m
.
def
(
"all_reduce"
,
&
all_reduce
,
"custom all reduce (CUDA)"
);
// moe_align_block_size
// moe_align_block_size
m
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"MOE Align Block Size (CUDA)"
);
m
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"MOE Align Block Size (CUDA)"
);
// int8_scaled_mm
m
.
def
(
"int8_scaled_mm"
,
&
int8_scaled_mm
,
"INT8 scaled matmul (CUDA)"
);
}
}
sgl-kernel/src/sgl-kernel/csrc/utils.hpp
View file @
0f3eb1d2
...
@@ -34,3 +34,13 @@ struct cuda_error : public std::runtime_error {
...
@@ -34,3 +34,13 @@ struct cuda_error : public std::runtime_error {
#define CHECK_CUDA_INPUT(x) \
#define CHECK_CUDA_INPUT(x) \
CHECK_IS_CUDA(x); \
CHECK_IS_CUDA(x); \
CHECK_IS_CONTIGUOUS(x)
CHECK_IS_CONTIGUOUS(x)
inline
int
getSMVersion
()
{
int
device
{
-
1
};
CHECK_CUDA_SUCCESS
(
cudaGetDevice
(
&
device
));
int
sm_major
=
0
;
int
sm_minor
=
0
;
CHECK_CUDA_SUCCESS
(
cudaDeviceGetAttribute
(
&
sm_major
,
cudaDevAttrComputeCapabilityMajor
,
device
));
CHECK_CUDA_SUCCESS
(
cudaDeviceGetAttribute
(
&
sm_minor
,
cudaDevAttrComputeCapabilityMinor
,
device
));
return
sm_major
*
10
+
sm_minor
;
}
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
0f3eb1d2
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
from
sgl_kernel.ops._kernels
import
dispose
as
_dispose
from
sgl_kernel.ops._kernels
import
dispose
as
_dispose
from
sgl_kernel.ops._kernels
import
init_custom_ar
as
_init_custom_ar
from
sgl_kernel.ops._kernels
import
init_custom_ar
as
_init_custom_ar
from
sgl_kernel.ops._kernels
import
int8_scaled_mm
as
_int8_scaled_mm
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
...
@@ -36,3 +37,14 @@ def moe_align_block_size(
...
@@ -36,3 +37,14 @@ def moe_align_block_size(
token_cnts_buffer
,
token_cnts_buffer
,
cumsum_buffer
,
cumsum_buffer
,
)
)
def
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
_int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
,
)
sgl-kernel/tests/test_int8_gemm.py
0 → 100644
View file @
0f3eb1d2
import
unittest
import
torch
from
sgl_kernel
import
int8_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
def
to_int8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
):
o
=
torch
.
matmul
(
a
.
to
(
torch
.
float32
),
b
.
to
(
torch
.
float32
))
if
bias
is
not
None
:
o
=
o
.
to
(
torch
.
float32
)
*
scale_a
.
view
(
-
1
,
1
)
*
scale_b
.
view
(
1
,
-
1
)
+
bias
else
:
o
=
o
.
to
(
torch
.
float32
)
*
scale_a
.
view
(
-
1
,
1
)
*
scale_b
.
view
(
1
,
-
1
)
return
o
.
to
(
out_dtype
)
class
TestInt8Gemm
(
unittest
.
TestCase
):
def
_test_accuracy_once
(
self
,
M
,
N
,
K
,
with_bias
,
out_dtype
,
device
):
a
=
to_int8
(
torch
.
randn
((
M
,
K
),
device
=
device
)
*
5
)
b
=
to_int8
(
torch
.
randn
((
N
,
K
),
device
=
device
).
t
()
*
5
)
scale_a
=
torch
.
randn
((
M
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
if
with_bias
:
bias
=
torch
.
ones
((
N
,),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
else
:
bias
=
None
o
=
int8_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o1
=
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o2
=
vllm_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
torch
.
testing
.
assert_close
(
o
,
o1
)
torch
.
testing
.
assert_close
(
o
,
o2
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
def
test_accuracy
(
self
):
Ms
=
[
1
,
128
,
512
,
1024
,
4096
]
Ns
=
[
16
,
128
,
512
,
1024
,
4096
]
Ks
=
[
512
,
1024
,
4096
,
8192
,
16384
]
bias_opts
=
[
True
,
False
]
out_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
]
for
M
in
Ms
:
for
N
in
Ns
:
for
K
in
Ks
:
for
with_bias
in
bias_opts
:
for
out_dtype
in
out_dtypes
:
self
.
_test_accuracy_once
(
M
,
N
,
K
,
with_bias
,
out_dtype
,
"cuda"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment