Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3ee62235
Unverified
Commit
3ee62235
authored
Jan 31, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 31, 2025
Browse files
revert the MoE dependence (#3230)
parent
9829e77e
Changes
94
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
7134 deletions
+0
-7134
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h
...lass_extensions/gemm/kernel/default_splitk_gemm_grouped.h
+0
-207
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h
...ns/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h
+0
-566
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh
...clude/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh
+0
-218
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh
...tlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh
+0
-799
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh
...utlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh
+0
-215
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h
...cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h
+0
-73
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp
...e/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp
+0
-70
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
...tlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
+0
-585
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h
...lude/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h
+0
-143
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh
.../include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh
+0
-185
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h
...clude/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h
+0
-553
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h
...lude/cutlass_extensions/gemm/kernel/moe_problem_visitor.h
+0
-344
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp
...ernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp
+0
-646
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp
...m/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp
+0
-621
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h
...lude/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h
+0
-494
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h
...lude/cutlass_extensions/gemm/threadblock/default_dq_mma.h
+0
-125
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h
...s_extensions/gemm/threadblock/default_dq_mma_multistage.h
+0
-302
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h
...ss_extensions/gemm/threadblock/default_dq_mma_pipelined.h
+0
-284
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h
...include/cutlass_extensions/gemm/threadblock/default_mma.h
+0
-351
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h
...de/cutlass_extensions/gemm/threadblock/default_mma_bf16.h
+0
-353
No files found.
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/permute.h"
#include "splitk_gemm_grouped.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
/// Element type for A matrix operand
typename
ElementA_
,
/// Layout type for A matrix operand
typename
LayoutA_
,
/// Complex elementwise transformation on A operand
ComplexTransform
TransformA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Element type for B matrix operand
typename
ElementB_
,
/// Layout type for B matrix operand
typename
LayoutB_
,
/// Complex elementwise transformation on B operand
ComplexTransform
TransformB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for C and D matrix operands
typename
ElementC_
,
/// Layout type for C and D matrix operands
typename
LayoutC_
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Operator class tag
typename
OperatorClass
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Warp-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Epilogue output operator
typename
EpilogueOutputOp
,
/// Threadblock-level swizzling operator
typename
ThreadblockSwizzle
,
/// Number of stages used in the pipelined mainloop
int
Stages
,
/// Whether the schedule of problems to visit has been precomputed
GroupScheduleMode
GroupScheduleMode_
=
GroupScheduleMode
::
kDeviceOnly
,
/// Operation performed by GEMM
typename
Operator
=
typename
device
::
DefaultGemmConfiguration
<
OperatorClass
,
ArchTag
,
ElementA_
,
ElementB_
,
ElementC_
,
ElementAccumulator
>
::
Operator
,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption
SharedMemoryClear
=
SharedMemoryClearOption
::
kNone
,
/// Permute result D
typename
PermuteDLayout
=
layout
::
NoPermute
,
///
typename
Enable
=
void
>
struct
DefaultSplitkGemmGrouped
;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Real-valued GEMM kernels
//
template
<
/// Element type for A matrix operand
typename
ElementA
,
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Element type for B matrix operand
typename
ElementB
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for C and D matrix operands
typename
ElementC
,
/// Layout type for C and D matrix operands
typename
LayoutC
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Operator class tag
typename
OperatorClass
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Warp-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Epilogue output operator
typename
EpilogueOutputOp
,
/// Threadblock-level swizzling operator
typename
ThreadblockSwizzle
,
/// Number of stages used in the pipelined mainloop
int
Stages
,
/// Whether the schedule of problems to visit has been precomputed
GroupScheduleMode
GroupScheduleMode_
,
/// Operation performed by GEMM
typename
Operator
,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption
SharedMemoryClear
,
/// Permute result D
typename
PermuteDLayout
>
struct
DefaultSplitkGemmGrouped
<
ElementA
,
LayoutA
,
ComplexTransform
::
kNone
,
// transform A
kAlignmentA
,
ElementB
,
LayoutB
,
ComplexTransform
::
kNone
,
// transform B
kAlignmentB
,
ElementC
,
LayoutC
,
ElementAccumulator
,
OperatorClass
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
EpilogueOutputOp
,
ThreadblockSwizzle
,
Stages
,
GroupScheduleMode_
,
Operator
,
SharedMemoryClear
,
PermuteDLayout
,
typename
platform
::
enable_if
<!
cutlass
::
is_complex
<
ElementAccumulator
>::
value
>::
type
>
{
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
static
bool
const
kInternalTranspose
=
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajor
>::
value
;
using
MapArguments
=
kernel
::
detail
::
MapArguments
<
ElementA
,
LayoutA
,
ComplexTransform
::
kNone
,
kAlignmentA
,
ElementB
,
LayoutB
,
ComplexTransform
::
kNone
,
kAlignmentB
,
LayoutC
,
kInternalTranspose
>
;
// Define the default GEMM kernel
using
DefaultGemmKernel
=
typename
kernel
::
DefaultGemm
<
typename
MapArguments
::
ElementA
,
typename
MapArguments
::
LayoutA
,
MapArguments
::
kAlignmentA
,
typename
MapArguments
::
ElementB
,
typename
MapArguments
::
LayoutB
,
MapArguments
::
kAlignmentB
,
ElementC
,
typename
MapArguments
::
LayoutC
,
ElementAccumulator
,
OperatorClass
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
EpilogueOutputOp
,
ThreadblockSwizzle
,
Stages
,
true
,
Operator
,
SharedMemoryClear
,
false
,
/*GatherA*/
false
,
/*GatherB*/
false
,
/*ScatterD*/
PermuteDLayout
>::
GemmKernel
;
/// Define the kernel in terms of the default kernel
using
GemmKernel
=
kernel
::
SplitkGemmGrouped
<
typename
DefaultGemmKernel
::
Mma
,
typename
DefaultGemmKernel
::
Epilogue
,
ThreadblockSwizzle
,
GroupScheduleMode_
,
kInternalTranspose
>
;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include <type_traits>
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
detail
{
template
<
typename
>
inline
constexpr
bool
dependent_false_v
=
false
;
}
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
,
///! Threadblock swizzling function
typename
KernelArch
,
///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
/// arch.
bool
SplitKSerial
///! If true, code supporting split-K via serial reduction is enabled.
>
struct
GemmFpAIntB
{
using
Mma
=
Mma_
;
using
Epilogue
=
Epilogue_
;
using
EpilogueOutputOp
=
typename
Epilogue
::
OutputOp
;
using
ThreadblockSwizzle
=
ThreadblockSwizzle_
;
static
bool
const
kSplitKSerial
=
SplitKSerial
;
using
ElementA
=
typename
Mma
::
IteratorA
::
Element
;
using
LayoutA
=
typename
Mma
::
IteratorA
::
Layout
;
using
ElementB
=
typename
Mma
::
IteratorB
::
Element
;
using
LayoutB
=
typename
Mma
::
IteratorB
::
Element
;
using
ElementC
=
typename
Epilogue
::
OutputTileIterator
::
Element
;
using
LayoutC
=
typename
Mma
::
LayoutC
;
using
ElementScale
=
ElementC
;
static
ComplexTransform
const
kTransformA
=
Mma
::
kTransformA
;
static
ComplexTransform
const
kTransformB
=
Mma
::
kTransformA
;
// Type definitions about the mainloop.
using
Operator
=
typename
Mma
::
Operator
;
using
OperatorClass
=
typename
Mma
::
Operator
::
OperatorClass
;
using
ThreadblockShape
=
typename
Mma
::
Shape
;
using
WarpShape
=
typename
Mma
::
Operator
::
Shape
;
using
InstructionShape
=
typename
Mma
::
Policy
::
Operator
::
InstructionShape
;
using
ArchTag
=
typename
Mma
::
ArchTag
;
static
int
const
kStages
=
Mma
::
kStages
;
static
int
const
kAlignmentA
=
Mma
::
IteratorA
::
AccessType
::
kElements
;
static
int
const
kAlignmentB
=
Mma
::
IteratorB
::
AccessType
::
kElements
;
static
int
const
kAlignmentC
=
Epilogue
::
OutputTileIterator
::
kElementsPerAccess
;
/// Warp count (concept: GemmShape)
using
WarpCount
=
typename
Mma
::
WarpCount
;
static
int
const
kThreadCount
=
32
*
WarpCount
::
kCount
;
static
constexpr
int
kInterleave
=
Mma
::
IteratorB
::
Shape
::
kRow
/
Mma
::
Shape
::
kK
;
/// Parameters structure
struct
Arguments
{
GemmUniversalMode
mode
=
GemmUniversalMode
::
kGemm
;
cutlass
::
gemm
::
GemmCoord
problem_size
;
int
group_size
;
typename
Mma
::
IteratorA
::
TensorRef
ref_A
;
typename
Mma
::
IteratorB
::
TensorRef
ref_B
;
typename
Mma
::
IteratorScale
::
TensorRef
ref_scale
;
typename
Mma
::
IteratorScale
::
TensorRef
ref_zero
;
typename
Epilogue
::
OutputTileIterator
::
TensorRef
ref_C
;
typename
Epilogue
::
OutputTileIterator
::
TensorRef
ref_D
;
// Control serial split-k
int
batch_count
;
typename
EpilogueOutputOp
::
Params
output_op
;
// For gather+scatter operations
int
const
*
gather_A_indices
;
int
const
*
gather_B_indices
;
int
const
*
scatter_D_indices
;
// Included so we can use Gemm Universal
int
batch_stride_D
=
0
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments
()
{}
CUTLASS_HOST_DEVICE
Arguments
(
cutlass
::
gemm
::
GemmCoord
const
&
problem_size
,
int
const
group_size
,
typename
Mma
::
IteratorA
::
TensorRef
ref_A
,
typename
Mma
::
IteratorB
::
TensorRef
ref_B
,
typename
Mma
::
IteratorScale
::
TensorRef
ref_scale
,
typename
Mma
::
IteratorScale
::
TensorRef
ref_zero
,
typename
Epilogue
::
OutputTileIterator
::
TensorRef
ref_C
,
typename
Epilogue
::
OutputTileIterator
::
TensorRef
ref_D
,
int
serial_split_k_factor
,
typename
EpilogueOutputOp
::
Params
output_op
=
typename
EpilogueOutputOp
::
Params
(),
int
const
*
gather_A_indices
=
nullptr
,
int
const
*
gather_B_indices
=
nullptr
,
int
const
*
scatter_D_indices
=
nullptr
)
:
problem_size
(
problem_size
)
,
group_size
(
group_size
)
,
ref_A
(
ref_A
)
,
ref_B
(
ref_B
)
,
ref_scale
(
ref_scale
)
,
ref_zero
(
ref_zero
)
,
ref_C
(
ref_C
)
,
ref_D
(
ref_D
)
,
batch_count
(
serial_split_k_factor
)
,
output_op
(
output_op
)
,
gather_A_indices
(
gather_A_indices
)
,
gather_B_indices
(
gather_B_indices
)
,
scatter_D_indices
(
scatter_D_indices
)
{
}
};
/// Parameters structure
struct
Params
{
cutlass
::
gemm
::
GemmCoord
problem_size
;
int
group_size
;
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
swizzle_log_tile
;
typename
Mma
::
IteratorA
::
Params
params_A
;
typename
Mma
::
IteratorA
::
TensorRef
ref_A
;
typename
Mma
::
IteratorB
::
Params
params_B
;
typename
Mma
::
IteratorB
::
TensorRef
ref_B
;
typename
Mma
::
IteratorScale
::
Params
params_scale
;
typename
Mma
::
IteratorScale
::
TensorRef
ref_scale
;
typename
Mma
::
IteratorScale
::
TensorRef
ref_zero
;
typename
Epilogue
::
OutputTileIterator
::
Params
params_C
;
typename
Epilogue
::
OutputTileIterator
::
TensorRef
ref_C
;
typename
Epilogue
::
OutputTileIterator
::
Params
params_D
;
typename
Epilogue
::
OutputTileIterator
::
TensorRef
ref_D
;
typename
EpilogueOutputOp
::
Params
output_op
;
int
*
semaphore
;
int
gemm_k_size
;
// For gather+scatter operations
int
const
*
gather_A_indices
;
int
const
*
gather_B_indices
;
int
const
*
scatter_D_indices
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
:
swizzle_log_tile
(
0
)
,
semaphore
(
0
)
,
gemm_k_size
(
0
)
{
}
CUTLASS_HOST_DEVICE
Params
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape
,
int
const
gemm_k_size
,
void
*
workspace
=
nullptr
)
:
problem_size
(
args
.
problem_size
)
,
group_size
(
args
.
group_size
)
,
grid_tiled_shape
(
grid_tiled_shape
)
,
swizzle_log_tile
(
ThreadblockSwizzle
().
get_log_tile
(
grid_tiled_shape
))
,
params_A
(
args
.
ref_A
.
layout
())
,
ref_A
(
args
.
ref_A
)
,
params_B
(
args
.
ref_B
.
layout
())
,
ref_B
(
args
.
ref_B
)
,
params_scale
(
args
.
ref_scale
.
layout
())
,
ref_scale
(
args
.
ref_scale
)
,
ref_zero
(
args
.
ref_zero
)
,
params_C
(
args
.
ref_C
.
layout
())
,
ref_C
(
args
.
ref_C
)
,
params_D
(
args
.
ref_D
.
layout
())
,
ref_D
(
args
.
ref_D
)
,
output_op
(
args
.
output_op
)
,
semaphore
(
static_cast
<
int
*>
(
workspace
))
,
gemm_k_size
(
gemm_k_size
)
,
gather_A_indices
(
args
.
gather_A_indices
)
,
gather_B_indices
(
args
.
gather_B_indices
)
,
scatter_D_indices
(
args
.
scatter_D_indices
)
{
}
};
/// Shared memory storage structure
union
SharedStorage
{
typename
Mma
::
SharedStorage
main_loop
;
typename
Epilogue
::
SharedStorage
epilogue
;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
GemmFpAIntB
()
{}
/// Determines whether kernel satisfies alignment
static
Status
can_implement
(
Arguments
const
&
args
)
{
static
int
const
kAlignmentA
=
(
platform
::
is_same
<
typename
Mma
::
IteratorA
::
Layout
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
)
?
32
:
(
platform
::
is_same
<
typename
Mma
::
IteratorA
::
Layout
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
?
64
:
Mma
::
IteratorA
::
AccessType
::
kElements
;
static
int
const
kAlignmentB
=
(
platform
::
is_same
<
typename
Mma
::
IteratorB
::
Layout
,
layout
::
RowMajorInterleaved
<
32
>>::
value
)
?
32
:
(
platform
::
is_same
<
typename
Mma
::
IteratorB
::
Layout
,
layout
::
RowMajorInterleaved
<
64
>>::
value
)
?
64
:
Mma
::
IteratorB
::
AccessType
::
kElements
;
static
int
const
kAlignmentScale
=
Mma
::
IteratorScale
::
AccessType
::
kElements
;
static
int
const
kAlignmentC
=
(
platform
::
is_same
<
typename
Epilogue
::
OutputTileIterator
::
Layout
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
)
?
32
:
(
platform
::
is_same
<
typename
Epilogue
::
OutputTileIterator
::
Layout
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
?
64
:
Epilogue
::
OutputTileIterator
::
kElementsPerAccess
;
if
(
!
TensorRef_aligned
(
args
.
ref_A
,
kAlignmentA
))
{
return
Status
::
kErrorMisalignedOperand
;
}
if
(
!
TensorRef_aligned
(
args
.
ref_B
,
kAlignmentB
))
{
return
Status
::
kErrorMisalignedOperand
;
}
if
(
!
TensorRef_aligned
(
args
.
ref_scale
,
kAlignmentScale
))
{
return
Status
::
kErrorMisalignedOperand
;
}
if
(
!
TensorRef_aligned
(
args
.
ref_zero
,
kAlignmentScale
))
{
return
Status
::
kErrorMisalignedOperand
;
}
if
(
!
TensorRef_aligned
(
args
.
ref_C
,
kAlignmentC
))
{
return
Status
::
kErrorMisalignedOperand
;
}
if
(
!
TensorRef_aligned
(
args
.
ref_D
,
kAlignmentC
))
{
return
Status
::
kErrorMisalignedOperand
;
}
if
(
!
args
.
ref_scale
.
good
())
{
return
Status
::
kErrorNotSupported
;
}
if
constexpr
(
hasZero
(
Mma
::
QuantOp
))
{
if
(
!
args
.
ref_zero
.
good
())
{
return
Status
::
kErrorNotSupported
;
}
}
else
{
if
(
args
.
ref_zero
.
good
())
{
return
Status
::
kErrorNotSupported
;
}
}
if
constexpr
(
isFinegrained
(
Mma
::
QuantOp
))
{
if
(
args
.
group_size
!=
64
&&
args
.
group_size
!=
128
)
{
return
Status
::
kErrorNotSupported
;
}
}
return
Status
::
kSuccess
;
}
static
size_t
get_extra_workspace_size
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape
)
{
return
0
;
}
// Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator
// has a different constructor signature than a regular cutlass iterator
template
<
typename
IteratorScale
,
WeightOnlyQuantOp
op
,
std
::
enable_if_t
<
isFinegrained
(
op
),
bool
>
=
true
>
CUTLASS_DEVICE
static
IteratorScale
initialize_scale
(
typename
IteratorScale
::
Params
const
&
params
,
typename
IteratorScale
::
Pointer
pointer_scale
,
typename
IteratorScale
::
Pointer
pointer_zero
,
typename
IteratorScale
::
TensorCoord
extent
,
int
thread_id
,
typename
IteratorScale
::
TensorCoord
const
&
threadblock_offset
,
int
group_size
)
{
return
IteratorScale
(
params
,
pointer_scale
,
pointer_zero
,
extent
,
thread_id
,
threadblock_offset
,
group_size
);
}
template
<
typename
IteratorScale
,
WeightOnlyQuantOp
op
,
std
::
enable_if_t
<!
isFinegrained
(
op
),
bool
>
=
true
>
CUTLASS_DEVICE
static
IteratorScale
initialize_scale
(
typename
IteratorScale
::
Params
const
&
params
,
typename
IteratorScale
::
Pointer
pointer_scale
,
typename
IteratorScale
::
Pointer
pointer_zero
,
typename
IteratorScale
::
TensorCoord
extent
,
int
thread_id
,
typename
IteratorScale
::
TensorCoord
const
&
threadblock_offset
,
int
group_size
)
{
return
IteratorScale
(
params
,
pointer_scale
,
extent
,
thread_id
,
threadblock_offset
);
}
CUTLASS_DEVICE
void
run_kernel_
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
using
LayoutB
=
typename
Mma
::
IteratorB
::
Layout
;
static_assert
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajor
>::
value
&&
kInterleave
==
1
||
platform
::
is_same
<
LayoutB
,
layout
::
ColumnMajor
>::
value
&&
kInterleave
>=
1
,
"B must be row major/col major OR col major interleaved."
);
// Compute threadblock location
ThreadblockSwizzle
threadblock_swizzle
;
cutlass
::
gemm
::
GemmCoord
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// Early exit if CTA is out of range
if
(
params
.
grid_tiled_shape
.
m
()
<=
threadblock_tile_offset
.
m
()
||
params
.
grid_tiled_shape
.
n
()
<=
threadblock_tile_offset
.
n
())
{
return
;
}
// Compute initial location in logical coordinates
cutlass
::
MatrixCoord
tb_offset_A
{
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
threadblock_tile_offset
.
k
()
*
params
.
gemm_k_size
,
};
cutlass
::
MatrixCoord
tb_offset_B
{
threadblock_tile_offset
.
k
()
*
params
.
gemm_k_size
*
kInterleave
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
/
kInterleave
};
typename
MatrixCoord
::
Index
fg_row_offset
=
threadblock_tile_offset
.
k
()
*
params
.
gemm_k_size
/
64
;
typename
MatrixCoord
::
Index
scale_row_offset
=
isFinegrained
(
Mma
::
QuantOp
)
?
fg_row_offset
:
0
;
cutlass
::
MatrixCoord
tb_offset_scale
{
scale_row_offset
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
};
// Problem size is a function of threadblock index in the K dimension
int
problem_size_k
=
min
(
params
.
problem_size
.
k
(),
(
threadblock_tile_offset
.
k
()
+
1
)
*
params
.
gemm_k_size
);
// Compute threadblock-scoped matrix multiply-add
int
gemm_k_iterations
=
(
problem_size_k
-
tb_offset_A
.
column
()
+
Mma
::
Shape
::
kK
-
1
)
/
Mma
::
Shape
::
kK
;
// Compute position within threadblock
int
thread_idx
=
threadIdx
.
x
;
// Construct iterators to A and B operands
typename
Mma
::
IteratorA
iterator_A
(
params
.
params_A
,
params
.
ref_A
.
data
(),
{
params
.
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
tb_offset_A
,
params
.
gather_A_indices
);
typename
Mma
::
IteratorB
iterator_B
(
params
.
params_B
,
params
.
ref_B
.
data
(),
{
problem_size_k
*
kInterleave
,
params
.
problem_size
.
n
()
/
kInterleave
},
thread_idx
,
tb_offset_B
,
params
.
gather_B_indices
);
typename
MatrixCoord
::
Index
scale_row_extent
=
isFinegrained
(
Mma
::
QuantOp
)
?
problem_size_k
/
64
:
1
;
typename
Mma
::
IteratorScale
iterator_scale
=
initialize_scale
<
typename
Mma
::
IteratorScale
,
Mma
::
QuantOp
>
(
params
.
params_scale
,
params
.
ref_scale
.
data
(),
params
.
ref_zero
.
data
(),
{
scale_row_extent
,
params
.
problem_size
.
n
()},
thread_idx
,
tb_offset_scale
,
params
.
group_size
);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int
warp_idx
=
__shfl_sync
(
0xffffffff
,
threadIdx
.
x
/
32
,
0
);
int
lane_idx
=
threadIdx
.
x
%
32
;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma
mma
(
shared_storage
.
main_loop
,
params
.
group_size
,
thread_idx
,
warp_idx
,
lane_idx
);
typename
Mma
::
FragmentC
accumulators
;
accumulators
.
clear
();
if
(
!
kSplitKSerial
||
gemm_k_iterations
>
0
)
{
// Compute threadblock-scoped matrix multiply-add
mma
(
gemm_k_iterations
,
accumulators
,
iterator_A
,
iterator_B
,
iterator_scale
,
accumulators
);
}
//
// Epilogue
//
EpilogueOutputOp
output_op
(
params
.
output_op
);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// assume identity swizzle
MatrixCoord
threadblock_offset
(
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
);
int
block_idx
=
threadblock_tile_offset
.
m
()
+
threadblock_tile_offset
.
n
()
*
params
.
grid_tiled_shape
.
m
();
// Construct the semaphore.
Semaphore
semaphore
(
params
.
semaphore
+
block_idx
,
thread_idx
);
// If performing a reduction via split-K, fetch the initial synchronization
if
(
kSplitKSerial
&&
params
.
grid_tiled_shape
.
k
()
>
1
)
{
// Fetch the synchronization lock initially but do not block.
semaphore
.
fetch
();
// Indicate which position in a serial reduction the output operator is currently updating
output_op
.
set_k_partition
(
threadblock_tile_offset
.
k
(),
params
.
grid_tiled_shape
.
k
());
}
// Tile iterator loading from source tensor.
typename
Epilogue
::
OutputTileIterator
iterator_C
(
params
.
params_C
,
params
.
ref_C
.
data
(),
params
.
problem_size
.
mn
(),
thread_idx
,
threadblock_offset
,
params
.
scatter_D_indices
);
// Tile iterator writing to destination tensor.
typename
Epilogue
::
OutputTileIterator
iterator_D
(
params
.
params_D
,
params
.
ref_D
.
data
(),
params
.
problem_size
.
mn
(),
thread_idx
,
threadblock_offset
,
params
.
scatter_D_indices
);
Epilogue
epilogue
(
shared_storage
.
epilogue
,
thread_idx
,
warp_idx
,
lane_idx
);
// Wait on the semaphore - this latency may have been covered by iterator construction
if
(
kSplitKSerial
&&
params
.
grid_tiled_shape
.
k
()
>
1
)
{
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if
(
threadblock_tile_offset
.
k
())
{
iterator_C
=
iterator_D
;
}
semaphore
.
wait
(
threadblock_tile_offset
.
k
());
}
// Execute the epilogue operator to update the destination tensor.
epilogue
(
output_op
,
iterator_D
,
accumulators
,
iterator_C
);
//
// Release the semaphore
//
if
(
kSplitKSerial
&&
params
.
grid_tiled_shape
.
k
()
>
1
)
{
int
lock
=
0
;
if
(
params
.
grid_tiled_shape
.
k
()
==
threadblock_tile_offset
.
k
()
+
1
)
{
// The final threadblock resets the semaphore for subsequent grids.
lock
=
0
;
}
else
{
// Otherwise, the semaphore is incremented
lock
=
threadblock_tile_offset
.
k
()
+
1
;
}
semaphore
.
release
(
lock
);
}
}
template
<
typename
CompilationArch
>
CUTLASS_DEVICE
void
run_kernel
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
if
constexpr
(
platform
::
is_same
<
KernelArch
,
CompilationArch
>::
value
)
{
run_kernel_
(
params
,
shared_storage
);
}
else
{
CUTLASS_NOT_IMPLEMENTED
();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel
<
arch
::
Sm75
>
(
params
,
shared_storage
);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
run_kernel
<
arch
::
Sm80
>
(
params
,
shared_storage
);
#elif (__CUDA_ARCH__ == 890)
run_kernel
<
arch
::
Sm89
>
(
params
,
shared_storage
);
#elif (__CUDA_ARCH__ >= 900)
CUTLASS_NOT_IMPLEMENTED
();
// Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
#else
static_assert
(
false
,
"Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."
);
#endif
#else
CUTLASS_NOT_IMPLEMENTED
();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass/gemm/kernel/gemm_grouped_problem_visitor.h>
#include <cutlass/trace.h>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
namespace
fused_moe
{
template
<
typename
ElementInput_
,
typename
ElementWeight_
,
typename
ElementOutput_
,
int
MaxTileM_
,
int
TileN_
,
int
TileK_
,
int
Stages_
,
Activation_Type
activation_type_
>
struct
Fused_Moe_Kernel_sm80
{
static
constexpr
int
kMaxTileM
=
MaxTileM_
;
static
constexpr
int
kTileN
=
isGateActivation
(
activation_type_
)
?
TileN_
/
2
:
TileN_
;
static
constexpr
int
kTileK
=
TileK_
;
static
constexpr
int
kStages
=
Stages_
;
static
constexpr
Activation_Type
activation_type
=
activation_type_
;
using
ElementInput
=
ElementInput_
;
using
ElementWeight
=
ElementWeight_
;
using
ElementOutput
=
ElementOutput_
;
using
BaseKernelTraits
=
Fused_Moe_Kernel_traits_sm80
<
ElementInput
,
ElementWeight
,
ElementOutput
,
kMaxTileM
,
kTileN
,
kTileK
,
kStages
,
activation_type
>
;
using
Routine_Arguments
=
Routine_Arguments
<
ElementInput
,
ElementWeight
,
ElementOutput
>
;
using
Routine_Params
=
Routine_Params
<
ElementInput
,
ElementWeight
,
ElementOutput
>
;
using
ProblemVisitor
=
cutlass
::
gemm
::
kernel
::
MoeProblemVisitor
<
cutlass
::
gemm
::
kernel
::
detail
::
GemmGroupedProblemSizeHelper
<
cutlass
::
gemm
::
GemmShape
<
kMaxTileM
,
kTileN
,
kTileK
>
,
false
>
,
cutlass
::
gemm
::
GemmShape
<
kMaxTileM
,
kTileN
,
kTileK
>
,
cutlass
::
gemm
::
kernel
::
GroupScheduleMode
::
kDeviceOnly
,
BaseKernelTraits
::
kThreadCount
,
BaseKernelTraits
::
kThreadCount
>
;
struct
Arguments
{
Routine_Arguments
routine_args
;
int
problem_count
{};
int
threadblock_count
{};
};
struct
Params
{
Routine_Params
routine_params
;
int
threadblock_count
{};
typename
ProblemVisitor
::
Params
problem_visitor_param
;
};
using
BaseKernelTraits_m16
=
Fused_Moe_Kernel_traits_sm80
<
ElementInput
,
ElementWeight
,
ElementOutput
,
16
,
kTileN
,
kTileK
,
kStages
,
activation_type
>
;
static
constexpr
bool
use_m16
=
TileK_
>=
64
;
// use tileshape m = 16 when original tileshape k >= 64
static
constexpr
int
kSmemSize
=
use_m16
?
(
BaseKernelTraits
::
kSmemSize
>
BaseKernelTraits_m16
::
kSmemSize
?
BaseKernelTraits
::
kSmemSize
:
BaseKernelTraits_m16
::
kSmemSize
)
:
BaseKernelTraits
::
kSmemSize
;
static
constexpr
int
kThreadCount
=
BaseKernelTraits
::
kThreadCount
;
static
constexpr
bool
can_implement
(
int
const
avaliable_smem_size
)
{
return
BaseKernelTraits
::
can_implement
(
avaliable_smem_size
);
}
static
Params
to_underlying_arguments
(
Arguments
const
&
args
)
{
return
{
{
args
.
routine_args
.
ptr_input
,
args
.
routine_args
.
ptr_fc1
,
args
.
routine_args
.
ptr_bias
,
args
.
routine_args
.
ptr_output
,
args
.
routine_args
.
total_tokens_including_expert
,
args
.
routine_args
.
gemm_n
,
args
.
routine_args
.
gemm_k
,
args
.
routine_args
.
num_expert
,
args
.
routine_args
.
bias_is_broadcast
},
args
.
threadblock_count
,
{
args
.
routine_args
.
total_tokens_including_expert
,
args
.
routine_args
.
gemm_n
,
args
.
routine_args
.
gemm_k
,
args
.
problem_count
,
nullptr
,
0
}};
}
CUTE_DEVICE
void
run_device
(
Params
const
&
params
)
{
#define ROUTINE_PATH(kTileM_size) \
{ \
constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \
using RoutineTraits = Fused_Moe_Kernel_routine_sm80<ElementInput, ElementWeight, ElementOutput, kTileM, \
kTileN, kTileK, kStages, activation_type>; \
RoutineTraits routine{}; \
int const block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \
routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \
}
typename
ProblemVisitor
::
SharedStorage
dummy_storage
{};
ProblemVisitor
problem_visitor
(
params
.
problem_visitor_param
,
dummy_storage
,
blockIdx
.
x
);
while
(
problem_visitor
.
next_tile
())
{
auto
problem_size
=
problem_visitor
.
problem_size
();
auto
grid_size
=
problem_visitor
.
grid_shape
(
problem_size
);
auto
problem_index
=
problem_visitor
.
problem_index
();
int32_t
cta_idx
=
int32_t
(
problem_visitor
.
threadblock_idx
());
int
const
gemm_m
=
problem_size
.
m
();
const
int32_t
block_m_idx_temp
=
cta_idx
/
grid_size
.
n
();
const
int32_t
block_n_idx
=
cta_idx
%
grid_size
.
n
();
int
const
residue_m
=
gemm_m
-
kMaxTileM
*
block_m_idx_temp
;
if
(
residue_m
>
kMaxTileM
/
2
)
{
using
RoutineTraits
=
Fused_Moe_Kernel_routine_sm80
<
ElementInput
,
ElementWeight
,
ElementOutput
,
kMaxTileM
,
kTileN
,
kTileK
,
kStages
,
activation_type
>
;
RoutineTraits
routine
{};
routine
.
run_routine
(
params
.
routine_params
,
problem_index
,
block_m_idx_temp
,
block_n_idx
,
gemm_m
);
}
else
{
if
constexpr
(
kMaxTileM
>=
128
)
{
if
(
residue_m
>
32
)
{
ROUTINE_PATH
(
64
);
}
else
if
(
residue_m
>
16
)
{
ROUTINE_PATH
(
32
);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH
(
16
);
}
}
else
if
(
kMaxTileM
==
64
)
{
if
(
residue_m
>
16
)
{
ROUTINE_PATH
(
32
);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH
(
16
);
}
}
else
if
(
kMaxTileM
==
32
)
{
// TODO: use cuda core gemm here
ROUTINE_PATH
(
16
);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH
(
16
);
}
}
problem_visitor
.
advance
(
gridDim
.
x
);
}
#undef ROUTINE_PATH
}
};
template
<
typename
GemmType
>
__global__
void
run_global
(
__grid_constant__
typename
GemmType
::
Params
const
params
)
{
GemmType
gemm
;
gemm
.
run_device
(
params
);
}
/// Computes the maximum number of active blocks per multiprocessor
template
<
typename
GemmType
>
static
int
fused_gemm_maximum_active_blocks
(
int
smem_capacity
=
-
1
)
{
CUTLASS_TRACE_HOST
(
"BaseGrouped::maximum_active_blocks()"
);
constexpr
int
smem_size
=
GemmType
::
kSmemSize
;
CUTLASS_TRACE_HOST
(
" smem_size: "
<<
smem_size
<<
" bytes"
);
cudaError_t
result
;
if
(
smem_size
>
(
48
<<
10
))
{
result
=
cudaFuncSetAttribute
(
run_global
<
GemmType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
result
!=
cudaSuccess
)
{
// Call cudaGetLastError() to clear the error bit
result
=
cudaGetLastError
();
CUTLASS_TRACE_HOST
(
" cudaFuncSetAttribute() returned error "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
}
int
max_active_blocks
=
-
1
;
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
run_global
<
GemmType
>
,
GemmType
::
kThreadCount
,
smem_size
);
if
(
result
!=
cudaSuccess
)
{
// Call cudaGetLastError() to clear the error bit
result
=
cudaGetLastError
();
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
return
max_active_blocks
;
}
}
// namespace fused_moe
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
namespace
fused_moe
{
template
<
typename
ElementInput_
,
typename
ElementWeight_
,
typename
ElementOutput_
,
int
TileM_
,
int
TileN_
,
int
TileK_
,
int
Stages_
,
Activation_Type
activation_type_
,
typename
Enable
=
void
>
struct
Fused_Moe_Kernel_routine_sm80
;
template
<
typename
ElementInput_
,
typename
ElementWeight_
,
typename
ElementOutput_
,
int
TileM_
,
int
TileN_
,
int
TileK_
,
int
Stages_
,
Activation_Type
activation_type_
>
struct
Fused_Moe_Kernel_routine_sm80
<
ElementInput_
,
ElementWeight_
,
ElementOutput_
,
TileM_
,
TileN_
,
TileK_
,
Stages_
,
activation_type_
,
std
::
enable_if_t
<
isGateActivation
(
activation_type_
)
>>
{
using
KT
=
Fused_Moe_Kernel_traits_sm80
<
ElementInput_
,
ElementWeight_
,
ElementOutput_
,
TileM_
,
TileN_
,
TileK_
,
Stages_
,
activation_type_
>
;
using
Params
=
Routine_Params
<
ElementInput_
,
ElementWeight_
,
ElementOutput_
>
;
CUTE_DEVICE
auto
gmem_tensor_init
(
int
const
problem_index
,
int
const
gemm_m
,
Params
const
&
params
)
{
using
X
=
cute
::
Underscore
;
int
const
M
=
gemm_m
;
int
const
N1
=
params
.
gemm_n
;
int
const
K1
=
params
.
gemm_k
;
bool
const
bias_is_broadcast
=
params
.
bias_is_broadcast
;
int
const
row_jump
=
((
problem_index
==
0
)
?
0
:
params
.
total_tokens_including_expert
[
problem_index
-
1
]);
typename
KT
::
ElementInput
const
*
ptr_input_
=
params
.
ptr_input
+
row_jump
*
K1
;
typename
KT
::
ElementWeight
const
*
ptr_fc1_gate_
=
params
.
ptr_fc1
+
(
2
*
problem_index
+
1
)
*
N1
*
K1
;
// TODO: we only focus on gated activation..
typename
KT
::
ElementWeight
const
*
ptr_fc1_
=
params
.
ptr_fc1
+
2
*
problem_index
*
N1
*
K1
;
// TODO: we only focus on gated activation..
typename
KT
::
ElementInput
const
*
ptr_bias_
=
(
params
.
ptr_bias
==
nullptr
)
?
nullptr
:
(
bias_is_broadcast
?
params
.
ptr_bias
+
2
*
problem_index
*
N1
:
params
.
ptr_bias
+
2
*
row_jump
*
N1
);
typename
KT
::
ElementInput
const
*
ptr_bias_gate_
=
(
params
.
ptr_bias
==
nullptr
)
?
nullptr
:
(
bias_is_broadcast
?
params
.
ptr_bias
+
(
2
*
problem_index
+
1
)
*
N1
:
params
.
ptr_bias
+
(
2
*
row_jump
+
1
)
*
N1
);
typename
KT
::
ElementOutput
*
ptr_output_
=
params
.
ptr_output
+
row_jump
*
N1
;
cute
::
Tensor
mInput_mk
=
cute
::
make_tensor
(
cute
::
make_gmem_ptr
(
static_cast
<
typename
KT
::
ElementInput
const
*>
(
ptr_input_
)),
cute
::
make_shape
(
M
,
K1
),
cute
::
make_stride
(
K1
,
cute
::
_1
{}));
cute
::
Tensor
mfc1_gate_nk
=
cute
::
make_tensor
(
cute
::
make_gmem_ptr
(
static_cast
<
typename
KT
::
ElementWeight
const
*>
(
ptr_fc1_gate_
)),
cute
::
make_shape
(
N1
,
K1
),
cute
::
make_stride
(
K1
,
cute
::
_1
{}));
cute
::
Tensor
mfc1_nk
=
cute
::
make_tensor
(
cute
::
make_gmem_ptr
(
static_cast
<
typename
KT
::
ElementWeight
const
*>
(
ptr_fc1_
)),
cute
::
make_shape
(
N1
,
K1
),
cute
::
make_stride
(
K1
,
cute
::
_1
{}));
cute
::
Tensor
mBias_mn
=
cute
::
make_tensor
(
cute
::
make_gmem_ptr
(
static_cast
<
typename
KT
::
ElementInput
const
*>
(
ptr_bias_
)),
cute
::
make_shape
(
M
,
N1
),
cute
::
make_stride
(
bias_is_broadcast
?
cute
::
Int
<
0
>
{}
:
N1
*
2
,
cute
::
_1
{}));
// trick: bias shape is [1, N], but we use [M, N].
cute
::
Tensor
mBias_gate_mn
=
cute
::
make_tensor
(
cute
::
make_gmem_ptr
(
static_cast
<
typename
KT
::
ElementInput
const
*>
(
ptr_bias_gate_
)),
cute
::
make_shape
(
M
,
N1
),
cute
::
make_stride
(
bias_is_broadcast
?
cute
::
Int
<
0
>
{}
:
N1
*
2
,
cute
::
_1
{}));
// trick: bias shape is [1, N], but we use [M, N].
cute
::
Tensor
mOutput_mn
=
cute
::
make_tensor
(
cute
::
make_gmem_ptr
(
static_cast
<
typename
KT
::
ElementInput
*>
(
ptr_output_
)),
cute
::
make_shape
(
M
,
N1
),
cute
::
make_stride
(
N1
,
cute
::
_1
{}));
cute
::
Tensor
gInput_mk
=
cute
::
local_tile
(
mInput_mk
,
typename
KT
::
TileShape
{},
cute
::
make_coord
(
cute
::
_
,
cute
::
_
,
cute
::
_
),
cute
::
Step
<
cute
::
_1
,
X
,
cute
::
_1
>
{});
// (BLK_M, BLK_K, m, k)
cute
::
Tensor
gfc1_gate_nk
=
cute
::
local_tile
(
mfc1_gate_nk
,
typename
KT
::
TileShape
{},
cute
::
make_coord
(
cute
::
_
,
cute
::
_
,
cute
::
_
),
cute
::
Step
<
X
,
cute
::
_1
,
cute
::
_1
>
{});
// (BLK_N, BLK_K, n, k)
cute
::
Tensor
gfc1_nk
=
cute
::
local_tile
(
mfc1_nk
,
typename
KT
::
TileShape
{},
cute
::
make_coord
(
cute
::
_
,
cute
::
_
,
cute
::
_
),
cute
::
Step
<
X
,
cute
::
_1
,
cute
::
_1
>
{});
// (BLK_N, BLK_K, n, k)
cute
::
Tensor
gBias_mn
=
cute
::
local_tile
(
mBias_mn
,
typename
KT
::
TileShape
{},
cute
::
make_coord
(
cute
::
_
,
cute
::
_
,
cute
::
_
),
cute
::
Step
<
cute
::
_1
,
cute
::
_1
,
X
>
{});
// (BLK_M, BLK_N, m, n)
cute
::
Tensor
gBias_gate_mn
=
cute
::
local_tile
(
mBias_gate_mn
,
typename
KT
::
TileShape
{},
cute
::
make_coord
(
cute
::
_
,
cute
::
_
,
cute
::
_
),
cute
::
Step
<
cute
::
_1
,
cute
::
_1
,
X
>
{});
// (BLK_M, BLK_N, m, n)
cute
::
Tensor
gOutput_mn
=
cute
::
local_tile
(
mOutput_mn
,
typename
KT
::
TileShape
{},
cute
::
make_coord
(
cute
::
_
,
cute
::
_
,
cute
::
_
),
cute
::
Step
<
cute
::
_1
,
cute
::
_1
,
X
>
{});
// (BLK_M, BLK_N, m, n)
return
cute
::
make_tuple
(
gInput_mk
,
gfc1_gate_nk
,
gfc1_nk
,
gBias_mn
,
gBias_gate_mn
,
gOutput_mn
);
}
// be careful, m_idx will change when use another tile shape..
CUTE_DEVICE
void
run_routine
(
Params
const
&
params
,
int
const
problem_index
,
int
const
block_m_idx
,
int
const
block_n_idx
,
int
const
gemm_m
)
{
extern
__shared__
char
smem_
[];
typename
KT
::
SharedStorage
&
shared_storage
=
*
reinterpret_cast
<
typename
KT
::
SharedStorage
*>
(
smem_
);
int
const
thread_idx
=
threadIdx
.
x
;
bool
const
bias_is_broadcast
=
params
.
bias_is_broadcast
;
// gmem tensor partition ..
auto
[
gInput_mk
,
gfc1_gate_nk
,
gfc1_nk
,
gBias_mn
,
gBias_gate_mn
,
gOutput_mn
]
=
gmem_tensor_init
(
problem_index
,
gemm_m
,
params
);
int
const
residue_m
=
gemm_m
-
block_m_idx
*
cute
::
size
<
0
>
(
gInput_mk
);
auto
const
n_tile_count
=
cute
::
size
<
2
>
(
gfc1_gate_nk
);
// smem tensor ..
cute
::
Tensor
sInput
=
cute
::
make_tensor
(
cute
::
make_smem_ptr
(
shared_storage
.
smem_input
.
data
()),
typename
KT
::
SmemLayoutA
{});
// (BLK_M, BLK_K, Stage)
cute
::
Tensor
sfc1_weight
=
cute
::
make_tensor
(
cute
::
make_smem_ptr
(
shared_storage
.
smem_fc1_weight
.
data
()),
typename
KT
::
SmemLayoutB
{});
// (BLK_N, BLK_K, Stage)
cute
::
Tensor
sfc1_gate_weight
=
cute
::
make_tensor
(
cute
::
make_smem_ptr
(
shared_storage
.
smem_fc1_gate_weight
.
data
()),
typename
KT
::
SmemLayoutB
{});
// (BLK_N, BLK_K, Stage)
cute
::
Tensor
sO
=
cute
::
make_tensor
(
cute
::
make_smem_ptr
(
shared_storage
.
smem_o
.
data
()),
typename
KT
::
SmemLayoutO
{});
// (BLK_M, BLK_N)
// (1) first step, get the fc1_res and fc1_gate
// (1.1) get partition for gmem -> smem
cute
::
Tensor
gInput
=
gInput_mk
(
cute
::
_
,
cute
::
_
,
block_m_idx
,
cute
::
_
);
// (BLK_M, BLK_K, k)
cute
::
Tensor
gfc1
=
gfc1_nk
(
cute
::
_
,
cute
::
_
,
block_n_idx
,
cute
::
_
);
// (BLK_N, BLK_K, k)
cute
::
Tensor
gfc1g
=
gfc1_gate_nk
(
cute
::
_
,
cute
::
_
,
block_n_idx
,
cute
::
_
);
// (BLK_N, BLK_K, k)
typename
KT
::
GmemTiledCopyA
gmem_tiled_copy_A
;
typename
KT
::
GmemTiledCopyB
gmem_tiled_copy_B
;
auto
gmem_thr_copy_A
=
gmem_tiled_copy_A
.
get_slice
(
thread_idx
);
auto
gmem_thr_copy_B
=
gmem_tiled_copy_B
.
get_slice
(
thread_idx
);
cute
::
Tensor
tInputgInput
=
gmem_thr_copy_A
.
partition_S
(
gInput
);
// (ACPY,ACPY_M,ACPY_K,k)
cute
::
Tensor
tInputsInput
=
gmem_thr_copy_A
.
partition_D
(
sInput
);
// (ACPY,ACPY_M,ACPY_K,Stage)
cute
::
Tensor
tfc1gfc1
=
gmem_thr_copy_B
.
partition_S
(
gfc1
);
// (BCPY,BCPY_N,BCPY_K,k)
cute
::
Tensor
tfc1sfc1
=
gmem_thr_copy_B
.
partition_D
(
sfc1_weight
);
// (BCPY,BCPY_N,BCPY_K,Stage)
cute
::
Tensor
tfc1ggfc1g
=
gmem_thr_copy_B
.
partition_S
(
gfc1g
);
// (BCPY,BCPY_N,BCPY_K,k)
cute
::
Tensor
tfc1gsfc1g
=
gmem_thr_copy_B
.
partition_D
(
sfc1_gate_weight
);
// (BCPY,BCPY_N,BCPY_K,Stage)
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
cute
::
Tensor
tInputpInput
=
cute
::
make_tensor
<
bool
>
(
cute
::
make_shape
(
cute
::
size
<
1
>
(
tInputsInput
),
cute
::
size
<
2
>
(
tInputsInput
)),
cute
::
Stride
<
cute
::
_1
,
cute
::
_0
>
{});
// Construct identity layout for sInput
cute
::
Tensor
cInput
=
make_identity_tensor
(
make_shape
(
cute
::
size
<
0
>
(
sInput
),
cute
::
size
<
1
>
(
sInput
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
cute
::
Tensor
tInputcInput
=
gmem_thr_copy_A
.
partition_S
(
cInput
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Set predicates for m bounds
CUTLASS_PRAGMA_UNROLL
for
(
int
m
=
0
;
m
<
cute
::
size
<
0
>
(
tInputpInput
);
++
m
)
{
tInputpInput
(
m
,
0
)
=
cute
::
get
<
0
>
(
tInputcInput
(
0
,
m
,
0
))
<
residue_m
;
// blk_m coord < residue_m
}
// (1.2) prefetch gmem -> smem
cute
::
clear
(
tInputsInput
);
// we don't need to clear tfc1sfc1..
auto
k_tile_iter
=
cute
::
make_coord_iterator
(
cute
::
size
<
2
>
(
gInput
));
// emm, iter start from 0
int
k_tile_count
=
cute
::
size
<
2
>
(
gInput
);
CUTLASS_PRAGMA_UNROLL
for
(
int
k_pipe
=
0
;
k_pipe
<
KT
::
Stages
-
1
;
++
k_pipe
)
{
if
(
k_tile_count
<=
0
)
{
cute
::
clear
(
tInputpInput
);
}
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
// use copy_if
cute
::
copy_if
(
gmem_tiled_copy_A
,
tInputpInput
,
tInputgInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
*
k_tile_iter
),
tInputsInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
k_pipe
));
cute
::
copy
(
gmem_tiled_copy_B
,
tfc1gfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
*
k_tile_iter
),
tfc1sfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
k_pipe
));
cute
::
copy
(
gmem_tiled_copy_B
,
tfc1ggfc1g
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
*
k_tile_iter
),
tfc1gsfc1g
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
k_pipe
));
cute
::
cp_async_fence
();
k_tile_count
--
;
if
(
k_tile_count
>
0
)
{
++
k_tile_iter
;
}
}
// (1.3) get partition for rf
typename
KT
::
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
thread_idx
);
cute
::
Tensor
tOrInput
=
thr_mma
.
partition_fragment_A
(
sInput
(
cute
::
_
,
cute
::
_
,
0
));
// (MMA,MMA_M,MMA_K)
cute
::
Tensor
tOrfc1
=
thr_mma
.
partition_fragment_B
(
sfc1_weight
(
cute
::
_
,
cute
::
_
,
0
));
// (MMA,MMA_N,MMA_K)
cute
::
Tensor
tOrfc1g
=
thr_mma
.
partition_fragment_B
(
sfc1_gate_weight
(
cute
::
_
,
cute
::
_
,
0
));
// (MMA,MMA_N,MMA_K)
cute
::
Tensor
accum
=
cute
::
partition_fragment_C
(
tiled_mma
,
cute
::
take
<
0
,
2
>
(
typename
KT
::
TileShape
{}));
// (MMA,MMA_M,MMA_N)
cute
::
Tensor
accum_gate
=
cute
::
partition_fragment_C
(
tiled_mma
,
cute
::
take
<
0
,
2
>
(
typename
KT
::
TileShape
{}));
// (MMA,MMA_M,MMA_N)
cute
::
clear
(
accum
);
cute
::
clear
(
accum_gate
);
// checkout the shape
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOrInput
)
==
cute
::
size
<
1
>
(
accum
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOrInput
)
==
cute
::
size
<
1
>
(
accum_gate
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOrfc1
)
==
cute
::
size
<
2
>
(
accum
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOrfc1
)
==
cute
::
size
<
2
>
(
accum_gate
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOrfc1g
)
==
cute
::
size
<
2
>
(
accum
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOrfc1g
)
==
cute
::
size
<
2
>
(
accum_gate
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
2
>
(
tOrInput
)
==
cute
::
size
<
2
>
(
tOrfc1
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
2
>
(
tOrInput
)
==
cute
::
size
<
2
>
(
tOrfc1g
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
cute
::
size
(
gmem_tiled_copy_A
)
==
cute
::
size
(
tiled_mma
));
CUTE_STATIC_ASSERT_V
(
cute
::
size
(
gmem_tiled_copy_B
)
==
cute
::
size
(
tiled_mma
));
// (1.4)retiling the smem and rf for copy..
auto
smem_tiled_copy_A
=
cute
::
make_tiled_copy_A
(
typename
KT
::
SmemCopyAtomA
{},
tiled_mma
);
auto
smem_thr_copy_A
=
smem_tiled_copy_A
.
get_thread_slice
(
thread_idx
);
cute
::
Tensor
tOsInput
=
smem_thr_copy_A
.
partition_S
(
sInput
);
// (CPY,CPY_M,CPY_K,Stage)
cute
::
Tensor
tOrInput_copy_view
=
smem_thr_copy_A
.
retile_D
(
tOrInput
);
// (CPY,CPY_M,CPY_K)
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOsInput
)
==
cute
::
size
<
1
>
(
tOrInput_copy_view
));
// CPY_M
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
2
>
(
tOsInput
)
==
cute
::
size
<
2
>
(
tOrInput_copy_view
));
// CPY_K
auto
smem_tiled_copy_B
=
cute
::
make_tiled_copy_B
(
typename
KT
::
SmemCopyAtomB
{},
tiled_mma
);
auto
smem_thr_copy_B
=
smem_tiled_copy_B
.
get_thread_slice
(
thread_idx
);
cute
::
Tensor
tOsfc1
=
smem_thr_copy_B
.
partition_S
(
sfc1_weight
);
// (CPY,CPY_N,CPY_K,Stage)
cute
::
Tensor
tOrfc1_copy_view
=
smem_thr_copy_B
.
retile_D
(
tOrfc1
);
// (CPY,CPY_N,CPY_K)
cute
::
Tensor
tOsfc1g
=
smem_thr_copy_B
.
partition_S
(
sfc1_gate_weight
);
// (CPY,CPY_N,CPY_K,Stage)
cute
::
Tensor
tOrfc1g_copy_view
=
smem_thr_copy_B
.
retile_D
(
tOrfc1g
);
// (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOsfc1
)
==
cute
::
size
<
1
>
(
tOrfc1_copy_view
));
// CPY_N
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
2
>
(
tOsfc1
)
==
cute
::
size
<
2
>
(
tOrfc1_copy_view
));
// CPY_K
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOsfc1g
)
==
cute
::
size
<
1
>
(
tOrfc1g_copy_view
));
// CPY_N
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
2
>
(
tOsfc1g
)
==
cute
::
size
<
2
>
(
tOrfc1g_copy_view
));
// CPY_K
// (1.5) mainloop
// Current pipe index in smem to read from
int
smem_pipe_read
=
0
;
// Current pipe index in smem to write to
int
smem_pipe_write
=
KT
::
Stages
-
1
;
cute
::
Tensor
tOsInput_p
=
tOsInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
cute
::
Tensor
tOsfc1_p
=
tOsfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
cute
::
Tensor
tOsfc1g_p
=
tOsfc1g
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
constexpr
int
K_BLOCK_MAX
=
cute
::
size
<
2
>
(
tOrInput
);
// prefetch register pipeline
if
constexpr
(
K_BLOCK_MAX
>
1
)
{
cute
::
cp_async_wait
<
KT
::
Stages
-
2
>
();
__syncthreads
();
// Prefetch the first rmem from the first k-tile
cute
::
copy
(
smem_tiled_copy_A
,
tOsInput_p
(
cute
::
_
,
cute
::
_
,
cute
::
Int
<
0
>
{}),
tOrInput_copy_view
(
cute
::
_
,
cute
::
_
,
cute
::
Int
<
0
>
{}));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1_p
(
cute
::
_
,
cute
::
_
,
cute
::
Int
<
0
>
{}),
tOrfc1_copy_view
(
cute
::
_
,
cute
::
_
,
cute
::
Int
<
0
>
{}));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1g_p
(
cute
::
_
,
cute
::
_
,
cute
::
Int
<
0
>
{}),
tOrfc1g_copy_view
(
cute
::
_
,
cute
::
_
,
cute
::
Int
<
0
>
{}));
}
// k loop for mainloop
CUTLASS_PRAGMA_NO_UNROLL
for
(;
k_tile_count
>
0
;
--
k_tile_count
)
{
cute
::
for_each
(
cute
::
make_int_sequence
<
K_BLOCK_MAX
>
{},
[
&
](
auto
k_block
)
{
if
(
k_block
==
K_BLOCK_MAX
-
1
)
{
tOsInput_p
=
tOsInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
tOsfc1_p
=
tOsfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
tOsfc1g_p
=
tOsfc1g
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
cute
::
cp_async_wait
<
KT
::
Stages
-
2
>
();
__syncthreads
();
}
// Load A, B shmem->regs for k_block+1
auto
k_block_next
=
(
k_block
+
cute
::
_1
{})
%
K_BLOCK_MAX
;
cute
::
copy
(
smem_tiled_copy_A
,
tOsInput_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrInput_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrfc1_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1g_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrfc1g_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
// Copy gmem to smem before computing gemm on each k-pipe
if
(
k_block
==
0
)
{
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute
::
copy_if
(
gmem_tiled_copy_A
,
tInputpInput
,
tInputgInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
*
k_tile_iter
),
tInputsInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_write
));
cute
::
copy
(
gmem_tiled_copy_B
,
tfc1gfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
*
k_tile_iter
),
tfc1sfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_write
));
cute
::
copy
(
gmem_tiled_copy_B
,
tfc1ggfc1g
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
*
k_tile_iter
),
tfc1gsfc1g
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_write
));
cute
::
cp_async_fence
();
if
(
k_tile_count
-
1
>
0
)
{
++
k_tile_iter
;
}
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
smem_pipe_write
=
smem_pipe_read
;
++
smem_pipe_read
;
smem_pipe_read
=
(
smem_pipe_read
==
KT
::
Stages
)
?
0
:
smem_pipe_read
;
}
// Thread-level register gemm for k_block
cute
::
gemm
(
tiled_mma
,
accum
,
tOrInput
(
cute
::
_
,
cute
::
_
,
k_block
),
tOrfc1
(
cute
::
_
,
cute
::
_
,
k_block
),
accum
);
cute
::
gemm
(
tiled_mma
,
accum_gate
,
tOrInput
(
cute
::
_
,
cute
::
_
,
k_block
),
tOrfc1g
(
cute
::
_
,
cute
::
_
,
k_block
),
accum_gate
);
});
}
// load tail
cute
::
for_each
(
cute
::
make_int_sequence
<
KT
::
Stages
-
2
>
{},
[
&
](
auto
WaitIndex
)
{
k_tile_count
--
;
using
WaitIndex_t
=
decltype
(
WaitIndex
);
cute
::
for_each
(
cute
::
make_int_sequence
<
K_BLOCK_MAX
>
{},
[
&
](
auto
k_block
)
{
if
(
k_block
==
K_BLOCK_MAX
-
1
)
{
tOsInput_p
=
tOsInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
tOsfc1_p
=
tOsfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
tOsfc1g_p
=
tOsfc1g
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
cute
::
cp_async_wait
<
KT
::
Stages
-
3
-
WaitIndex_t
::
value
>
();
__syncthreads
();
}
// Load A, B shmem->regs for k_block+1
auto
k_block_next
=
(
k_block
+
cute
::
_1
{})
%
K_BLOCK_MAX
;
cute
::
copy
(
smem_tiled_copy_A
,
tOsInput_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrInput_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrfc1_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1g_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrfc1g_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
if
(
k_block
==
0
)
{
// only update smem_pipe_read
++
smem_pipe_read
;
smem_pipe_read
=
(
smem_pipe_read
==
KT
::
Stages
)
?
0
:
smem_pipe_read
;
}
// Thread-level register gemm for k_block
cute
::
gemm
(
tiled_mma
,
accum
,
tOrInput
(
cute
::
_
,
cute
::
_
,
k_block
),
tOrfc1
(
cute
::
_
,
cute
::
_
,
k_block
),
accum
);
cute
::
gemm
(
tiled_mma
,
accum_gate
,
tOrInput
(
cute
::
_
,
cute
::
_
,
k_block
),
tOrfc1g
(
cute
::
_
,
cute
::
_
,
k_block
),
accum_gate
);
});
});
// mma tail
cute
::
for_each
(
cute
::
make_int_sequence
<
K_BLOCK_MAX
>
{},
[
&
](
auto
k_block
)
{
// Load A, B shmem->regs for k_block+1
auto
k_block_next
=
(
k_block
+
cute
::
_1
{})
%
K_BLOCK_MAX
;
cute
::
copy
(
smem_tiled_copy_A
,
tOsInput_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrInput_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrfc1_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1g_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrfc1g_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
// Thread-level register gemm for k_block
cute
::
gemm
(
tiled_mma
,
accum
,
tOrInput
(
cute
::
_
,
cute
::
_
,
k_block
),
tOrfc1
(
cute
::
_
,
cute
::
_
,
k_block
),
accum
);
cute
::
gemm
(
tiled_mma
,
accum_gate
,
tOrInput
(
cute
::
_
,
cute
::
_
,
k_block
),
tOrfc1g
(
cute
::
_
,
cute
::
_
,
k_block
),
accum_gate
);
});
// if (cute::thread0()) {
// cute::print(accum_gate(0, 0, 0));
// printf("\n");
// }
// (2) add bias if it has..
if
(
params
.
ptr_bias
!=
nullptr
)
{
cute
::
Tensor
gBias
=
gBias_mn
(
cute
::
_
,
cute
::
_
,
bias_is_broadcast
?
0
:
block_m_idx
,
block_n_idx
);
cute
::
Tensor
gBias_gate
=
gBias_gate_mn
(
cute
::
_
,
cute
::
_
,
bias_is_broadcast
?
0
:
block_m_idx
,
block_n_idx
);
cute
::
Tensor
tOgBias
=
thr_mma
.
partition_C
(
gBias
);
cute
::
Tensor
tOgBiasg
=
thr_mma
.
partition_C
(
gBias_gate
);
for
(
int
i
=
0
;
i
<
cute
::
size
(
accum
);
i
++
)
{
accum
(
i
)
+=
tOgBias
(
i
);
accum_gate
(
i
)
+=
tOgBiasg
(
i
);
}
}
// (3) calculate swiglu
using
ActivationFn
=
typename
KT
::
ActivationFn
;
ActivationFn
fn
{};
CUTLASS_PRAGMA_UNROLL
for
(
int
temp_iter
=
0
;
temp_iter
<
cute
::
size
(
accum
);
temp_iter
++
)
{
accum
(
temp_iter
)
=
fn
(
accum_gate
(
temp_iter
))
*
accum
(
temp_iter
);
}
// (4) push all the result to smem
// (4.1) convert result from ElementAccum to ElementInput
cute
::
Tensor
temp_accum
=
util_convert_type
<
KT
::
ElementOutput
>
(
accum
);
// if (cute::thread0()) {
// cute::print(temp_accum(0, 0, 0));
// printf("\n");
// }
// (4.2) retile rf and smem for copy back..
auto
smem_tiled_copy_O
=
cute
::
make_tiled_copy_C
(
typename
KT
::
SmemCopyAtomO
{},
tiled_mma
);
auto
smem_thr_copy_O
=
smem_tiled_copy_O
.
get_thread_slice
(
thread_idx
);
// cute::clear(sO);
cute
::
Tensor
taccumrO
=
smem_thr_copy_O
.
retile_S
(
temp_accum
);
cute
::
Tensor
taccumsO
=
smem_thr_copy_O
.
partition_D
(
sO
);
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
cute
::
copy
(
smem_tiled_copy_O
,
taccumrO
,
taccumsO
);
__syncthreads
();
// (4.4) sO -> rO -> gO
typename
KT
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
thread_idx
);
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
// remember, for all the threads in the same col, they have the same idx for bias..
cute
::
Tensor
gO
=
gOutput_mn
(
cute
::
_
,
cute
::
_
,
block_m_idx
,
block_n_idx
);
// cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row..
auto
tOsO
=
gmem_thr_copy_O
.
partition_S
(
sO
);
auto
tOgO
=
gmem_thr_copy_O
.
partition_D
(
gO
);
// auto tOgBias = gmem_thr_copy_O.partition_D(gBias);
cute
::
Tensor
cOutput
=
cute
::
make_identity_tensor
(
cute
::
make_shape
(
cute
::
size
<
0
>
(
typename
KT
::
TileShape
{}),
cute
::
size
<
1
>
(
typename
KT
::
TileShape
{})));
cute
::
Tensor
tOcO
=
gmem_thr_copy_O
.
partition_D
(
cOutput
);
cute
::
Tensor
tOrO
=
cute
::
make_tensor
<
KT
::
ElementOutput
>
(
cute
::
shape
(
tOgO
));
cute
::
copy
(
gmem_tiled_copy_O
,
tOsO
,
tOrO
);
CUTLASS_PRAGMA_UNROLL
for
(
int
m
=
0
;
m
<
cute
::
size
<
1
>
(
tOgO
);
++
m
)
{
if
(
cute
::
get
<
0
>
(
tOcO
(
0
,
m
,
0
))
<
residue_m
)
{
cute
::
copy
(
gmem_tiled_copy_O
,
tOrO
(
cute
::
_
,
m
,
cute
::
_
),
tOgO
(
cute
::
_
,
m
,
cute
::
_
));
}
}
}
};
template
<
typename
ElementInput_
,
typename
ElementWeight_
,
typename
ElementOutput_
,
int
TileM_
,
int
TileN_
,
int
TileK_
,
int
Stages_
,
Activation_Type
activation_type_
>
struct
Fused_Moe_Kernel_routine_sm80
<
ElementInput_
,
ElementWeight_
,
ElementOutput_
,
TileM_
,
TileN_
,
TileK_
,
Stages_
,
activation_type_
,
std
::
enable_if_t
<!
isGateActivation
(
activation_type_
)
>>
{
using
KT
=
Fused_Moe_Kernel_traits_sm80
<
ElementInput_
,
ElementWeight_
,
ElementOutput_
,
TileM_
,
TileN_
,
TileK_
,
Stages_
,
activation_type_
>
;
using
Params
=
Routine_Params
<
ElementInput_
,
ElementWeight_
,
ElementOutput_
>
;
CUTE_DEVICE
auto
gmem_tensor_init
(
int
const
problem_index
,
int
const
gemm_m
,
Params
const
&
params
)
{
using
X
=
cute
::
Underscore
;
int
const
M
=
gemm_m
;
int
const
N1
=
params
.
gemm_n
;
int
const
K1
=
params
.
gemm_k
;
bool
const
bias_is_broadcast
=
params
.
bias_is_broadcast
;
int
const
row_jump
=
((
problem_index
==
0
)
?
0
:
params
.
total_tokens_including_expert
[
problem_index
-
1
]);
typename
KT
::
ElementInput
const
*
ptr_input_
=
params
.
ptr_input
+
row_jump
*
K1
;
typename
KT
::
ElementWeight
const
*
ptr_fc1_
=
params
.
ptr_fc1
+
problem_index
*
N1
*
K1
;
typename
KT
::
ElementInput
const
*
ptr_bias_
=
(
params
.
ptr_bias
==
nullptr
)
?
nullptr
:
(
bias_is_broadcast
?
params
.
ptr_bias
+
problem_index
*
N1
:
params
.
ptr_bias
+
row_jump
*
N1
);
typename
KT
::
ElementOutput
*
ptr_output_
=
params
.
ptr_output
+
row_jump
*
N1
;
cute
::
Tensor
mInput_mk
=
cute
::
make_tensor
(
cute
::
make_gmem_ptr
(
static_cast
<
typename
KT
::
ElementInput
const
*>
(
ptr_input_
)),
cute
::
make_shape
(
M
,
K1
),
cute
::
make_stride
(
K1
,
cute
::
_1
{}));
cute
::
Tensor
mfc1_nk
=
cute
::
make_tensor
(
cute
::
make_gmem_ptr
(
static_cast
<
typename
KT
::
ElementWeight
const
*>
(
ptr_fc1_
)),
cute
::
make_shape
(
N1
,
K1
),
cute
::
make_stride
(
K1
,
cute
::
_1
{}));
cute
::
Tensor
mBias_mn
=
cute
::
make_tensor
(
cute
::
make_gmem_ptr
(
static_cast
<
typename
KT
::
ElementInput
const
*>
(
ptr_bias_
)),
cute
::
make_shape
(
M
,
N1
),
cute
::
make_stride
(
bias_is_broadcast
?
cute
::
Int
<
0
>
{}
:
N1
,
cute
::
_1
{}));
// trick: bias shape is [1, N], but we use [M, N].
cute
::
Tensor
mOutput_mn
=
cute
::
make_tensor
(
cute
::
make_gmem_ptr
(
static_cast
<
typename
KT
::
ElementInput
*>
(
ptr_output_
)),
cute
::
make_shape
(
M
,
N1
),
cute
::
make_stride
(
N1
,
cute
::
_1
{}));
cute
::
Tensor
gInput_mk
=
cute
::
local_tile
(
mInput_mk
,
typename
KT
::
TileShape
{},
cute
::
make_coord
(
cute
::
_
,
cute
::
_
,
cute
::
_
),
cute
::
Step
<
cute
::
_1
,
X
,
cute
::
_1
>
{});
// (BLK_M, BLK_K, m, k)
cute
::
Tensor
gfc1_nk
=
cute
::
local_tile
(
mfc1_nk
,
typename
KT
::
TileShape
{},
cute
::
make_coord
(
cute
::
_
,
cute
::
_
,
cute
::
_
),
cute
::
Step
<
X
,
cute
::
_1
,
cute
::
_1
>
{});
// (BLK_N, BLK_K, n, k)
cute
::
Tensor
gBias_mn
=
cute
::
local_tile
(
mBias_mn
,
typename
KT
::
TileShape
{},
cute
::
make_coord
(
cute
::
_
,
cute
::
_
,
cute
::
_
),
cute
::
Step
<
cute
::
_1
,
cute
::
_1
,
X
>
{});
// (BLK_M, BLK_N, m, n)
cute
::
Tensor
gOutput_mn
=
cute
::
local_tile
(
mOutput_mn
,
typename
KT
::
TileShape
{},
cute
::
make_coord
(
cute
::
_
,
cute
::
_
,
cute
::
_
),
cute
::
Step
<
cute
::
_1
,
cute
::
_1
,
X
>
{});
// (BLK_M, BLK_N, m, n)
return
cute
::
make_tuple
(
gInput_mk
,
gfc1_nk
,
gBias_mn
,
gOutput_mn
);
}
// be careful, m_idx will change when use another tile shape..
CUTE_DEVICE
void
run_routine
(
Params
const
&
params
,
int
const
problem_index
,
int
const
block_m_idx
,
int
const
block_n_idx
,
int
const
gemm_m
)
{
extern
__shared__
char
smem_
[];
typename
KT
::
SharedStorage
&
shared_storage
=
*
reinterpret_cast
<
typename
KT
::
SharedStorage
*>
(
smem_
);
int
const
thread_idx
=
threadIdx
.
x
;
bool
const
bias_is_broadcast
=
params
.
bias_is_broadcast
;
// gmem tensor partition ..
auto
[
gInput_mk
,
gfc1_nk
,
gBias_mn
,
gOutput_mn
]
=
gmem_tensor_init
(
problem_index
,
gemm_m
,
params
);
int
const
residue_m
=
gemm_m
-
block_m_idx
*
cute
::
size
<
0
>
(
gInput_mk
);
auto
const
n_tile_count
=
cute
::
size
<
2
>
(
gfc1_nk
);
// smem tensor ..
cute
::
Tensor
sInput
=
cute
::
make_tensor
(
cute
::
make_smem_ptr
(
shared_storage
.
smem_input
.
data
()),
typename
KT
::
SmemLayoutA
{});
// (BLK_M, BLK_K, Stage)
cute
::
Tensor
sfc1_weight
=
cute
::
make_tensor
(
cute
::
make_smem_ptr
(
shared_storage
.
smem_fc1_weight
.
data
()),
typename
KT
::
SmemLayoutB
{});
// (BLK_N, BLK_K, Stage)
cute
::
Tensor
sO
=
cute
::
make_tensor
(
cute
::
make_smem_ptr
(
shared_storage
.
smem_o
.
data
()),
typename
KT
::
SmemLayoutO
{});
// (BLK_M, BLK_N)
// (1) first step, get the fc1_res and fc1_gate
// (1.1) get partition for gmem -> smem
cute
::
Tensor
gInput
=
gInput_mk
(
cute
::
_
,
cute
::
_
,
block_m_idx
,
cute
::
_
);
// (BLK_M, BLK_K, k)
cute
::
Tensor
gfc1
=
gfc1_nk
(
cute
::
_
,
cute
::
_
,
block_n_idx
,
cute
::
_
);
// (BLK_N, BLK_K, k)
typename
KT
::
GmemTiledCopyA
gmem_tiled_copy_A
;
typename
KT
::
GmemTiledCopyB
gmem_tiled_copy_B
;
auto
gmem_thr_copy_A
=
gmem_tiled_copy_A
.
get_slice
(
thread_idx
);
auto
gmem_thr_copy_B
=
gmem_tiled_copy_B
.
get_slice
(
thread_idx
);
cute
::
Tensor
tInputgInput
=
gmem_thr_copy_A
.
partition_S
(
gInput
);
// (ACPY,ACPY_M,ACPY_K,k)
cute
::
Tensor
tInputsInput
=
gmem_thr_copy_A
.
partition_S
(
sInput
);
// (ACPY,ACPY_M,ACPY_K,Stage)
cute
::
Tensor
tfc1gfc1
=
gmem_thr_copy_B
.
partition_S
(
gfc1
);
// (BCPY,BCPY_N,BCPY_K,k)
cute
::
Tensor
tfc1sfc1
=
gmem_thr_copy_B
.
partition_D
(
sfc1_weight
);
// (BCPY,BCPY_N,BCPY_K,Stage)
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
cute
::
Tensor
tInputpInput
=
cute
::
make_tensor
<
bool
>
(
cute
::
make_shape
(
cute
::
size
<
1
>
(
tInputsInput
),
cute
::
size
<
2
>
(
tInputsInput
)),
cute
::
Stride
<
cute
::
_1
,
cute
::
_0
>
{});
// Construct identity layout for sInput
cute
::
Tensor
cInput
=
make_identity_tensor
(
make_shape
(
cute
::
size
<
0
>
(
sInput
),
cute
::
size
<
1
>
(
sInput
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
cute
::
Tensor
tInputcInput
=
gmem_thr_copy_A
.
partition_S
(
cInput
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Set predicates for m bounds
CUTLASS_PRAGMA_UNROLL
for
(
int
m
=
0
;
m
<
cute
::
size
<
0
>
(
tInputpInput
);
++
m
)
{
tInputpInput
(
m
,
0
)
=
cute
::
get
<
0
>
(
tInputcInput
(
0
,
m
,
0
))
<
residue_m
;
// blk_m coord < residue_m
}
// (1.2) prefetch gmem -> smem
cute
::
clear
(
tInputsInput
);
// we don't need to clear tfc1sfc1..
auto
k_tile_iter
=
cute
::
make_coord_iterator
(
cute
::
size
<
2
>
(
gInput
));
// emm, iter start from 0
int
k_tile_count
=
cute
::
size
<
2
>
(
gInput
);
CUTLASS_PRAGMA_UNROLL
for
(
int
k_pipe
=
0
;
k_pipe
<
KT
::
Stages
-
1
;
++
k_pipe
)
{
if
(
k_tile_count
<=
0
)
{
cute
::
clear
(
tInputpInput
);
}
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
// use copy_if
cute
::
copy_if
(
gmem_tiled_copy_A
,
tInputpInput
,
tInputgInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
*
k_tile_iter
),
tInputsInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
k_pipe
));
cute
::
copy
(
gmem_tiled_copy_B
,
tfc1gfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
*
k_tile_iter
),
tfc1sfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
k_pipe
));
cute
::
cp_async_fence
();
k_tile_count
--
;
if
(
k_tile_count
>
0
)
{
++
k_tile_iter
;
}
}
// (1.3) get partition for rf
typename
KT
::
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
thread_idx
);
cute
::
Tensor
tOrInput
=
thr_mma
.
partition_fragment_A
(
sInput
(
cute
::
_
,
cute
::
_
,
0
));
// (MMA,MMA_M,MMA_K)
cute
::
Tensor
tOrfc1
=
thr_mma
.
partition_fragment_B
(
sfc1_weight
(
cute
::
_
,
cute
::
_
,
0
));
// (MMA,MMA_N,MMA_K)
cute
::
Tensor
accum
=
cute
::
partition_fragment_C
(
tiled_mma
,
cute
::
take
<
0
,
2
>
(
typename
KT
::
TileShape
{}));
// (MMA,MMA_M,MMA_N)
cute
::
clear
(
accum
);
// checkout the shape
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOrInput
)
==
cute
::
size
<
1
>
(
accum
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOrfc1
)
==
cute
::
size
<
2
>
(
accum
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
2
>
(
tOrInput
)
==
cute
::
size
<
2
>
(
tOrfc1
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
cute
::
size
(
gmem_tiled_copy_A
)
==
cute
::
size
(
tiled_mma
));
CUTE_STATIC_ASSERT_V
(
cute
::
size
(
gmem_tiled_copy_B
)
==
cute
::
size
(
tiled_mma
));
// (1.4)retiling the smem and rf for copy..
auto
smem_tiled_copy_A
=
cute
::
make_tiled_copy_A
(
typename
KT
::
SmemCopyAtomA
{},
tiled_mma
);
auto
smem_thr_copy_A
=
smem_tiled_copy_A
.
get_thread_slice
(
thread_idx
);
cute
::
Tensor
tOsInput
=
smem_thr_copy_A
.
partition_S
(
sInput
);
// (CPY,CPY_M,CPY_K,Stage)
cute
::
Tensor
tOrInput_copy_view
=
smem_thr_copy_A
.
retile_D
(
tOrInput
);
// (CPY,CPY_M,CPY_K)
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOsInput
)
==
cute
::
size
<
1
>
(
tOrInput_copy_view
));
// CPY_M
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
2
>
(
tOsInput
)
==
cute
::
size
<
2
>
(
tOrInput_copy_view
));
// CPY_K
auto
smem_tiled_copy_B
=
cute
::
make_tiled_copy_B
(
typename
KT
::
SmemCopyAtomB
{},
tiled_mma
);
auto
smem_thr_copy_B
=
smem_tiled_copy_B
.
get_thread_slice
(
thread_idx
);
cute
::
Tensor
tOsfc1
=
smem_thr_copy_B
.
partition_S
(
sfc1_weight
);
// (CPY,CPY_N,CPY_K,Stage)
cute
::
Tensor
tOrfc1_copy_view
=
smem_thr_copy_B
.
retile_D
(
tOrfc1
);
// (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
tOsfc1
)
==
cute
::
size
<
1
>
(
tOrfc1_copy_view
));
// CPY_N
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
2
>
(
tOsfc1
)
==
cute
::
size
<
2
>
(
tOrfc1_copy_view
));
// CPY_K
// (1.5) mainloop
// Current pipe index in smem to read from
int
smem_pipe_read
=
0
;
// Current pipe index in smem to write to
int
smem_pipe_write
=
KT
::
Stages
-
1
;
cute
::
Tensor
tOsInput_p
=
tOsInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
cute
::
Tensor
tOsfc1_p
=
tOsfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
constexpr
int
K_BLOCK_MAX
=
cute
::
size
<
2
>
(
tOrInput
);
// prefetch register pipeline
if
constexpr
(
K_BLOCK_MAX
>
1
)
{
cute
::
cp_async_wait
<
KT
::
Stages
-
2
>
();
__syncthreads
();
// Prefetch the first rmem from the first k-tile
cute
::
copy
(
smem_tiled_copy_A
,
tOsInput_p
(
cute
::
_
,
cute
::
_
,
cute
::
Int
<
0
>
{}),
tOrInput_copy_view
(
cute
::
_
,
cute
::
_
,
cute
::
Int
<
0
>
{}));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1_p
(
cute
::
_
,
cute
::
_
,
cute
::
Int
<
0
>
{}),
tOrfc1_copy_view
(
cute
::
_
,
cute
::
_
,
cute
::
Int
<
0
>
{}));
}
// k loop for mainloop
CUTLASS_PRAGMA_NO_UNROLL
for
(;
k_tile_count
>
0
;
--
k_tile_count
)
{
cute
::
for_each
(
cute
::
make_int_sequence
<
K_BLOCK_MAX
>
{},
[
&
](
auto
k_block
)
{
if
(
k_block
==
K_BLOCK_MAX
-
1
)
{
tOsInput_p
=
tOsInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
tOsfc1_p
=
tOsfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
cute
::
cp_async_wait
<
KT
::
Stages
-
2
>
();
__syncthreads
();
}
// Load A, B shmem->regs for k_block+1
auto
k_block_next
=
(
k_block
+
cute
::
_1
{})
%
K_BLOCK_MAX
;
cute
::
copy
(
smem_tiled_copy_A
,
tOsInput_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrInput_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrfc1_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
// Copy gmem to smem before computing gemm on each k-pipe
if
(
k_block
==
0
)
{
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute
::
copy_if
(
gmem_tiled_copy_A
,
tInputpInput
,
tInputgInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
*
k_tile_iter
),
tInputsInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_write
));
cute
::
copy
(
gmem_tiled_copy_B
,
tfc1gfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
*
k_tile_iter
),
tfc1sfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_write
));
cute
::
cp_async_fence
();
if
(
k_tile_count
-
1
>
0
)
{
++
k_tile_iter
;
}
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
smem_pipe_write
=
smem_pipe_read
;
++
smem_pipe_read
;
smem_pipe_read
=
(
smem_pipe_read
==
KT
::
Stages
)
?
0
:
smem_pipe_read
;
}
// Thread-level register gemm for k_block
cute
::
gemm
(
tiled_mma
,
accum
,
tOrInput
(
cute
::
_
,
cute
::
_
,
k_block
),
tOrfc1
(
cute
::
_
,
cute
::
_
,
k_block
),
accum
);
});
}
// load tail
cute
::
for_each
(
cute
::
make_int_sequence
<
KT
::
Stages
-
2
>
{},
[
&
](
auto
WaitIndex
)
{
k_tile_count
--
;
using
WaitIndex_t
=
decltype
(
WaitIndex
);
cute
::
for_each
(
cute
::
make_int_sequence
<
K_BLOCK_MAX
>
{},
[
&
](
auto
k_block
)
{
if
(
k_block
==
K_BLOCK_MAX
-
1
)
{
tOsInput_p
=
tOsInput
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
tOsfc1_p
=
tOsfc1
(
cute
::
_
,
cute
::
_
,
cute
::
_
,
smem_pipe_read
);
cute
::
cp_async_wait
<
KT
::
Stages
-
3
-
WaitIndex_t
::
value
>
();
__syncthreads
();
}
// Load A, B shmem->regs for k_block+1
auto
k_block_next
=
(
k_block
+
cute
::
_1
{})
%
K_BLOCK_MAX
;
cute
::
copy
(
smem_tiled_copy_A
,
tOsInput_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrInput_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrfc1_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
if
(
k_block
==
0
)
{
// only update smem_pipe_read
++
smem_pipe_read
;
smem_pipe_read
=
(
smem_pipe_read
==
KT
::
Stages
)
?
0
:
smem_pipe_read
;
}
// Thread-level register gemm for k_block
cute
::
gemm
(
tiled_mma
,
accum
,
tOrInput
(
cute
::
_
,
cute
::
_
,
k_block
),
tOrfc1
(
cute
::
_
,
cute
::
_
,
k_block
),
accum
);
});
});
// mma tail
cute
::
for_each
(
cute
::
make_int_sequence
<
K_BLOCK_MAX
>
{},
[
&
](
auto
k_block
)
{
// Load A, B shmem->regs for k_block+1
auto
k_block_next
=
(
k_block
+
cute
::
_1
{})
%
K_BLOCK_MAX
;
cute
::
copy
(
smem_tiled_copy_A
,
tOsInput_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrInput_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
cute
::
copy
(
smem_tiled_copy_B
,
tOsfc1_p
(
cute
::
_
,
cute
::
_
,
k_block_next
),
tOrfc1_copy_view
(
cute
::
_
,
cute
::
_
,
k_block_next
));
// Thread-level register gemm for k_block
cute
::
gemm
(
tiled_mma
,
accum
,
tOrInput
(
cute
::
_
,
cute
::
_
,
k_block
),
tOrfc1
(
cute
::
_
,
cute
::
_
,
k_block
),
accum
);
});
// if (cute::thread0()) {
// cute::print(accum_gate(0, 0, 0));
// printf("\n");
// }
// (2) add bias if it has..
if
(
params
.
ptr_bias
!=
nullptr
)
{
cute
::
Tensor
gBias
=
gBias_mn
(
cute
::
_
,
cute
::
_
,
bias_is_broadcast
?
0
:
block_m_idx
,
block_n_idx
);
cute
::
Tensor
tOgBias
=
thr_mma
.
partition_C
(
gBias
);
for
(
int
i
=
0
;
i
<
cute
::
size
(
accum
);
i
++
)
{
accum
(
i
)
+=
tOgBias
(
i
);
}
}
// (3) calculate swiglu
using
ActivationFn
=
typename
KT
::
ActivationFn
;
ActivationFn
fn
{};
CUTLASS_PRAGMA_UNROLL
for
(
int
temp_iter
=
0
;
temp_iter
<
cute
::
size
(
accum
);
temp_iter
++
)
{
accum
(
temp_iter
)
=
fn
(
accum
(
temp_iter
));
}
// (4) push all the result to smem
// (4.1) convert result from ElementAccum to ElementInput
cute
::
Tensor
temp_accum
=
util_convert_type
<
KT
::
ElementOutput
>
(
accum
);
// if (cute::thread0()) {
// cute::print(temp_accum(0, 0, 0));
// printf("\n");
// }
// (4.2) retile rf and smem for copy back..
auto
smem_tiled_copy_O
=
cute
::
make_tiled_copy_C
(
typename
KT
::
SmemCopyAtomO
{},
tiled_mma
);
auto
smem_thr_copy_O
=
smem_tiled_copy_O
.
get_thread_slice
(
thread_idx
);
// cute::clear(sO);
cute
::
Tensor
taccumrO
=
smem_thr_copy_O
.
retile_S
(
temp_accum
);
cute
::
Tensor
taccumsO
=
smem_thr_copy_O
.
partition_D
(
sO
);
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
cute
::
copy
(
smem_tiled_copy_O
,
taccumrO
,
taccumsO
);
__syncthreads
();
// (4.4) sO -> rO -> gO
typename
KT
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
thread_idx
);
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
cute
::
Tensor
gO
=
gOutput_mn
(
cute
::
_
,
cute
::
_
,
block_m_idx
,
block_n_idx
);
auto
tOsO
=
gmem_thr_copy_O
.
partition_S
(
sO
);
auto
tOgO
=
gmem_thr_copy_O
.
partition_D
(
gO
);
cute
::
Tensor
cOutput
=
cute
::
make_identity_tensor
(
cute
::
make_shape
(
cute
::
size
<
0
>
(
typename
KT
::
TileShape
{}),
cute
::
size
<
1
>
(
typename
KT
::
TileShape
{})));
cute
::
Tensor
tOcO
=
gmem_thr_copy_O
.
partition_D
(
cOutput
);
cute
::
Tensor
tOrO
=
cute
::
make_tensor
<
KT
::
ElementOutput
>
(
cute
::
shape
(
tOgO
));
cute
::
copy
(
gmem_tiled_copy_O
,
tOsO
,
tOrO
);
CUTLASS_PRAGMA_UNROLL
for
(
int
m
=
0
;
m
<
cute
::
size
<
1
>
(
tOgO
);
++
m
)
{
if
(
cute
::
get
<
0
>
(
tOcO
(
0
,
m
,
0
))
<
residue_m
)
{
cute
::
copy
(
gmem_tiled_copy_O
,
tOrO
(
cute
::
_
,
m
,
cute
::
_
),
tOgO
(
cute
::
_
,
m
,
cute
::
_
));
}
}
}
};
}
// namespace fused_moe
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass_extensions/epilogue_helpers.h>
#include <cutlass_extensions/gemm/kernel/moe_cute_util.cuh>
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
namespace
fused_moe
{
template
<
typename
ElementInput
,
typename
ElementWeight
,
typename
ElementOutput
>
struct
Routine_Arguments
{
ElementInput
*
ptr_input
{};
ElementWeight
*
ptr_fc1
{};
ElementInput
*
ptr_bias
{};
ElementOutput
*
ptr_output
{};
int64_t
const
*
total_tokens_including_expert
{};
int
gemm_n
{};
int
gemm_k
{};
int
num_expert
{};
bool
bias_is_broadcast
{};
};
template
<
typename
ElementInput
,
typename
ElementWeight
,
typename
ElementOutput
>
struct
Routine_Params
{
ElementInput
*
ptr_input
{};
ElementWeight
*
ptr_fc1
{};
ElementInput
*
ptr_bias
{};
ElementOutput
*
ptr_output
{};
int64_t
const
*
total_tokens_including_expert
{};
int
gemm_n
{};
int
gemm_k
{};
int
num_expert
{};
bool
bias_is_broadcast
{};
};
enum
class
Activation_Type
{
Gelu
=
0
,
Relu
,
Silu
,
Swiglu
,
Geglu
,
Identity
,
InvalidType
};
constexpr
bool
isGateActivation
(
Activation_Type
const
&
activation_type
)
{
return
activation_type
==
Activation_Type
::
Swiglu
||
activation_type
==
Activation_Type
::
Geglu
;
}
template
<
typename
CutlassExtensionEpilogueTag
>
constexpr
Activation_Type
EpilogueRouting
(
bool
/*is_gate*/
)
{
return
Activation_Type
::
InvalidType
;
}
template
<
>
constexpr
Activation_Type
EpilogueRouting
<
tensorrt_llm
::
cutlass_extensions
::
EpilogueOpDefault
>
(
bool
/*is_gate*/
)
{
return
Activation_Type
::
Identity
;
}
template
<
>
constexpr
Activation_Type
EpilogueRouting
<
tensorrt_llm
::
cutlass_extensions
::
EpilogueOpDefaultReLU
>
(
bool
/*is_gate*/
)
{
return
Activation_Type
::
Relu
;
}
template
<
>
constexpr
Activation_Type
EpilogueRouting
<
tensorrt_llm
::
cutlass_extensions
::
EpilogueOpDefaultSilu
>
(
bool
is_gate
)
{
return
is_gate
?
Activation_Type
::
Swiglu
:
Activation_Type
::
Silu
;
}
template
<
>
constexpr
Activation_Type
EpilogueRouting
<
tensorrt_llm
::
cutlass_extensions
::
EpilogueOpDefaultFtGelu
>
(
bool
is_gate
)
{
return
is_gate
?
Activation_Type
::
Geglu
:
Activation_Type
::
Gelu
;
}
/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/
template
<
typename
ElementInput_
,
typename
ElementWeight_
,
typename
ElementOutput_
,
int
TileM_
,
int
TileN_
,
int
TileK_
,
int
Stages_
,
Activation_Type
activation_type
>
struct
Fused_Moe_Kernel_traits_sm80
{
using
ElementInput
=
ElementInput_
;
using
ElementWeight
=
ElementWeight_
;
using
ElementAccum
=
float
;
using
ElementOutput
=
ElementOutput_
;
using
index_t
=
uint32_t
;
static_assert
(
TileM_
%
16
==
0
);
static_assert
(
TileN_
%
32
==
0
);
static_assert
(
TileK_
%
32
==
0
);
static
constexpr
int
Stages
=
Stages_
;
static
constexpr
int
kTileM
=
TileM_
;
static
constexpr
int
kTileN
=
TileN_
;
static
constexpr
int
kTileK
=
(
kTileM
>
16
)
?
(
TileK_
)
:
(
TileK_
>=
64
?
TileK_
:
64
);
// tile shape
using
TileShape
=
cute
::
Shape
<
cute
::
Int
<
kTileM
>
,
cute
::
Int
<
kTileN
>
,
cute
::
Int
<
kTileK
>>
;
static
constexpr
int
kWarpsCount
=
4
;
static
constexpr
int
kThreadCount
=
kWarpsCount
*
32
;
// MMA atom arch and layout
using
MMA_Atom_Arch
=
std
::
conditional_t
<
std
::
is_same_v
<
ElementInput
,
cutlass
::
half_t
>
,
cute
::
MMA_Atom
<
cute
::
SM80_16x8x16_F32F16F16F32_TN
>
,
cute
::
MMA_Atom
<
cute
::
SM80_16x8x16_F32BF16BF16F32_TN
>>
;
// using ValLayoutMNK = cute::Layout<cute::Shape<cute::_1, cute::_2, cute::_1>>;
using
ThreadLayoutMNK
=
std
::
conditional_t
<
kTileM
==
16
,
cute
::
Layout
<
cute
::
Shape
<
cute
::
_1
,
cute
::
Int
<
kWarpsCount
/
1
>
,
cute
::
_1
>>
,
cute
::
Layout
<
cute
::
Shape
<
cute
::
_2
,
cute
::
Int
<
kWarpsCount
/
2
>
,
cute
::
_1
>>>
;
using
ValLayoutMNK
=
std
::
conditional_t
<
kTileM
==
16
,
cute
::
Tile
<
cute
::
_16
,
cute
::
_64
,
cute
::
_16
>
,
cute
::
Tile
<
cute
::
_32
,
cute
::
_32
,
cute
::
_16
>>
;
using
TiledMma
=
cute
::
TiledMMA
<
MMA_Atom_Arch
,
ThreadLayoutMNK
,
ValLayoutMNK
>
;
// 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4
static
constexpr
int
kAlignment
=
8
;
static
constexpr
int
kBlcokKSmem
=
(
kTileM
==
16
)
?
64
:
32
;
// A memory copy operand
using
DefaultOperandA
=
DefaultGemm_TensorOpSm80_OperandA
<
ElementInput
,
cutlass
::
layout
::
RowMajor
,
kAlignment
,
kBlcokKSmem
>
;
using
SmemLayoutAtomA
=
typename
DefaultOperandA
::
SmemLayoutAtom
;
using
SmemCopyAtomA
=
typename
DefaultOperandA
::
SmemCopyAtom
;
using
GmemTiledCopyA
=
typename
DefaultOperandA
::
GmemTiledCopy
;
// B memory copy operand
using
DefaultOperandB
=
DefaultGemm_TensorOpSm80_OperandB
<
ElementWeight
,
cutlass
::
layout
::
ColumnMajor
,
kAlignment
,
kBlcokKSmem
>
;
using
SmemLayoutAtomB
=
typename
DefaultOperandB
::
SmemLayoutAtom
;
using
SmemCopyAtomB
=
typename
DefaultOperandB
::
SmemCopyAtom
;
using
GmemTiledCopyB
=
typename
DefaultOperandB
::
GmemTiledCopy
;
// Output memory copy operand
using
SmemLayoutAtomO
=
SmemLayoutAtomA
;
using
SmemCopyAtomO
=
cute
::
Copy_Atom
<
cute
::
DefaultCopy
,
ElementOutput
>
;
static
constexpr
int
kGmemElementPerLoad
=
sizeof
(
cute
::
uint128_t
)
/
sizeof
(
ElementOutput
);
static
constexpr
int
kGmemTrheadsPerRow
=
kBlcokKSmem
/
kGmemElementPerLoad
;
using
GmemLayoutAtomO
=
cute
::
Layout
<
cute
::
Shape
<
cute
::
Int
<
kThreadCount
/
kGmemTrheadsPerRow
>
,
cute
::
Int
<
kGmemTrheadsPerRow
>>
,
cute
::
Stride
<
cute
::
Int
<
kGmemTrheadsPerRow
>
,
cute
::
_1
>>
;
using
GmemTiledCopyO
=
decltype
(
cute
::
make_tiled_copy
(
cute
::
Copy_Atom
<
cute
::
DefaultCopy
,
ElementOutput
>
{},
GmemLayoutAtomO
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_1
,
cute
::
_8
>>
{}));
static_assert
(
cute
::
rank
(
SmemLayoutAtomA
{})
==
2
);
static_assert
(
cute
::
size
<
0
>
(
TileShape
{})
%
cute
::
size
<
0
>
(
SmemLayoutAtomA
{})
==
0
);
// M
static_assert
(
cute
::
size
<
2
>
(
TileShape
{})
%
cute
::
size
<
1
>
(
SmemLayoutAtomA
{})
==
0
);
// K
static_assert
(
cute
::
rank
(
SmemLayoutAtomB
{})
==
2
);
static_assert
(
cute
::
size
<
1
>
(
TileShape
{})
%
cute
::
size
<
0
>
(
SmemLayoutAtomB
{})
==
0
);
// N
static_assert
(
cute
::
size
<
2
>
(
TileShape
{})
%
cute
::
size
<
1
>
(
SmemLayoutAtomB
{})
==
0
);
// K
using
SmemLayoutA
=
decltype
(
cute
::
tile_to_shape
(
SmemLayoutAtomA
{},
cute
::
make_shape
(
cute
::
shape
<
0
>
(
TileShape
{}),
cute
::
shape
<
2
>
(
TileShape
{}),
cute
::
Int
<
Stages
>
{})));
// BLK_M, BLK_K, Stages
using
SmemLayoutB
=
decltype
(
cute
::
tile_to_shape
(
SmemLayoutAtomB
{},
cute
::
make_shape
(
cute
::
shape
<
1
>
(
TileShape
{}),
cute
::
shape
<
2
>
(
TileShape
{}),
cute
::
Int
<
Stages
>
{})));
// BLK_N, BLK_K, Stages
using
SmemLayoutO
=
decltype
(
cute
::
tile_to_shape
(
SmemLayoutAtomO
{},
cute
::
make_shape
(
cute
::
shape
<
0
>
(
TileShape
{}),
cute
::
shape
<
1
>
(
TileShape
{}))));
// BLK_M, BLK_N
// we need at least 2 stages..
static_assert
(
Stages
>=
2
);
struct
SharedStorageNormal
:
cute
::
aligned_struct
<
128
>
{
cute
::
array_aligned
<
ElementInput
,
cute
::
cosize_v
<
SmemLayoutA
>>
smem_input
;
cute
::
array_aligned
<
ElementInput
,
cute
::
cosize_v
<
SmemLayoutB
>>
smem_fc1_weight
;
cute
::
array_aligned
<
ElementInput
,
cute
::
cosize_v
<
SmemLayoutO
>>
smem_o
;
};
struct
SharedStorageGate
:
cute
::
aligned_struct
<
128
>
{
cute
::
array_aligned
<
ElementInput
,
cute
::
cosize_v
<
SmemLayoutA
>>
smem_input
;
cute
::
array_aligned
<
ElementInput
,
cute
::
cosize_v
<
SmemLayoutB
>>
smem_fc1_gate_weight
;
cute
::
array_aligned
<
ElementInput
,
cute
::
cosize_v
<
SmemLayoutB
>>
smem_fc1_weight
;
cute
::
array_aligned
<
ElementInput
,
cute
::
cosize_v
<
SmemLayoutO
>>
smem_o
;
};
using
SharedStorage
=
std
::
conditional_t
<
isGateActivation
(
activation_type
),
SharedStorageGate
,
SharedStorageNormal
>
;
using
ActivationFn
=
std
::
conditional_t
<
activation_type
==
Activation_Type
::
Gelu
||
activation_type
==
Activation_Type
::
Geglu
,
cutlass
::
epilogue
::
thread
::
GELU
<
float
>
,
std
::
conditional_t
<
activation_type
==
Activation_Type
::
Relu
,
cutlass
::
epilogue
::
thread
::
ReLU
<
float
>
,
std
::
conditional_t
<
activation_type
==
Activation_Type
::
Silu
||
activation_type
==
Activation_Type
::
Swiglu
,
cutlass
::
epilogue
::
thread
::
SiLu
<
float
>
,
cutlass
::
epilogue
::
thread
::
Identity
<
float
>>>>
;
static
constexpr
int
kSmemSize
=
static_cast
<
int
>
(
sizeof
(
SharedStorage
));
static
constexpr
bool
can_implement
(
int
const
avaliable_smem_size
)
{
return
avaliable_smem_size
>
kSmemSize
;
}
// #endif
};
}
// namespace fused_moe
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief Scheduler for grouped GEMM
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "cutlass/matrix_coord.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
/// Visitor class to abstract away the algorithm for iterating over tiles
template
<
typename
ThreadblockShape
,
GroupScheduleMode
GroupScheduleMode_
,
int
PrefetchTileCount
,
int
ThreadCount
,
bool
Transposed
=
false
>
struct
GemmMoeProblemVisitor
:
public
MoeProblemVisitor
<
detail
::
GemmGroupedProblemSizeHelper
<
ThreadblockShape
,
Transposed
>
,
ThreadblockShape
,
GroupScheduleMode_
,
PrefetchTileCount
,
ThreadCount
>
{
static
bool
const
kTransposed
=
Transposed
;
using
ProblemSizeHelper
=
detail
::
GemmGroupedProblemSizeHelper
<
ThreadblockShape
,
Transposed
>
;
using
Base
=
MoeProblemVisitor
<
ProblemSizeHelper
,
ThreadblockShape
,
GroupScheduleMode_
,
PrefetchTileCount
,
ThreadCount
>
;
using
Params
=
typename
Base
::
Params
;
using
SharedStorage
=
typename
Base
::
SharedStorage
;
//
// Methods
//
CUTLASS_DEVICE
GemmMoeProblemVisitor
(
Params
const
&
params_
,
SharedStorage
&
shared_storage_
,
int32_t
block_idx
)
:
Base
(
params_
,
shared_storage_
,
block_idx
)
{
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
kernel
{
////////////////////////////////////////////////////////////////////////////////
/*
* Stateless universal device GEMM kernel type that treats GEMM as
* a composition of a collective mainloop and a collective epilogue.
*
* Supports both the 2.x and 3.x APIs based on whether the first type is
* a cute::tuple<> or not.
* 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h
* 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp
*
* In the following declaration, the name preceding the 'Or' refers to
* 3.x API type argument order, and the name succeeding the 'Or' refers to
* 2.x API type argument order. Template arguments without two names
* belong to the 3.x API only.
**/
template
<
class
ProblemShapeOrThreadblockMma_
,
// (m, n, k) or (m, n, k, l)
class
CollectiveMainloopOrEpilogue_
,
class
CollectiveEpilogueOrThreadblockSwizzle_
,
class
TileScheduler_
=
void
,
class
Enable
=
void
>
class
GemmUniversalGated
;
////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::kernel
////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp"
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp"
////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief GEMM kernel to support the epilogue visitor model
for customized softmax partial reduction epilogue fusion.
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
its usage has been stabilized. For now, it is included in this example to demonstrate
some basic output fusion options.
original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h"
namespace
tk
=
tensorrt_llm
::
common
;
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
///! Threadblock swizzling function
>
struct
GemmWithEpilogueVisitor
{
public:
using
Mma
=
Mma_
;
using
Epilogue
=
Epilogue_
;
using
EpilogueVisitor
=
typename
Epilogue
::
Visitor
;
using
ThreadblockSwizzle
=
ThreadblockSwizzle_
;
using
ElementA
=
typename
Mma
::
IteratorA
::
Element
;
using
LayoutA
=
typename
Mma
::
IteratorA
::
Layout
;
using
TensorRefA
=
TensorRef
<
ElementA
,
LayoutA
>
;
using
ElementB
=
typename
Mma
::
IteratorB
::
Element
;
using
LayoutB
=
typename
Mma
::
IteratorB
::
Layout
;
using
TensorRefB
=
TensorRef
<
ElementB
,
LayoutB
>
;
using
ElementCompute
=
typename
EpilogueVisitor
::
ElementCompute
;
using
LayoutAlphaCol
=
cutlass
::
layout
::
RowMajor
;
using
LayoutAlphaRow
=
cutlass
::
layout
::
ColumnMajor
;
using
TensorRefAlphaCol
=
TensorRef
<
ElementCompute
,
LayoutAlphaCol
>
;
using
TensorRefAlphaRow
=
TensorRef
<
ElementCompute
,
LayoutAlphaRow
>
;
using
ElementC
=
typename
EpilogueVisitor
::
ElementOutput
;
using
LayoutC
=
typename
Epilogue
::
Layout
;
using
TensorRefC
=
TensorRef
<
ElementC
,
LayoutC
>
;
static
ComplexTransform
const
kTransformA
=
Mma
::
kTransformA
;
static
ComplexTransform
const
kTransformB
=
Mma
::
kTransformB
;
using
Operator
=
typename
Mma
::
Operator
;
using
OperatorClass
=
typename
Mma
::
Operator
::
OperatorClass
;
using
ThreadblockShape
=
typename
Mma
::
Shape
;
using
WarpShape
=
typename
Mma
::
Operator
::
Shape
;
using
InstructionShape
=
typename
Mma
::
Policy
::
Operator
::
InstructionShape
;
using
ArchTag
=
typename
Mma
::
ArchTag
;
using
EpilogueOutputOp
=
typename
Epilogue
::
Visitor
::
ElementwiseFunctor
;
// Define type so GemmUniversalBase doesn't complain
static
int
const
kStages
=
Mma
::
kStages
;
static
int
const
kAlignmentA
=
Mma
::
IteratorA
::
AccessType
::
kElements
;
static
int
const
kAlignmentB
=
Mma
::
IteratorB
::
AccessType
::
kElements
;
static
int
const
kAlignmentC
=
EpilogueVisitor
::
kElementsPerAccess
;
/// Warp count (concept: GemmShape)
using
WarpCount
=
typename
Mma
::
WarpCount
;
static
int
const
kThreadCount
=
32
*
WarpCount
::
kCount
;
/// Split-K preserves splits that are 128b aligned
static
int
const
kSplitKAlignment
=
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
);
//
// Structures
//
/// Argument structure
struct
Arguments
{
//
// Data members
//
GemmUniversalMode
mode
;
GemmCoord
problem_size
;
int
batch_count
;
TensorRefA
ref_A
;
TensorRefB
ref_B
;
tk
::
QuantMode
quant_option
;
TensorRefAlphaCol
ref_alpha_col
;
TensorRefAlphaRow
ref_alpha_row
;
TensorRefC
ref_C
;
TensorRefC
ref_D
;
int64_t
batch_stride_A
;
int64_t
batch_stride_B
;
int64_t
batch_stride_D
;
typename
EpilogueVisitor
::
Arguments
epilogue_visitor
;
//
// Methods
//
Arguments
()
:
mode
(
GemmUniversalMode
::
kGemm
)
,
batch_count
(
1
)
{
}
/// constructs an arguments structure
Arguments
(
GemmUniversalMode
mode_
,
GemmCoord
problem_size_
,
int
batch_count_
,
TensorRefA
ref_A_
,
TensorRefB
ref_B_
,
tk
::
QuantMode
quant_option_
,
TensorRefAlphaCol
ref_alpha_col_
,
TensorRefAlphaRow
ref_alpha_row_
,
TensorRefC
ref_C_
,
TensorRefC
ref_D_
,
int64_t
batch_stride_A_
,
int64_t
batch_stride_B_
,
typename
EpilogueVisitor
::
Arguments
epilogue_visitor_
)
:
mode
(
mode_
)
,
problem_size
(
problem_size_
)
,
batch_count
(
batch_count_
)
,
ref_A
(
ref_A_
)
,
ref_B
(
ref_B_
)
,
quant_option
(
quant_option_
)
,
ref_alpha_col
(
ref_alpha_col_
)
,
ref_alpha_row
(
ref_alpha_row_
)
,
ref_C
(
ref_C_
)
,
ref_D
(
ref_D_
)
,
batch_stride_A
(
batch_stride_A_
)
,
batch_stride_B
(
batch_stride_B_
)
,
batch_stride_D
(
0
)
,
epilogue_visitor
(
epilogue_visitor_
)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct
Params
{
cutlass
::
gemm
::
GemmCoord
problem_size
;
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
swizzle_log_tile
;
typename
Mma
::
IteratorA
::
Params
params_A
;
typename
Mma
::
IteratorB
::
Params
params_B
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Params
params_alpha_col
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Params
params_alpha_row
;
typename
EpilogueVisitor
::
OutputTileIterator
::
Params
params_C
;
typename
EpilogueVisitor
::
OutputTileIterator
::
Params
params_D
;
GemmUniversalMode
mode
;
int
batch_count
;
int
gemm_k_size
;
void
*
ptr_A
;
void
*
ptr_B
;
tk
::
QuantMode
quant_option
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Element
*
ptr_alpha_col
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Element
*
ptr_alpha_row
;
ElementC
*
ptr_C
;
ElementC
*
ptr_D
;
int64_t
batch_stride_A
;
int64_t
batch_stride_B
;
typename
EpilogueVisitor
::
Params
epilogue_visitor
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
:
swizzle_log_tile
(
0
)
,
params_A
(
0
)
,
params_B
(
0
)
,
params_alpha_col
(
0
)
,
params_C
(
0
)
,
params_D
(
0
)
,
batch_count
(
0
)
,
gemm_k_size
(
0
)
,
mode
(
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
)
,
ptr_A
(
nullptr
)
,
ptr_B
(
nullptr
)
,
ptr_alpha_col
(
nullptr
)
,
ptr_alpha_row
(
nullptr
)
,
ptr_C
(
nullptr
)
,
ptr_D
(
nullptr
)
,
batch_stride_A
(
0
)
,
batch_stride_B
(
0
)
{
}
Params
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape_
,
int
gemm_k_size_
,
int
*
workspace_
)
:
problem_size
(
args
.
problem_size
)
,
swizzle_log_tile
(
0
)
,
params_A
(
args
.
ref_A
.
layout
())
,
params_B
(
args
.
ref_B
.
layout
())
,
params_alpha_col
(
args
.
ref_alpha_col
.
layout
())
,
params_alpha_row
(
args
.
ref_alpha_col
.
layout
())
,
params_C
(
args
.
ref_C
.
layout
())
,
params_D
(
args
.
ref_D
.
layout
())
,
mode
(
args
.
mode
)
,
batch_count
(
args
.
batch_count
)
,
gemm_k_size
(
args
.
problem_size
.
k
())
,
ptr_A
(
args
.
ref_A
.
data
())
,
ptr_B
(
args
.
ref_B
.
data
())
,
quant_option
(
args
.
quant_option
)
,
ptr_alpha_col
(
args
.
ref_alpha_col
.
data
())
,
ptr_alpha_row
(
args
.
ref_alpha_row
.
data
())
,
ptr_C
(
args
.
ref_C
.
data
())
,
ptr_D
(
args
.
ref_D
.
data
())
,
batch_stride_A
(
args
.
batch_stride_A
)
,
batch_stride_B
(
args
.
batch_stride_B
)
,
epilogue_visitor
(
args
.
epilogue_visitor
)
{
ThreadblockSwizzle
threadblock_swizzle
;
grid_tiled_shape
=
threadblock_swizzle
.
get_tiled_shape
(
args
.
problem_size
,
{
ThreadblockShape
::
kM
,
ThreadblockShape
::
kN
,
ThreadblockShape
::
kK
},
args
.
batch_count
);
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
||
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
int
const
kAlignK
=
const_max
(
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
),
1
);
gemm_k_size
=
round_up
(
ceil_div
(
args
.
problem_size
.
k
(),
args
.
batch_count
),
kAlignK
);
if
(
gemm_k_size
)
{
grid_tiled_shape
.
k
()
=
ceil_div
(
args
.
problem_size
.
k
(),
gemm_k_size
);
}
}
swizzle_log_tile
=
threadblock_swizzle
.
get_log_tile
(
grid_tiled_shape
);
}
};
/// Shared memory storage structure
union
SharedStorage
{
typename
Mma
::
SharedStorage
main_loop
;
struct
{
typename
Epilogue
::
SharedStorage
epilogue
;
typename
EpilogueVisitor
::
SharedStorage
visitor
;
}
epilogue
;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithEpilogueVisitor
()
{}
/// Determines whether kernel satisfies alignment
static
Status
can_implement
(
cutlass
::
gemm
::
GemmCoord
const
&
problem_size
)
{
CUTLASS_TRACE_HOST
(
"GemmWithEpilogueVisitor::can_implement()"
);
static
int
const
kAlignmentA
=
Mma
::
IteratorA
::
AccessType
::
kElements
;
static
int
const
kAlignmentB
=
Mma
::
IteratorB
::
AccessType
::
kElements
;
static
int
const
kAlignmentC
=
EpilogueVisitor
::
OutputTileIterator
::
kElementsPerAccess
;
bool
isAMisaligned
=
false
;
bool
isBMisaligned
=
false
;
bool
isCMisaligned
=
false
;
if
(
platform
::
is_same
<
LayoutA
,
layout
::
RowMajor
>::
value
)
{
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajor
>::
value
)
{
isAMisaligned
=
problem_size
.
m
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajor
>::
value
)
{
isBMisaligned
=
problem_size
.
n
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
ColumnMajor
>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
64
>>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
if
(
platform
::
is_same
<
LayoutC
,
layout
::
RowMajor
>::
value
)
{
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajor
>::
value
)
{
isCMisaligned
=
problem_size
.
m
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
if
(
isAMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for A operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
if
(
isBMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for B operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
if
(
isCMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for C operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
CUTLASS_TRACE_HOST
(
" returning kSuccess"
);
return
Status
::
kSuccess
;
}
static
Status
can_implement
(
Arguments
const
&
args
)
{
return
can_implement
(
args
.
problem_size
);
}
static
size_t
get_extra_workspace_size
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape
)
{
return
0
;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
CUTLASS_DEVICE
void
run_kernel_
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
// Compute threadblock location
ThreadblockSwizzle
threadblock_swizzle
;
cutlass
::
gemm
::
GemmCoord
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// Early exit if CTA is out of range
if
(
params
.
grid_tiled_shape
.
m
()
<=
threadblock_tile_offset
.
m
()
||
params
.
grid_tiled_shape
.
n
()
<=
threadblock_tile_offset
.
n
())
{
return
;
}
int
offset_k
=
0
;
int
problem_size_k
=
params
.
problem_size
.
k
();
ElementA
*
ptr_A
=
static_cast
<
ElementA
*>
(
params
.
ptr_A
);
ElementB
*
ptr_B
=
static_cast
<
ElementB
*>
(
params
.
ptr_B
);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
||
params
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
if
(
threadblock_tile_offset
.
k
()
+
1
<
params
.
grid_tiled_shape
.
k
())
{
problem_size_k
=
(
threadblock_tile_offset
.
k
()
+
1
)
*
params
.
gemm_k_size
;
}
offset_k
=
threadblock_tile_offset
.
k
()
*
params
.
gemm_k_size
;
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kBatched
)
{
ptr_A
+=
threadblock_tile_offset
.
k
()
*
params
.
batch_stride_A
;
ptr_B
+=
threadblock_tile_offset
.
k
()
*
params
.
batch_stride_B
;
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kArray
)
{
ptr_A
=
static_cast
<
ElementA
*
const
*>
(
params
.
ptr_A
)[
threadblock_tile_offset
.
k
()];
ptr_B
=
static_cast
<
ElementB
*
const
*>
(
params
.
ptr_B
)[
threadblock_tile_offset
.
k
()];
}
#endif
// Compute initial location in logical coordinates
cutlass
::
MatrixCoord
tb_offset_A
{
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
offset_k
,
};
cutlass
::
MatrixCoord
tb_offset_B
{
offset_k
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
};
// Compute position within threadblock
int
thread_idx
=
threadIdx
.
x
;
// Construct iterators to A and B operands
typename
Mma
::
IteratorA
iterator_A
(
params
.
params_A
,
ptr_A
,
{
params
.
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorB
iterator_B
(
params
.
params_B
,
ptr_B
,
{
problem_size_k
,
params
.
problem_size
.
n
()},
thread_idx
,
tb_offset_B
);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int
warp_idx
=
__shfl_sync
(
0xffffffff
,
threadIdx
.
x
/
32
,
0
);
int
lane_idx
=
threadIdx
.
x
%
32
;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma
mma
(
shared_storage
.
main_loop
,
thread_idx
,
warp_idx
,
lane_idx
);
typename
Mma
::
FragmentC
accumulators
;
accumulators
.
clear
();
// Compute threadblock-scoped matrix multiply-add
int
gemm_k_iterations
=
(
problem_size_k
-
offset_k
+
Mma
::
Shape
::
kK
-
1
)
/
Mma
::
Shape
::
kK
;
// Compute threadblock-scoped matrix multiply-add
mma
(
gemm_k_iterations
,
accumulators
,
iterator_A
,
iterator_B
,
accumulators
);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// assume identity swizzle
MatrixCoord
threadblock_offset
(
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
);
int
block_idx
=
threadblock_tile_offset
.
m
()
+
threadblock_tile_offset
.
n
()
*
params
.
grid_tiled_shape
.
m
();
//
// Construct the epilogue visitor
//
EpilogueVisitor
epilogue_visitor
(
params
.
epilogue_visitor
,
shared_storage
.
epilogue
.
visitor
,
params
.
problem_size
.
mn
(),
thread_idx
,
warp_idx
,
lane_idx
,
params
.
params_alpha_col
,
params
.
params_C
,
params
.
params_D
,
params
.
quant_option
,
params
.
ptr_alpha_row
,
params
.
ptr_alpha_col
,
params
.
ptr_C
,
params
.
ptr_D
,
threadblock_offset
,
blockIdx
.
y
*
params
.
problem_size
.
m
());
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
)
{
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor
.
set_k_partition
(
threadblock_tile_offset
.
k
(),
params
.
grid_tiled_shape
.
k
());
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kBatched
||
params
.
mode
==
GemmUniversalMode
::
kArray
)
{
epilogue_visitor
.
set_batch_index
(
threadblock_tile_offset
.
k
());
}
// Construct the epilogue
Epilogue
epilogue
(
shared_storage
.
epilogue
.
epilogue
,
thread_idx
,
warp_idx
,
lane_idx
);
// Execute the epilogue operator to update the destination tensor.
epilogue
(
epilogue_visitor
,
accumulators
);
}
template
<
typename
CompilationArch
>
CUTLASS_DEVICE
void
run_kernel
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
if
constexpr
(
platform
::
is_same
<
ArchTag
,
CompilationArch
>::
value
)
{
run_kernel_
(
params
,
shared_storage
);
}
else
{
CUTLASS_NOT_IMPLEMENTED
();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750)
run_kernel
<
arch
::
Sm72
>
(
params
,
shared_storage
);
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel
<
arch
::
Sm75
>
(
params
,
shared_storage
);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
run_kernel
<
arch
::
Sm80
>
(
params
,
shared_storage
);
#elif (__CUDA_ARCH__ >= 900)
// TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels.
run_kernel
<
arch
::
Sm80
>
(
params
,
shared_storage
);
#else
static_assert
(
false
,
"Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."
);
#endif
#else
CUTLASS_NOT_IMPLEMENTED
();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
to be consumed by CUTLASS.
Note that for int4, ThreadBlockK MUST be 64.
*/
#pragma once
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/platform/platform.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
template
<
typename
TypeA
,
typename
TypeB
,
typename
Arch
,
typename
Enable
=
void
>
struct
LayoutDetailsB
{
};
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
// TODO - Switch this to column major for weights since gemms should be more performant.
template
<
typename
TypeA
,
typename
Arch
>
struct
LayoutDetailsB
<
TypeA
,
half_t
,
Arch
,
typename
platform
::
enable_if
<
Arch
::
kMinComputeCapability
>=
75
>::
type
>
{
static
constexpr
int
ThreadblockK
=
128
*
8
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
using
Layout
=
layout
::
ColumnMajor
;
static
constexpr
int
ElementsPerAccess
=
128
/
cutlass
::
sizeof_bits
<
half_t
>::
value
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
template
<
typename
TypeA
,
typename
Arch
>
struct
LayoutDetailsB
<
TypeA
,
bfloat16_t
,
Arch
,
typename
platform
::
enable_if
<
Arch
::
kMinComputeCapability
>=
75
>::
type
>
{
static
constexpr
int
ThreadblockK
=
128
*
8
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
using
Layout
=
layout
::
ColumnMajor
;
static
constexpr
int
ElementsPerAccess
=
128
/
cutlass
::
sizeof_bits
<
bfloat16_t
>::
value
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
template
<
typename
TypeA
>
struct
LayoutDetailsB
<
TypeA
,
cutlass
::
float_e4m3_t
,
arch
::
Sm89
>
{
static
constexpr
int
ThreadblockK
=
64
;
private:
static
constexpr
int
ElementsPerCacheLine
=
128
*
8
/
sizeof_bits
<
uint8_t
>::
value
;
static
constexpr
int
ColumnsInterleaved
=
ElementsPerCacheLine
/
ThreadblockK
;
public:
using
Layout
=
layout
::
ColumnMajor
;
static
constexpr
int
ElementsPerAccess
=
128
/
cutlass
::
sizeof_bits
<
cutlass
::
float_e4m3_t
>::
value
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
// for fast accumulation
// using Operator = cutlass::arch::OpMultiplyAddFastAccum;
};
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
// which signals that we want to dequantize after loading from smem.
template
<
typename
TypeA
,
typename
Arch
>
struct
LayoutDetailsB
<
TypeA
,
uint8_t
,
Arch
,
typename
platform
::
enable_if
<
Arch
::
kMinComputeCapability
>=
75
&&
Arch
::
kMinComputeCapability
<
90
>::
type
>
{
static
constexpr
int
ThreadblockK
=
128
*
8
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
private:
static
constexpr
int
ElementsPerCacheLine
=
128
*
8
/
sizeof_bits
<
uint8_t
>::
value
;
static
constexpr
int
ColumnsInterleaved
=
ElementsPerCacheLine
/
ThreadblockK
;
public:
using
Layout
=
layout
::
ColumnMajorTileInterleave
<
ThreadblockK
,
ColumnsInterleaved
>
;
static
constexpr
int
ElementsPerAccess
=
128
/
cutlass
::
sizeof_bits
<
uint8_t
>::
value
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAddDequantizeInterleavedBToA
;
};
template
<
typename
TypeA
,
typename
Arch
>
struct
LayoutDetailsB
<
TypeA
,
uint4b_t
,
Arch
,
typename
platform
::
enable_if
<
Arch
::
kMinComputeCapability
>=
75
&&
Arch
::
kMinComputeCapability
<
90
>::
type
>
{
static
constexpr
int
ThreadblockK
=
128
*
8
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
private:
static
constexpr
int
ElementsPerCacheLine
=
128
*
8
/
sizeof_bits
<
uint4b_t
>::
value
;
static
constexpr
int
ColumnsInterleaved
=
ElementsPerCacheLine
/
ThreadblockK
;
public:
using
Layout
=
layout
::
ColumnMajorTileInterleave
<
ThreadblockK
,
ColumnsInterleaved
>
;
static
constexpr
int
ElementsPerAccess
=
128
/
cutlass
::
sizeof_bits
<
uint4b_t
>::
value
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAddDequantizeInterleavedBToA
;
};
template
<
typename
TypeA
,
typename
Arch
>
struct
LayoutDetailsB
<
TypeA
,
uint8_t
,
Arch
,
typename
platform
::
enable_if
<
Arch
::
kMinComputeCapability
>=
90
>::
type
>
{
static
constexpr
int
ThreadblockK
=
128
*
8
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
using
Layout
=
layout
::
ColumnMajor
;
static
constexpr
int
ElementsPerAccess
=
128
/
cutlass
::
sizeof_bits
<
half_t
>::
value
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
template
<
typename
TypeA
,
typename
Arch
>
struct
LayoutDetailsB
<
TypeA
,
uint4b_t
,
Arch
,
typename
platform
::
enable_if
<
Arch
::
kMinComputeCapability
>=
90
>::
type
>
{
static
constexpr
int
ThreadblockK
=
128
*
8
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
using
Layout
=
layout
::
ColumnMajor
;
static
constexpr
int
ElementsPerAccess
=
128
/
cutlass
::
sizeof_bits
<
half_t
>::
value
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cute/algorithm/copy.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/layout/layout.h>
#include <cutlass/numeric_conversion.h>
template
<
typename
Element
,
typename
Layout
,
int
Alignment
,
int
SizeK
>
struct
DefaultGemm_TensorOpSm80_OperandA
;
template
<
typename
Element
,
typename
Layout
,
int
Alignment
,
int
SizeK
>
struct
DefaultGemm_TensorOpSm80_OperandB
;
template
<
>
struct
DefaultGemm_TensorOpSm80_OperandA
<
cute
::
half_t
,
cutlass
::
layout
::
RowMajor
,
8
,
64
>
{
// Smem
using
SmemLayoutAtom
=
decltype
(
cute
::
composition
(
cute
::
Swizzle
<
3
,
3
,
3
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_8
,
cute
::
_64
>
,
cute
::
Stride
<
cute
::
_64
,
cute
::
_1
>>
{}));
using
SmemCopyAtom
=
cute
::
Copy_Atom
<
cute
::
SM75_U32x4_LDSM_N
,
cute
::
half_t
>
;
// Gmem
using
GmemTiledCopy
=
decltype
(
cute
::
make_tiled_copy
(
cute
::
Copy_Atom
<
cute
::
SM80_CP_ASYNC_CACHEGLOBAL
<
cute
::
uint128_t
>
,
cute
::
half_t
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_16
,
cute
::
_8
>
,
cute
::
Stride
<
cute
::
_8
,
cute
::
_1
>>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_1
,
cute
::
_8
>>
{}));
};
template
<
>
struct
DefaultGemm_TensorOpSm80_OperandA
<
cute
::
bfloat16_t
,
cutlass
::
layout
::
RowMajor
,
8
,
64
>
{
// Smem
using
SmemLayoutAtom
=
decltype
(
cute
::
composition
(
cute
::
Swizzle
<
3
,
3
,
3
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_8
,
cute
::
_64
>
,
cute
::
Stride
<
cute
::
_64
,
cute
::
_1
>>
{}));
using
SmemCopyAtom
=
cute
::
Copy_Atom
<
cute
::
SM75_U32x4_LDSM_N
,
cute
::
bfloat16_t
>
;
// Gmem
using
GmemTiledCopy
=
decltype
(
cute
::
make_tiled_copy
(
cute
::
Copy_Atom
<
cute
::
SM80_CP_ASYNC_CACHEGLOBAL
<
cute
::
uint128_t
>
,
cute
::
bfloat16_t
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_16
,
cute
::
_8
>
,
cute
::
Stride
<
cute
::
_8
,
cute
::
_1
>>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_1
,
cute
::
_8
>>
{}));
};
/// Operand A - Column-major (M-major)
template
<
int
SizeK
>
struct
DefaultGemm_TensorOpSm80_OperandA
<
cute
::
half_t
,
cutlass
::
layout
::
ColumnMajor
,
8
,
SizeK
>
{
// Smem
using
SmemLayoutAtom
=
decltype
(
cute
::
composition
(
cute
::
Swizzle
<
3
,
3
,
3
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_64
,
cute
::
_8
>
,
cute
::
Stride
<
cute
::
_1
,
cute
::
_64
>>
{}));
using
SmemCopyAtom
=
cute
::
Copy_Atom
<
cute
::
SM75_U16x8_LDSM_T
,
cute
::
half_t
>
;
// Gmem
using
GmemTiledCopy
=
decltype
(
cute
::
make_tiled_copy
(
cute
::
Copy_Atom
<
cute
::
SM80_CP_ASYNC_CACHEGLOBAL
<
cute
::
uint128_t
>
,
cute
::
half_t
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_16
,
cute
::
_8
>
,
cute
::
Stride
<
cute
::
_1
,
cute
::
_16
>>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_8
,
cute
::
_1
>>
{}));
};
template
<
int
SizeK
>
struct
DefaultGemm_TensorOpSm80_OperandA
<
cute
::
bfloat16_t
,
cutlass
::
layout
::
ColumnMajor
,
8
,
SizeK
>
{
// Smem
using
SmemLayoutAtom
=
decltype
(
cute
::
composition
(
cute
::
Swizzle
<
3
,
3
,
3
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_64
,
cute
::
_8
>
,
cute
::
Stride
<
cute
::
_1
,
cute
::
_64
>>
{}));
using
SmemCopyAtom
=
cute
::
Copy_Atom
<
cute
::
SM75_U16x8_LDSM_T
,
cute
::
bfloat16_t
>
;
// Gmem
using
GmemTiledCopy
=
decltype
(
cute
::
make_tiled_copy
(
cute
::
Copy_Atom
<
cute
::
SM80_CP_ASYNC_CACHEGLOBAL
<
cute
::
uint128_t
>
,
cute
::
bfloat16_t
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_16
,
cute
::
_8
>
,
cute
::
Stride
<
cute
::
_1
,
cute
::
_16
>>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_8
,
cute
::
_1
>>
{}));
};
// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands
// Operand B - Column-Major (K-major)
template
<
int
Alignment
,
int
SizeK
>
struct
DefaultGemm_TensorOpSm80_OperandB
<
cute
::
half_t
,
cutlass
::
layout
::
ColumnMajor
,
Alignment
,
SizeK
>
:
DefaultGemm_TensorOpSm80_OperandA
<
cute
::
half_t
,
cutlass
::
layout
::
RowMajor
,
Alignment
,
SizeK
>
{
};
template
<
int
Alignment
,
int
SizeK
>
struct
DefaultGemm_TensorOpSm80_OperandB
<
cute
::
bfloat16_t
,
cutlass
::
layout
::
ColumnMajor
,
Alignment
,
SizeK
>
:
DefaultGemm_TensorOpSm80_OperandA
<
cute
::
bfloat16_t
,
cutlass
::
layout
::
RowMajor
,
Alignment
,
SizeK
>
{
};
// Operand B - Row-Major (N-major)
template
<
int
Alignment
,
int
SizeK
>
struct
DefaultGemm_TensorOpSm80_OperandB
<
cute
::
half_t
,
cutlass
::
layout
::
RowMajor
,
Alignment
,
SizeK
>
:
DefaultGemm_TensorOpSm80_OperandA
<
cute
::
half_t
,
cutlass
::
layout
::
ColumnMajor
,
Alignment
,
SizeK
>
{
};
template
<
int
Alignment
,
int
SizeK
>
struct
DefaultGemm_TensorOpSm80_OperandB
<
cute
::
bfloat16_t
,
cutlass
::
layout
::
RowMajor
,
Alignment
,
SizeK
>
:
DefaultGemm_TensorOpSm80_OperandA
<
cute
::
bfloat16_t
,
cutlass
::
layout
::
ColumnMajor
,
Alignment
,
SizeK
>
{
};
//
// F16: 128-by-128-by-32 (small k-block)
//
/// Operand A - Row-major (K-Major)
template
<
>
struct
DefaultGemm_TensorOpSm80_OperandA
<
cute
::
half_t
,
cutlass
::
layout
::
RowMajor
,
8
,
32
>
{
// Smem
using
SmemLayoutAtom
=
decltype
(
cute
::
composition
(
cute
::
Swizzle
<
2
,
3
,
3
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_8
,
cute
::
_32
>
,
cute
::
Stride
<
cute
::
_32
,
cute
::
_1
>>
{}));
using
SmemCopyAtom
=
cute
::
Copy_Atom
<
cute
::
SM75_U32x4_LDSM_N
,
cute
::
half_t
>
;
// Gmem
using
GmemTiledCopy
=
decltype
(
cute
::
make_tiled_copy
(
cute
::
Copy_Atom
<
cute
::
SM80_CP_ASYNC_CACHEGLOBAL
<
cute
::
uint128_t
>
,
cute
::
half_t
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_32
,
cute
::
_4
>
,
cute
::
Stride
<
cute
::
_4
,
cute
::
_1
>>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_1
,
cute
::
_8
>>
{}));
};
template
<
>
struct
DefaultGemm_TensorOpSm80_OperandA
<
cute
::
bfloat16_t
,
cutlass
::
layout
::
RowMajor
,
8
,
32
>
{
// Smem
using
SmemLayoutAtom
=
decltype
(
cute
::
composition
(
cute
::
Swizzle
<
2
,
3
,
3
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_8
,
cute
::
_32
>
,
cute
::
Stride
<
cute
::
_32
,
cute
::
_1
>>
{}));
using
SmemCopyAtom
=
cute
::
Copy_Atom
<
cute
::
SM75_U32x4_LDSM_N
,
cute
::
bfloat16_t
>
;
// Gmem
using
GmemTiledCopy
=
decltype
(
cute
::
make_tiled_copy
(
cute
::
Copy_Atom
<
cute
::
SM80_CP_ASYNC_CACHEGLOBAL
<
cute
::
uint128_t
>
,
cute
::
bfloat16_t
>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_32
,
cute
::
_4
>
,
cute
::
Stride
<
cute
::
_4
,
cute
::
_1
>>
{},
cute
::
Layout
<
cute
::
Shape
<
cute
::
_1
,
cute
::
_8
>>
{}));
};
template
<
typename
To_type
,
typename
Engine
,
typename
Layout
>
CUTE_DEVICE
auto
util_convert_type
(
cute
::
Tensor
<
Engine
,
Layout
>
const
&
tensor
)
{
using
From_type
=
typename
Engine
::
value_type
;
constexpr
int
numel
=
decltype
(
cute
::
size
(
tensor
))
::
value
;
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
>
convert_op
;
// HACK: this requires tensor to be "contiguous"
auto
frag
=
convert_op
(
*
reinterpret_cast
<
cutlass
::
Array
<
From_type
,
numel
>
const
*>
(
tensor
.
data
()));
return
cute
::
make_tensor
(
cute
::
make_rmem_ptr
<
To_type
>
(
&
frag
),
tensor
.
layout
());
}
template
<
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
CUTE_DEVICE
void
util_copy
(
TiledCopy
const
&
tiled_copy
,
cute
::
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
cute
::
Tensor
<
Engine1
,
Layout1
>&
D
)
{
CUTE_STATIC_ASSERT_V
(
cute
::
rank
(
S
)
==
cute
::
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
cute
::
rank
(
D
)
==
cute
::
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
0
>
(
S
)
==
cute
::
size
<
0
>
(
D
));
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
1
>
(
S
)
==
cute
::
size
<
1
>
(
D
));
CUTE_STATIC_ASSERT_V
(
cute
::
size
<
2
>
(
S
)
==
cute
::
size
<
2
>
(
D
));
CUTLASS_PRAGMA_UNROLL
for
(
int
m
=
0
;
m
<
cute
::
size
<
1
>
(
S
);
++
m
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
k
=
0
;
k
<
cute
::
size
<
2
>
(
S
);
++
k
)
{
cute
::
copy
(
tiled_copy
,
S
(
cute
::
_
,
m
,
k
),
D
(
cute
::
_
,
m
,
k
));
}
}
}
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms.
// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global.
template
<
typename
...
>
using
void_t
=
void
;
template
<
typename
Mma
,
typename
=
void
>
struct
use_dq_gemm
:
platform
::
false_type
{
};
template
<
typename
Mma
>
struct
use_dq_gemm
<
Mma
,
void_t
<
typename
Mma
::
IteratorScale
>>
:
platform
::
true_type
{
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
,
///! Threadblock swizzling function
typename
KernelArch
,
///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
/// arch.
GroupScheduleMode
GroupScheduleMode_
///! Type of scheduling to perform
>
struct
MoeFCGemm
{
public:
using
Mma
=
Mma_
;
using
Epilogue
=
Epilogue_
;
using
EpilogueOutputOp
=
typename
Epilogue
::
OutputOp
;
using
ThreadblockSwizzle
=
ThreadblockSwizzle_
;
static
GroupScheduleMode
const
kGroupScheduleMode
=
GroupScheduleMode_
;
static
bool
const
kTransposed
=
false
;
// Optional transpose
using
MapArguments
=
kernel
::
detail
::
MapArguments
<
typename
Mma
::
IteratorA
::
Element
,
typename
Mma
::
IteratorA
::
Layout
,
Mma
::
kTransformA
,
Mma
::
IteratorA
::
AccessType
::
kElements
,
typename
Mma
::
IteratorB
::
Element
,
typename
Mma
::
IteratorB
::
Layout
,
Mma
::
kTransformB
,
Mma
::
IteratorB
::
AccessType
::
kElements
,
typename
Mma
::
LayoutC
,
kTransposed
>
;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
static_assert
(
!
kTransposed
,
"Transpose problem not supported"
);
using
ElementA
=
typename
MapArguments
::
ElementA
;
using
LayoutA
=
typename
MapArguments
::
LayoutA
;
using
ElementB
=
typename
MapArguments
::
ElementB
;
using
LayoutB
=
typename
MapArguments
::
LayoutB
;
using
ElementC
=
typename
Epilogue
::
OutputTileIterator
::
Element
;
using
LayoutC
=
typename
MapArguments
::
LayoutC
;
using
ElementScale
=
ElementC
;
static
ComplexTransform
const
kTransformA
=
MapArguments
::
kTransformA
;
static
ComplexTransform
const
kTransformB
=
MapArguments
::
kTransformB
;
// Type definitions about the mainloop.
using
Operator
=
typename
Mma
::
Operator
;
using
OperatorClass
=
typename
Mma
::
Operator
::
OperatorClass
;
using
ThreadblockShape
=
typename
Mma
::
Shape
;
using
WarpShape
=
typename
Mma
::
Operator
::
Shape
;
using
InstructionShape
=
typename
Mma
::
Policy
::
Operator
::
InstructionShape
;
using
ArchTag
=
typename
Mma
::
ArchTag
;
static
int
const
kStages
=
Mma
::
kStages
;
static
int
const
kAlignmentA
=
MapArguments
::
kAlignmentA
;
static
int
const
kAlignmentB
=
MapArguments
::
kAlignmentB
;
static
int
const
kAlignmentC
=
Epilogue
::
OutputTileIterator
::
kElementsPerAccess
;
/// Warp count (concept: GemmShape)
using
WarpCount
=
typename
Mma
::
WarpCount
;
static
int
const
kThreadCount
=
32
*
WarpCount
::
kCount
;
using
ProblemVisitor
=
GemmMoeProblemVisitor
<
ThreadblockShape
,
kGroupScheduleMode
,
kThreadCount
,
kThreadCount
,
kTransposed
>
;
//
// Structures
//
/// Argument structure
struct
Arguments
{
//
// Data members
//
int
problem_count
;
int
threadblock_count
;
int
group_size
;
typename
EpilogueOutputOp
::
Params
output_op
;
ElementA
*
ptr_A
;
ElementB
*
ptr_B
;
ElementScale
*
weight_scales
;
ElementC
*
ptr_C
;
ElementC
*
ptr_D
;
bool
C_is_broadcast
;
int64_t
const
*
total_tokens_including_expert
;
int64_t
gemm_n
;
int64_t
gemm_k
;
// Only used by device-level operator
GemmCoord
*
host_problem_sizes
;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments
()
:
problem_count
(
0
)
,
threadblock_count
(
0
)
,
ptr_A
(
nullptr
)
,
ptr_B
(
nullptr
)
,
weight_scales
(
nullptr
)
,
ptr_C
(
nullptr
)
,
ptr_D
(
nullptr
)
,
total_tokens_including_expert
(
nullptr
)
,
gemm_n
(
0
)
,
gemm_k
(
0
)
,
host_problem_sizes
(
nullptr
)
,
C_is_broadcast
{
true
}
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments
(
int
problem_count
,
int
threadblock_count
,
int
group_size
,
typename
EpilogueOutputOp
::
Params
output_op
,
ElementA
const
*
ptr_A
,
ElementB
const
*
ptr_B
,
ElementScale
const
*
weight_scales
,
ElementC
const
*
ptr_C
,
bool
C_is_broadcast
,
ElementC
*
ptr_D
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
gemm_n
,
int64_t
gemm_k
,
GemmCoord
*
host_problem_sizes
=
nullptr
)
:
problem_count
(
problem_count
)
,
threadblock_count
(
threadblock_count
)
,
group_size
(
group_size
)
,
output_op
(
output_op
)
,
ptr_A
(
const_cast
<
ElementA
*>
(
ptr_A
))
,
ptr_B
(
const_cast
<
ElementB
*>
(
ptr_B
))
,
weight_scales
(
const_cast
<
ElementScale
*>
(
weight_scales
))
,
ptr_C
(
const_cast
<
ElementC
*>
(
ptr_C
))
,
C_is_broadcast
{
C_is_broadcast
}
,
ptr_D
(
ptr_D
)
,
total_tokens_including_expert
(
total_tokens_including_expert
)
,
gemm_n
(
gemm_n
)
,
gemm_k
(
gemm_k
)
,
host_problem_sizes
(
nullptr
)
{
if
(
platform
::
is_same
<
uint8_t
,
ElementB
>::
value
||
platform
::
is_same
<
uint4b_t
,
ElementB
>::
value
)
{
assert
(
weight_scales
);
}
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct
Params
{
typename
ProblemVisitor
::
Params
problem_visitor
;
int
threadblock_count
;
int
group_size
;
bool
C_is_broadcast
;
typename
EpilogueOutputOp
::
Params
output_op
;
ElementA
*
ptr_A
;
ElementB
*
ptr_B
;
ElementScale
*
weight_scales
;
ElementC
*
ptr_C
;
ElementC
*
ptr_D
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
:
ptr_A
(
nullptr
)
,
ptr_B
(
nullptr
)
,
weight_scales
(
nullptr
)
,
ptr_C
(
nullptr
)
,
ptr_D
(
nullptr
)
,
C_is_broadcast
(
true
)
{
}
CUTLASS_HOST_DEVICE
Params
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
int
tile_count
=
0
)
:
problem_visitor
(
args
.
total_tokens_including_expert
,
args
.
gemm_n
,
args
.
gemm_k
,
args
.
problem_count
,
workspace
,
tile_count
)
,
threadblock_count
(
args
.
threadblock_count
)
,
group_size
(
args
.
group_size
)
,
output_op
(
args
.
output_op
)
,
ptr_A
(
args
.
ptr_A
)
,
ptr_B
(
args
.
ptr_B
)
,
weight_scales
(
args
.
weight_scales
)
,
ptr_C
(
args
.
ptr_C
)
,
ptr_D
(
args
.
ptr_D
)
,
C_is_broadcast
(
args
.
C_is_broadcast
)
{
}
CUTLASS_HOST_DEVICE
void
update
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
int
tile_count
=
0
)
{
problem_visitor
=
typename
ProblemVisitor
::
Params
(
args
.
total_tokens_including_expert
,
args
.
gemm_n
,
args
.
gemm_k
,
args
.
problem_count
,
workspace
,
tile_count
);
threadblock_count
=
args
.
threadblock_count
;
output_op
=
args
.
output_op
;
ptr_A
=
args
.
ptr_A
;
ptr_B
=
args
.
ptr_B
;
weight_scales
=
args
.
weight_scales
;
ptr_C
=
args
.
ptr_C
;
ptr_D
=
args
.
ptr_D
;
C_is_broadcast
=
args
.
C_is_broadcast
;
}
};
/// Shared memory storage structure
union
SharedStorage
{
typename
ProblemVisitor
::
SharedStorage
problem_visitor
;
typename
Mma
::
SharedStorage
main_loop
;
typename
Epilogue
::
SharedStorage
epilogue
;
};
public:
//
// Methods
//
CUTLASS_DEVICE
MoeFCGemm
()
{}
/// Determines whether kernel satisfies alignment
static
Status
can_implement
(
cutlass
::
gemm
::
GemmCoord
const
&
problem_size
)
{
return
Status
::
kSuccess
;
}
static
Status
can_implement
(
Arguments
const
&
args
)
{
if
(
platform
::
is_same
<
uint8_t
,
ElementB
>::
value
||
platform
::
is_same
<
uint4b_t
,
ElementB
>::
value
)
{
if
(
args
.
weight_scales
==
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"
);
return
Status
::
kInvalid
;
}
}
else
if
(
args
.
weight_scales
!=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"
);
return
Status
::
kInvalid
;
}
else
if
(
args
.
group_size
!=
args
.
gemm_k
)
{
CUTLASS_TRACE_HOST
(
"MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)"
);
return
Status
::
kInvalid
;
}
// Handle the case the input is too short
else
if
(
args
.
gemm_n
<
Mma
::
IteratorB
::
AccessType
::
kElements
)
{
CUTLASS_TRACE_HOST
(
"MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment"
);
return
Status
::
kInvalid
;
}
return
Status
::
kSuccess
;
}
static
size_t
get_extra_workspace_size
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape
)
{
return
0
;
}
CUTLASS_DEVICE
void
run_kernel_
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using
ElementA
=
typename
Mma
::
IteratorA
::
Element
;
using
LayoutA
=
typename
Mma
::
IteratorA
::
Layout
;
using
ElementB
=
typename
Mma
::
IteratorB
::
Element
;
using
LayoutB
=
typename
Mma
::
IteratorB
::
Layout
;
using
ElementC
=
typename
Epilogue
::
OutputTileIterator
::
Element
;
using
LayoutC
=
typename
Epilogue
::
OutputTileIterator
::
Layout
;
static
constexpr
int
kInterleave
=
Mma
::
IteratorB
::
Shape
::
kRow
/
Mma
::
Shape
::
kK
;
static_assert
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajor
>::
value
&&
kInterleave
==
1
||
platform
::
is_same
<
LayoutB
,
layout
::
ColumnMajor
>::
value
&&
kInterleave
>=
1
,
"B must be row major/col major OR col major interleaved."
);
//
// Problem visitor.
//
ProblemVisitor
problem_visitor
(
params
.
problem_visitor
,
shared_storage
.
problem_visitor
,
blockIdx
.
x
);
const
int64_t
gemm_k
=
params
.
problem_visitor
.
gemm_k
;
const
int64_t
gemm_n
=
params
.
problem_visitor
.
gemm_n
;
int64_t
bytes_per_expert_matrix
=
(
gemm_k
*
gemm_n
/
8
)
*
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
// Outer 'persistent' loop to iterate over tiles
int
loop
=
0
;
while
(
problem_visitor
.
next_tile
())
{
loop
++
;
GemmCoord
problem_size
=
problem_visitor
.
problem_size
();
int32_t
problem_idx
=
problem_visitor
.
problem_index
();
int32_t
cta_idx
=
int32_t
(
problem_visitor
.
threadblock_idx
());
GemmCoord
grid_shape
=
problem_visitor
.
grid_shape
(
problem_size
);
cutlass
::
gemm
::
GemmCoord
threadblock_offset
(
int
(
cta_idx
/
grid_shape
.
n
())
*
Mma
::
Shape
::
kM
,
int
(
cta_idx
%
grid_shape
.
n
())
*
Mma
::
Shape
::
kN
,
0
);
// Load element pointers. Exchange pointers and strides if working on the transpose
const
int64_t
rows_to_jump
=
problem_idx
==
0
?
0
:
params
.
problem_visitor
.
last_row_for_problem
[
problem_idx
-
1
];
ElementA
*
ptr_A
=
reinterpret_cast
<
ElementA
*>
(
params
.
ptr_A
)
+
rows_to_jump
*
gemm_k
;
typename
LayoutA
::
LongIndex
ldm_A
=
gemm_k
;
char
*
byte_ptr_B
=
((
char
*
)
params
.
ptr_B
)
+
problem_idx
*
bytes_per_expert_matrix
;
ElementB
*
ptr_B
=
reinterpret_cast
<
ElementB
*>
(
byte_ptr_B
);
typename
LayoutB
::
LongIndex
ldm_B
=
platform
::
is_same
<
layout
::
RowMajor
,
LayoutB
>::
value
?
gemm_n
:
gemm_k
*
kInterleave
;
// Compute initial location in logical coordinates
cutlass
::
MatrixCoord
tb_offset_A
{
threadblock_offset
.
m
(),
0
,
};
cutlass
::
MatrixCoord
tb_offset_B
{
0
,
threadblock_offset
.
n
()
/
kInterleave
};
cutlass
::
MatrixCoord
tb_offset_scale
{
0
,
threadblock_offset
.
n
()};
// Compute position within threadblock
int
thread_idx
=
threadIdx
.
x
;
// Construct iterators to A and B operands
typename
Mma
::
IteratorA
iterator_A
(
LayoutA
(
ldm_A
),
ptr_A
,
{
problem_size
.
m
(),
problem_size
.
k
()},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorB
iterator_B
(
LayoutB
(
ldm_B
),
ptr_B
,
{
problem_size
.
k
()
*
kInterleave
,
problem_size
.
n
()
/
kInterleave
},
thread_idx
,
tb_offset_B
);
typename
Mma
::
FragmentC
accumulators
;
accumulators
.
clear
();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int
warp_idx
=
__shfl_sync
(
0xffffffff
,
threadIdx
.
x
/
32
,
0
);
int
lane_idx
=
threadIdx
.
x
%
32
;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
auto
CreateMMA
=
[
&
]()
{
if
constexpr
(
use_dq_gemm
<
Mma
>::
value
)
return
Mma
(
shared_storage
.
main_loop
,
params
.
group_size
,
thread_idx
,
warp_idx
,
lane_idx
);
else
return
Mma
(
shared_storage
.
main_loop
,
thread_idx
,
warp_idx
,
lane_idx
);
};
Mma
mma
=
CreateMMA
();
// Compute threadblock-scoped matrix multiply-add
int
gemm_k_iterations
=
(
problem_size
.
k
()
+
Mma
::
Shape
::
kK
-
1
)
/
Mma
::
Shape
::
kK
;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads
();
// Compute threadblock-scoped matrix multiply-add
ElementScale
*
weight_scale_ptr
=
params
.
weight_scales
+
problem_idx
*
problem_size
.
n
();
if
constexpr
(
use_dq_gemm
<
Mma
>::
value
)
{
const
MatrixCoord
scale_extent
=
{
1
,
problem_size
.
n
()};
typename
Mma
::
IteratorScale
iterator_scale
(
Mma
::
IteratorScale
::
Layout
(
scale_extent
.
column
()),
weight_scale_ptr
,
scale_extent
,
thread_idx
,
tb_offset_scale
);
mma
(
gemm_k_iterations
,
accumulators
,
iterator_A
,
iterator_B
,
iterator_scale
,
accumulators
);
}
else
{
mma
(
gemm_k_iterations
,
accumulators
,
iterator_A
,
iterator_B
,
accumulators
);
}
//
// Epilogue
//
ElementC
*
ptr_C
=
reinterpret_cast
<
ElementC
*>
(
params
.
ptr_C
)
+
(
params
.
C_is_broadcast
?
problem_idx
:
rows_to_jump
)
*
gemm_n
;
ElementC
*
ptr_D
=
reinterpret_cast
<
ElementC
*>
(
params
.
ptr_D
)
+
rows_to_jump
*
gemm_n
;
// lora need to set as layout_C(gemm_n)
LayoutC
layout_C
=
params
.
C_is_broadcast
?
LayoutC
(
0
)
:
LayoutC
(
gemm_n
);
LayoutC
layout_D
(
gemm_n
);
typename
Epilogue
::
OutputTileIterator
::
Params
params_C
(
layout_C
);
typename
Epilogue
::
OutputTileIterator
::
Params
params_D
(
layout_D
);
// Tile iterator loading from source tensor.
typename
Epilogue
::
OutputTileIterator
iterator_C
(
params_C
,
ptr_C
,
problem_size
.
mn
(),
thread_idx
,
threadblock_offset
.
mn
());
// Tile iterator writing to destination tensor.
typename
Epilogue
::
OutputTileIterator
iterator_D
(
params_D
,
ptr_D
,
problem_size
.
mn
(),
thread_idx
,
threadblock_offset
.
mn
());
Epilogue
epilogue
(
shared_storage
.
epilogue
,
thread_idx
,
warp_idx
,
lane_idx
);
// Execute the epilogue operator to update the destination tensor.
if
constexpr
(
platform
::
is_same
<
EpilogueOutputOp
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
typename
EpilogueOutputOp
::
ElementOutput
,
EpilogueOutputOp
::
kCount
,
typename
EpilogueOutputOp
::
ElementAccumulator
,
typename
EpilogueOutputOp
::
ElementCompute
,
EpilogueOutputOp
::
kScale
,
EpilogueOutputOp
::
kRound
>>::
value
)
{
EpilogueOutputOp
output_op
(
params
.
output_op
,
problem_idx
);
epilogue
(
output_op
,
iterator_D
,
accumulators
,
iterator_C
);
}
else
{
EpilogueOutputOp
output_op
(
params
.
output_op
);
epilogue
(
output_op
,
iterator_D
,
accumulators
,
iterator_C
);
}
// Next tile
problem_visitor
.
advance
(
gridDim
.
x
);
}
}
template
<
typename
CompilationArch
>
CUTLASS_DEVICE
void
run_kernel
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
if
constexpr
(
platform
::
is_same
<
KernelArch
,
CompilationArch
>::
value
)
{
run_kernel_
(
params
,
shared_storage
);
}
else
{
CUTLASS_NOT_IMPLEMENTED
();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel
<
arch
::
Sm75
>
(
params
,
shared_storage
);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
run_kernel
<
arch
::
Sm80
>
(
params
,
shared_storage
);
#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900)
constexpr
bool
isFp8
=
platform
::
is_same
<
ElementA
,
cutlass
::
float_e4m3_t
>::
value
||
platform
::
is_same
<
ElementA
,
cutlass
::
float_e5m2_t
>::
value
;
if
constexpr
(
isFp8
)
{
run_kernel
<
arch
::
Sm89
>
(
params
,
shared_storage
);
}
else
{
// reuse sm80 kernel for other types, align with dispatchToArch
run_kernel
<
arch
::
Sm80
>
(
params
,
shared_storage
);
}
#elif (__CUDA_ARCH__ >= 900)
run_kernel
<
arch
::
Sm80
>
(
params
,
shared_storage
);
#else
static_assert
(
false
,
"Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."
);
#endif
#else
CUTLASS_NOT_IMPLEMENTED
();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief Base scheduler for grouped problems, using MoE
*/
#pragma once
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Visitor class to abstract away the algorithm for iterating over tiles
template
<
typename
ProblemSizeHelper
,
typename
ThreadblockShape_
>
struct
BaseMoeProblemVisitor
{
using
ThreadblockShape
=
ThreadblockShape_
;
struct
ProblemInfo
{
static
int32_t
const
kNoPrefetchEntry
=
-
1
;
int32_t
problem_idx
;
int32_t
problem_start
;
CUTLASS_DEVICE
ProblemInfo
()
:
problem_idx
(
kNoPrefetchEntry
)
,
problem_start
(
kNoPrefetchEntry
)
{
}
CUTLASS_DEVICE
ProblemInfo
(
int32_t
problem_idx_
,
int32_t
problem_start_
)
:
problem_idx
(
problem_idx_
)
,
problem_start
(
problem_start_
)
{
}
};
struct
Params
{
int64_t
const
*
last_row_for_problem
;
int64_t
gemm_n
;
int64_t
gemm_k
;
int32_t
problem_count
;
void
const
*
workspace
;
int32_t
tile_count
;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params
()
:
last_row_for_problem
(
nullptr
)
,
gemm_n
(
0
)
,
gemm_k
(
0
)
,
problem_count
(
0
)
,
workspace
(
nullptr
)
,
tile_count
(
0
)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Params
(
int64_t
const
*
last_row_for_problem
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int32_t
problem_count
,
void
const
*
workspace
=
nullptr
,
int32_t
tile_count
=
0
)
:
last_row_for_problem
(
last_row_for_problem
)
,
gemm_n
(
gemm_n
)
,
gemm_k
(
gemm_k
)
,
problem_count
(
problem_count
)
,
workspace
(
workspace
)
,
tile_count
(
tile_count
)
{
}
};
Params
const
&
params
;
int32_t
tile_idx
;
int32_t
problem_tile_start
;
int32_t
problem_idx
;
//
// Methods
//
CUTLASS_DEVICE
BaseMoeProblemVisitor
(
Params
const
&
params_
,
int32_t
block_idx
)
:
params
(
params_
)
,
tile_idx
(
block_idx
)
,
problem_tile_start
(
0
)
,
problem_idx
(
0
)
{
}
/// Get the grid shape
CUTLASS_HOST_DEVICE
static
cutlass
::
gemm
::
GemmCoord
grid_shape
(
cutlass
::
gemm
::
GemmCoord
const
&
problem
)
{
return
cutlass
::
gemm
::
GemmCoord
(((
problem
.
m
()
-
1
+
ThreadblockShape
::
kM
)
/
ThreadblockShape
::
kM
),
((
problem
.
n
()
-
1
+
ThreadblockShape
::
kN
)
/
ThreadblockShape
::
kN
),
1
);
}
/// Gets the global tile index
CUTLASS_HOST_DEVICE
int32_t
tile_index
()
const
{
return
tile_idx
;
}
/// Gets the index of the problem
CUTLASS_HOST_DEVICE
int32_t
problem_index
()
const
{
return
problem_idx
;
}
CUTLASS_HOST_DEVICE
int32_t
threadblock_idx
()
const
{
return
tile_idx
-
problem_tile_start
;
}
CUTLASS_DEVICE
void
advance
(
int32_t
grid_size
)
{
tile_idx
+=
grid_size
;
}
CUTLASS_HOST_DEVICE
static
void
possibly_transpose_problem
(
cutlass
::
gemm
::
GemmCoord
&
problem
)
{
ProblemSizeHelper
::
possibly_transpose_problem
(
problem
);
}
/// Returns the problem size for the current problem
CUTLASS_HOST_DEVICE
cutlass
::
gemm
::
GemmCoord
problem_size
()
const
{
return
problem_size
(
problem_idx
);
}
CUTLASS_HOST_DEVICE
cutlass
::
gemm
::
GemmCoord
problem_size
(
int
idx
)
const
{
const
int64_t
prev_problem_row
=
idx
==
0
?
0
:
params
.
last_row_for_problem
[
idx
-
1
];
const
int64_t
current_problem_row
=
params
.
last_row_for_problem
[
idx
];
const
int64_t
gemm_m
=
current_problem_row
-
prev_problem_row
;
GemmCoord
problem
(
GemmCoord
::
Index
(
gemm_m
),
GemmCoord
::
Index
(
params
.
gemm_n
),
GemmCoord
::
Index
(
params
.
gemm_k
));
ProblemSizeHelper
::
possibly_transpose_problem
(
problem
);
return
problem
;
}
CUTLASS_HOST_DEVICE
static
int32_t
tile_count
(
cutlass
::
gemm
::
GemmCoord
const
&
grid
)
{
return
ProblemSizeHelper
::
tile_count
(
grid
);
}
static
int32_t
group_tile_count
(
cutlass
::
gemm
::
GemmCoord
const
*
host_problem_sizes_ptr
,
int32_t
problem_count
)
{
int32_t
total_tiles
=
0
;
for
(
int32_t
i
=
0
;
i
<
problem_count
;
++
i
)
{
auto
problem
=
host_problem_sizes_ptr
[
i
];
possibly_transpose_problem
(
problem
);
auto
grid
=
grid_shape
(
problem
);
total_tiles
+=
tile_count
(
grid
);
}
return
total_tiles
;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
ProblemSizeHelper
,
typename
ThreadblockShape
,
GroupScheduleMode
GroupScheduleMode_
,
int
PrefetchTileCount
,
int
ThreadCount
>
struct
MoeProblemVisitor
;
/////////////////////////////////////////////////////////////////////////////////////////////////
// ProblemVisitor that performs all scheduling on device
//
template
<
typename
ProblemSizeHelper
,
typename
ThreadblockShape
,
int
PrefetchTileCount
,
int
ThreadCount
>
struct
MoeProblemVisitor
<
ProblemSizeHelper
,
ThreadblockShape
,
GroupScheduleMode
::
kDeviceOnly
,
PrefetchTileCount
,
ThreadCount
>
:
public
BaseMoeProblemVisitor
<
ProblemSizeHelper
,
ThreadblockShape
>
{
using
Base
=
BaseMoeProblemVisitor
<
ProblemSizeHelper
,
ThreadblockShape
>
;
using
Params
=
typename
Base
::
Params
;
static
int
const
kThreadCount
=
ThreadCount
;
static
bool
const
kRequiresPrecomputation
=
false
;
static
int
const
kThreadsPerWarp
=
32
;
struct
SharedStorage
{
};
// Final tile of the problem loaded by this thread. Each thread will hold
// a separate value.
int32_t
problem_ending_tile
;
SharedStorage
&
shared_storage
;
//
// Methods
//
CUTLASS_DEVICE
MoeProblemVisitor
(
Params
const
&
params_
,
SharedStorage
&
shared_storage_
,
int32_t
block_idx
)
:
Base
(
params_
,
block_idx
)
,
problem_ending_tile
(
0
)
,
shared_storage
(
shared_storage_
)
{
this
->
problem_idx
=
-
1
*
kThreadsPerWarp
;
this
->
problem_tile_start
=
0
;
}
CUTLASS_DEVICE
bool
next_tile
()
{
// Check whether the tile to compute is within the range of the current problem.
int32_t
problem_tile_end
=
__shfl_sync
(
0xffffffff
,
problem_ending_tile
,
this
->
problem_idx
%
kThreadsPerWarp
);
if
(
this
->
tile_idx
<
problem_tile_end
)
{
return
true
;
}
// Check whether the tile to compute is within the current group of problems fetched by the warp.
// The last tile for this group is the final tile of the problem held by the final thread in the warp.
int32_t
group_tile_end
=
__shfl_sync
(
0xffffffff
,
problem_ending_tile
,
kThreadsPerWarp
-
1
);
// Keep the starting problem for this group in `problem_idx`. This is done to reduce
// register pressure. The starting problem for this group is simply the first problem
// in the group most recently fetched by the warp.
int32_t
&
group_problem_start
=
this
->
problem_idx
;
group_problem_start
=
(
this
->
problem_idx
/
kThreadsPerWarp
)
*
kThreadsPerWarp
;
// Keep the starting tile for this group in `problem_tile_start`. This is done to reduce
// register pressure.
int32_t
&
group_tile_start
=
this
->
problem_tile_start
;
// Each thread in the warp processes a separate problem to advance until
// reaching a problem whose starting tile is less less than tile_idx.
while
(
group_tile_end
<=
this
->
tile_idx
)
{
group_problem_start
+=
kThreadsPerWarp
;
if
(
group_problem_start
>
this
->
params
.
problem_count
)
{
return
false
;
}
// Since `group_tile_start` is a reference to `this->problem_tile_start`, this
// also sets `this->problem_tile_start`. The fact that `this->problem_tile_start`
// is also set here is used later in `next_tile`.
group_tile_start
=
group_tile_end
;
int
lane_idx
=
threadIdx
.
x
%
kThreadsPerWarp
;
int32_t
lane_problem
=
group_problem_start
+
lane_idx
;
// Compute the number of tiles in the problem assigned to each thread.
problem_ending_tile
=
0
;
if
(
lane_problem
<
this
->
params
.
problem_count
)
{
cutlass
::
gemm
::
GemmCoord
problem
=
this
->
problem_size
(
lane_problem
);
cutlass
::
gemm
::
GemmCoord
grid
=
this
->
grid_shape
(
problem
);
problem_ending_tile
=
this
->
tile_count
(
grid
);
}
// Compute a warp-wide inclusive prefix sum to compute the ending tile index of
// each thread's problem.
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
1
;
i
<
kThreadsPerWarp
;
i
<<=
1
)
{
int32_t
val
=
__shfl_up_sync
(
0xffffffff
,
problem_ending_tile
,
i
);
if
(
lane_idx
>=
i
)
{
problem_ending_tile
+=
val
;
}
}
// The total tile count for this group is now in the final position of the prefix sum
int32_t
tiles_in_group
=
__shfl_sync
(
0xffffffff
,
problem_ending_tile
,
kThreadsPerWarp
-
1
);
problem_ending_tile
+=
group_tile_start
;
group_tile_end
+=
tiles_in_group
;
}
// The next problem to process is the first one that does not have ending tile position
// that is greater than or equal to tile index.
int32_t
problem_idx_in_group
=
__popc
(
__ballot_sync
(
0xffffffff
,
problem_ending_tile
<=
this
->
tile_idx
));
this
->
problem_idx
=
group_problem_start
+
problem_idx_in_group
;
// The starting tile for this problem is the ending tile of the previous problem. In cases
// where `problem_idx_in_group` is the first problem in the group, we do not need to reset
// `problem_tile_start`, because it is set to the previous group's ending tile in the while
// loop above.
if
(
problem_idx_in_group
>
0
)
{
this
->
problem_tile_start
=
__shfl_sync
(
0xffffffff
,
problem_ending_tile
,
problem_idx_in_group
-
1
);
}
return
true
;
}
static
size_t
get_workspace_size
(
cutlass
::
gemm
::
GemmCoord
const
*
host_problem_sizes_ptr
,
int32_t
problem_count
,
int32_t
block_count
)
{
return
0
;
}
static
void
host_precompute
(
cutlass
::
gemm
::
GemmCoord
const
*
host_problem_sizes_ptr
,
int32_t
problem_count
,
int32_t
block_count
,
void
*
host_workspace_ptr
)
{
}
};
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/tensor.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass/workspace.h"
///////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
kernel
{
///////////////////////////////////////////////////////////////////////////////
template
<
class
ProblemShape_
,
class
CollectiveMainloop_
,
class
CollectiveEpilogue_
,
class
TileScheduler_
>
class
GemmUniversalGated
<
ProblemShape_
,
CollectiveMainloop_
,
CollectiveEpilogue_
,
TileScheduler_
,
cute
::
enable_if_t
<
cute
::
is_base_of_v
<
KernelTmaWarpSpecializedCooperative
,
typename
CollectiveMainloop_
::
DispatchPolicy
::
Schedule
>
&&
CollectiveMainloop_
::
isGated
>>
{
public:
//
// Type Aliases
//
using
ProblemShape
=
ProblemShape_
;
static_assert
(
cute
::
rank
(
ProblemShape
{})
==
3
or
cute
::
rank
(
ProblemShape
{})
==
4
,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>"
);
// Mainloop derived types
using
CollectiveMainloop
=
CollectiveMainloop_
;
using
TileShape
=
typename
CollectiveMainloop
::
TileShape
;
using
TiledMma
=
typename
CollectiveMainloop
::
TiledMma
;
using
ArchTag
=
typename
CollectiveMainloop
::
ArchTag
;
using
ElementA
=
typename
CollectiveMainloop
::
ElementA
;
using
StrideA
=
typename
CollectiveMainloop
::
StrideA
;
using
ElementB
=
typename
CollectiveMainloop
::
ElementB
;
using
StrideB
=
typename
CollectiveMainloop
::
StrideB
;
using
DispatchPolicy
=
typename
CollectiveMainloop
::
DispatchPolicy
;
using
ElementAccumulator
=
typename
CollectiveMainloop
::
ElementAccumulator
;
using
ClusterShape
=
typename
DispatchPolicy
::
ClusterShape
;
using
MainloopArguments
=
typename
CollectiveMainloop
::
Arguments
;
using
MainloopParams
=
typename
CollectiveMainloop
::
Params
;
using
Activation
=
typename
CollectiveMainloop
::
Activation
;
// Epilogue derived types
using
CollectiveEpilogue
=
CollectiveEpilogue_
;
using
ElementC
=
typename
CollectiveEpilogue
::
ElementC
;
using
StrideC
=
typename
CollectiveEpilogue
::
StrideC
;
using
ElementD
=
typename
CollectiveEpilogue
::
ElementD
;
using
StrideD
=
typename
CollectiveEpilogue
::
StrideD
;
using
EpilogueArguments
=
typename
CollectiveEpilogue
::
Arguments
;
using
EpilogueParams
=
typename
CollectiveEpilogue
::
Params
;
static_assert
(
ArchTag
::
kMinComputeCapability
>=
90
);
using
TileSchedulerTag
=
TileScheduler_
;
using
TileScheduler
=
typename
detail
::
TileSchedulerSelector
<
TileScheduler_
,
ArchTag
,
TileShape
,
ClusterShape
>::
Scheduler
;
using
TileSchedulerArguments
=
typename
TileScheduler
::
Arguments
;
using
TileSchedulerParams
=
typename
TileScheduler
::
Params
;
static
constexpr
uint32_t
NumLoadWarpGroups
=
1
;
static
constexpr
uint32_t
NumMmaWarpGroups
=
CUTE_STATIC_V
(
size
(
TiledMma
{}))
/
NumThreadsPerWarpGroup
;
static
constexpr
uint32_t
MaxThreadsPerBlock
=
CUTE_STATIC_V
(
size
(
TiledMma
{}))
+
(
NumLoadWarpGroups
*
NumThreadsPerWarpGroup
);
static
constexpr
uint32_t
MinBlocksPerMultiprocessor
=
1
;
/// Register requirement for Load and Math WGs
static
constexpr
uint32_t
LoadRegisterRequirement
=
40
;
static
constexpr
uint32_t
MmaRegisterRequirement
=
232
;
// 1 stage ordered sequence between mainloop and epilogue producer load threads
using
LoadWarpOrderBarrier
=
cutlass
::
OrderedSequenceBarrier
<
1
,
2
>
;
// Kernel level shared memory storage
struct
SharedStorage
{
struct
TensorStorage
:
cute
::
aligned_struct
<
128
>
{
using
MainloopTensorStorage
=
typename
CollectiveMainloop
::
TensorStorage
;
using
EpilogueTensorStorage
=
typename
CollectiveEpilogue
::
TensorStorage
;
MainloopTensorStorage
mainloop
;
EpilogueTensorStorage
epilogue
;
}
tensors
;
struct
PipelineStorage
:
cute
::
aligned_struct
<
16
>
{
using
MainloopPipelineStorage
=
typename
CollectiveMainloop
::
PipelineStorage
;
using
EpiLoadPipelineStorage
=
typename
CollectiveEpilogue
::
PipelineStorage
;
alignas
(
16
)
MainloopPipelineStorage
mainloop
;
alignas
(
16
)
EpiLoadPipelineStorage
epi_load
;
alignas
(
16
)
typename
LoadWarpOrderBarrier
::
SharedStorage
load_order
;
}
pipelines
;
};
static
constexpr
int
SharedStorageSize
=
sizeof
(
SharedStorage
);
// Device side arguments
struct
Arguments
{
GemmUniversalMode
mode
{};
ProblemShape
problem_shape
{};
MainloopArguments
mainloop
{};
EpilogueArguments
epilogue
{};
KernelHardwareInfo
hw_info
{};
TileSchedulerArguments
scheduler
{};
};
// Kernel entry point API
struct
Params
{
GemmUniversalMode
mode
{};
ProblemShape
problem_shape
{};
MainloopParams
mainloop
{};
EpilogueParams
epilogue
{};
KernelHardwareInfo
hw_info
{};
TileSchedulerParams
scheduler
{};
void
*
workspace
{
nullptr
};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static
Params
to_underlying_arguments
(
Arguments
const
&
args
,
void
*
workspace
)
{
CUTLASS_TRACE_HOST
(
"to_underlying_arguments():"
);
auto
problem_shape
=
args
.
problem_shape
;
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
// // swap M/N
// get<0>(problem_shape) = get<1>(args.problem_shape);
// get<1>(problem_shape) = get<0>(args.problem_shape);
// }
auto
problem_shape_MNKL
=
append
<
4
>
(
problem_shape
,
1
);
// Get SM count if needed, otherwise use user supplied SM count
int
sm_count
=
args
.
hw_info
.
sm_count
;
if
(
sm_count
<=
0
)
{
CUTLASS_TRACE_HOST
(
" WARNING: Arguments do not include a valid SM count.
\n
"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."
);
sm_count
=
KernelHardwareInfo
::
query_device_multiprocessor_count
(
args
.
hw_info
.
device_id
);
}
CUTLASS_TRACE_HOST
(
"to_underlying_arguments(): Setting persistent grid SM count to "
<<
sm_count
);
KernelHardwareInfo
hw_info
{
args
.
hw_info
.
device_id
,
sm_count
};
// Calculate workspace pointers
uint8_t
*
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace
);
size_t
workspace_offset
=
0
;
void
*
scheduler_workspace
=
workspace_ptr
;
workspace_offset
+=
TileScheduler
::
template
get_workspace_size
<
ProblemShape
,
ElementAccumulator
>(
args
.
scheduler
,
args
.
problem_shape
,
args
.
hw_info
,
NumMmaWarpGroups
);
workspace_offset
=
round_nearest
(
workspace_offset
,
MinWorkspaceAlignment
);
void
*
epilogue_workspace
=
workspace_ptr
+
workspace_offset
;
workspace_offset
+=
CollectiveEpilogue
::
get_workspace_size
(
args
.
problem_shape
,
args
.
epilogue
);
workspace_offset
=
round_nearest
(
workspace_offset
,
MinWorkspaceAlignment
);
void
*
mainloop_workspace
=
nullptr
;
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used
// in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means
// subtile will not be used, therefore separate reduction will not be enabled.
constexpr
uint32_t
NumEpilogueSubTiles
=
CollectiveEpilogue
::
get_store_pipe_increment
(
TileShape
{});
TileSchedulerParams
scheduler
=
TileScheduler
::
to_underlying_arguments
(
problem_shape_MNKL
,
TileShape
{},
ClusterShape
{},
hw_info
,
args
.
scheduler
,
scheduler_workspace
,
NumEpilogueSubTiles
);
return
{
args
.
mode
,
problem_shape
,
CollectiveMainloop
::
to_underlying_arguments
(
args
.
problem_shape
,
args
.
mainloop
,
mainloop_workspace
),
CollectiveEpilogue
::
to_underlying_arguments
(
args
.
problem_shape
,
args
.
epilogue
,
epilogue_workspace
),
hw_info
,
scheduler
,
workspace
};
}
static
bool
can_implement
(
Arguments
const
&
args
)
{
bool
implementable
=
(
args
.
mode
==
GemmUniversalMode
::
kGemm
)
or
(
args
.
mode
==
GemmUniversalMode
::
kBatched
&&
cute
::
rank
(
ProblemShape
{})
==
4
);
if
(
!
implementable
)
{
CUTLASS_TRACE_HOST
(
" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.
\n
"
);
return
implementable
;
}
implementable
&=
CollectiveMainloop
::
can_implement
(
args
.
problem_shape
,
args
.
mainloop
);
implementable
&=
CollectiveEpilogue
::
can_implement
(
args
.
problem_shape
,
args
.
epilogue
);
implementable
&=
TileScheduler
::
can_implement
(
args
.
scheduler
);
return
implementable
;
}
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
size_t
workspace_size
=
0
;
constexpr
uint32_t
NumEpilogueSubTiles
=
CollectiveEpilogue
::
get_store_pipe_increment
(
TileShape
{});
workspace_size
+=
TileScheduler
::
template
get_workspace_size
<
ProblemShape
,
ElementAccumulator
>(
args
.
scheduler
,
args
.
problem_shape
,
args
.
hw_info
,
NumMmaWarpGroups
,
NumEpilogueSubTiles
);
workspace_size
=
round_nearest
(
workspace_size
,
MinWorkspaceAlignment
);
workspace_size
+=
CollectiveEpilogue
::
get_workspace_size
(
args
.
problem_shape
,
args
.
epilogue
);
workspace_size
=
round_nearest
(
workspace_size
,
MinWorkspaceAlignment
);
return
workspace_size
;
}
static
cutlass
::
Status
initialize_workspace
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
,
CudaHostAdapter
*
cuda_adapter
=
nullptr
)
{
Status
status
=
Status
::
kSuccess
;
uint8_t
*
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace
);
size_t
workspace_offset
=
0
;
constexpr
uint32_t
NumEpilogueSubTiles
=
CollectiveEpilogue
::
get_store_pipe_increment
(
TileShape
{});
status
=
TileScheduler
::
template
initialize_workspace
<
ProblemShape
,
ElementAccumulator
>(
args
.
scheduler
,
workspace_ptr
+
workspace_offset
,
stream
,
args
.
problem_shape
,
args
.
hw_info
,
NumMmaWarpGroups
,
NumEpilogueSubTiles
);
workspace_offset
+=
TileScheduler
::
template
get_workspace_size
<
ProblemShape
,
ElementAccumulator
>(
args
.
scheduler
,
args
.
problem_shape
,
args
.
hw_info
,
NumMmaWarpGroups
,
NumEpilogueSubTiles
);
workspace_offset
=
round_nearest
(
workspace_offset
,
MinWorkspaceAlignment
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
status
=
CollectiveEpilogue
::
initialize_workspace
(
args
.
problem_shape
,
args
.
epilogue
,
workspace_ptr
+
workspace_offset
,
stream
,
cuda_adapter
);
workspace_offset
+=
CollectiveEpilogue
::
get_workspace_size
(
args
.
problem_shape
,
args
.
epilogue
);
workspace_offset
=
round_nearest
(
workspace_offset
,
MinWorkspaceAlignment
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
return
status
;
}
// Computes the kernel launch grid shape based on runtime parameters
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
TileSchedulerArguments
args
{};
if
constexpr
(
!
std
::
is_const_v
<
decltype
(
args
.
max_swizzle_size
)
>
)
{
args
.
max_swizzle_size
=
1
<<
params
.
scheduler
.
log_swizzle_size_
;
}
args
.
raster_order
=
params
.
scheduler
.
raster_order_
==
TileScheduler
::
RasterOrder
::
AlongN
?
TileScheduler
::
RasterOrderOptions
::
AlongN
:
TileScheduler
::
RasterOrderOptions
::
AlongM
;
return
TileScheduler
::
get_grid_shape
(
params
.
problem_shape
,
TileShape
{},
ClusterShape
{},
params
.
hw_info
,
args
);
}
static
dim3
get_block_shape
()
{
return
dim3
(
MaxThreadsPerBlock
,
1
,
1
);
}
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
char
*
smem_buf
)
{
using
namespace
cute
;
using
X
=
Underscore
;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf
(
"ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.
\n
"
);
#else
// Preconditions
static_assert
(
size
(
TiledMma
{})
==
256
,
"Cooperative kernel must have TiledMMA operating using 256 threads."
);
static_assert
(
size
<
0
>
(
TileShape
{})
>=
128
,
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."
);
static_assert
(
cute
::
rank
(
StrideA
{})
==
3
,
"StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."
);
static_assert
(
cute
::
rank
(
StrideB
{})
==
3
,
"StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."
);
static_assert
(
cute
::
rank
(
StrideC
{})
==
3
,
"StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."
);
static_assert
(
cute
::
rank
(
StrideD
{})
==
3
,
"StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."
);
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
enum
class
WarpGroupRole
{
Producer
=
0
,
Consumer0
=
1
,
Consumer1
=
2
};
enum
class
ProducerWarpRole
{
Mainloop
=
0
,
Warp1
=
1
,
Epilogue
=
2
,
Warp3
=
3
};
// Kernel level shared memory storage
SharedStorage
&
shared_storage
=
*
reinterpret_cast
<
SharedStorage
*>
(
smem_buf
);
int
thread_idx
=
int
(
threadIdx
.
x
);
int
lane_idx
=
canonical_lane_idx
();
int
warp_idx
=
canonical_warp_idx_sync
();
int
warp_idx_in_warp_group
=
warp_idx
%
NumWarpsPerWarpGroup
;
int
warp_group_thread_idx
=
thread_idx
%
NumThreadsPerWarpGroup
;
int
mma_thread_idx
=
thread_idx
%
size
(
TiledMma
{});
auto
warp_group_role
=
WarpGroupRole
(
canonical_warp_group_idx
());
auto
producer_warp_role
=
ProducerWarpRole
(
warp_idx_in_warp_group
);
int
lane_predicate
=
cute
::
elect_one_sync
();
uint32_t
block_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
// Issue Tma Descriptor Prefetch from a single thread
if
((
warp_idx
==
0
)
&&
lane_predicate
)
{
CollectiveMainloop
::
prefetch_tma_descriptors
(
params
.
mainloop
);
CollectiveEpilogue
::
prefetch_tma_descriptors
(
params
.
epilogue
);
}
// Mainloop Load pipeline
using
MainloopPipeline
=
typename
CollectiveMainloop
::
MainloopPipeline
;
typename
MainloopPipeline
::
Params
mainloop_pipeline_params
;
if
(
warp_group_role
==
WarpGroupRole
::
Producer
&&
producer_warp_role
==
ProducerWarpRole
::
Mainloop
)
{
mainloop_pipeline_params
.
role
=
MainloopPipeline
::
ThreadCategory
::
Producer
;
}
if
(
warp_group_role
==
WarpGroupRole
::
Consumer0
||
warp_group_role
==
WarpGroupRole
::
Consumer1
)
{
mainloop_pipeline_params
.
role
=
MainloopPipeline
::
ThreadCategory
::
Consumer
;
}
mainloop_pipeline_params
.
is_leader
=
warp_group_thread_idx
==
0
;
mainloop_pipeline_params
.
num_consumers
=
size
(
TiledMma
{});
mainloop_pipeline_params
.
transaction_bytes
=
CollectiveMainloop
::
TmaTransactionBytes
;
MainloopPipeline
mainloop_pipeline
(
shared_storage
.
pipelines
.
mainloop
,
mainloop_pipeline_params
,
ClusterShape
{});
// Epilogue Load pipeline
using
EpiLoadPipeline
=
typename
CollectiveEpilogue
::
LoadPipeline
;
typename
EpiLoadPipeline
::
Params
epi_load_pipeline_params
;
if
(
warp_group_role
==
WarpGroupRole
::
Producer
&&
producer_warp_role
==
ProducerWarpRole
::
Epilogue
)
{
epi_load_pipeline_params
.
role
=
EpiLoadPipeline
::
ThreadCategory
::
Producer
;
}
if
(
warp_group_role
==
WarpGroupRole
::
Consumer0
||
warp_group_role
==
WarpGroupRole
::
Consumer1
)
{
epi_load_pipeline_params
.
role
=
EpiLoadPipeline
::
ThreadCategory
::
Consumer
;
}
epi_load_pipeline_params
.
dst_blockid
=
cute
::
block_rank_in_cluster
();
epi_load_pipeline_params
.
producer_arv_count
=
NumThreadsPerWarp
;
epi_load_pipeline_params
.
consumer_arv_count
=
size
(
TiledMma
{});
epi_load_pipeline_params
.
transaction_bytes
=
CollectiveEpilogue
::
TmaTransactionBytes
;
EpiLoadPipeline
epi_load_pipeline
(
shared_storage
.
pipelines
.
epi_load
,
epi_load_pipeline_params
);
// Epilogue Store pipeline
using
EpiStorePipeline
=
typename
CollectiveEpilogue
::
StorePipeline
;
typename
EpiStorePipeline
::
Params
epi_store_pipeline_params
;
epi_store_pipeline_params
.
always_wait
=
true
;
EpiStorePipeline
epi_store_pipeline
(
epi_store_pipeline_params
);
typename
LoadWarpOrderBarrier
::
Params
params_load_order_barrier
;
params_load_order_barrier
.
group_id
=
producer_warp_role
==
ProducerWarpRole
::
Mainloop
?
0
:
1
;
params_load_order_barrier
.
group_size
=
NumThreadsPerWarp
;
LoadWarpOrderBarrier
load_order_barrier
(
shared_storage
.
pipelines
.
load_order
,
params_load_order_barrier
);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
typename
CollectiveMainloop
::
PipelineState
mainloop_pipe_consumer_state
;
typename
CollectiveEpilogue
::
LoadPipelineState
epi_load_pipe_consumer_state
;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState
mainloop_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
MainloopPipeline
>
();
PipelineState
epi_load_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
EpiLoadPipeline
>
();
PipelineState
epi_store_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
EpiStorePipeline
>
();
auto
cluster_wait_fn
=
[]()
{
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if
constexpr
(
size
(
ClusterShape
{})
>
1
)
{
cute
::
cluster_arrive_relaxed
();
return
[]()
{
cute
::
cluster_wait
();
};
}
else
{
__syncthreads
();
return
[]()
{};
// do nothing
}
}();
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto
problem_shape_MNKL
=
append
<
4
>
(
params
.
problem_shape
,
Int
<
1
>
{});
// Get the appropriate blocks for this thread block -- potential for thread block locality
TiledMma
tiled_mma
;
auto
blk_shape
=
TileShape
{};
// (BLK_M,BLK_N,BLK_K)
TileScheduler
scheduler
{
params
.
scheduler
};
auto
work_tile_info
=
scheduler
.
get_current_work
();
// In a warp specialized kernel, collectives expose data movement and compute operations separately
CollectiveMainloop
collective_mainloop
;
CollectiveEpilogue
collective_epilogue
(
params
.
epilogue
,
shared_storage
.
tensors
.
epilogue
);
// Prepare and partition the input tensors. Expects a tuple of tensors where:
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto
load_inputs
=
collective_mainloop
.
load_init
(
problem_shape_MNKL
,
params
.
mainloop
);
static_assert
(
cute
::
tuple_size_v
<
decltype
(
load_inputs
)
>
>=
3
,
"Output of load_init must have at least three elements (A, B, Aux)"
);
// Extract out partitioned A and B.
Tensor
gA_mkl
=
get
<
0
>
(
load_inputs
);
Tensor
gB_nkl
=
get
<
1
>
(
load_inputs
);
Tensor
gAux_xkl
=
get
<
2
>
(
load_inputs
);
// Get pipeline stage increments from tensor shapes
auto
k_tile_count
=
size
<
3
>
(
gA_mkl
);
// Wait for all thread blocks in the Cluster
cluster_wait_fn
();
if
(
warp_group_role
==
WarpGroupRole
::
Producer
)
{
cutlass
::
arch
::
warpgroup_reg_dealloc
<
LoadRegisterRequirement
>
();
// Mainloop Producer Warp
if
(
producer_warp_role
==
ProducerWarpRole
::
Mainloop
)
{
bool
do_load_order_arrive
=
true
;
while
(
work_tile_info
.
is_valid
())
{
if
(
!
TileScheduler
::
valid_warpgroup_in_work_tile
(
work_tile_info
))
{
work_tile_info
=
fetch_next_work
(
work_tile_info
,
scheduler
);
continue
;
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto
m_coord
=
idx2crd
(
work_tile_info
.
M_idx
,
shape
<
2
>
(
gA_mkl
));
auto
n_coord
=
idx2crd
(
work_tile_info
.
N_idx
,
shape
<
2
>
(
gB_nkl
));
auto
l_coord
=
idx2crd
(
work_tile_info
.
L_idx
,
shape
<
4
>
(
gB_nkl
));
auto
blk_coord
=
make_coord
(
m_coord
,
n_coord
,
_
,
l_coord
);
// Get the number of K tiles to compute for this work as well as the starting K tile offset of the
// work.
auto
work_k_tile_count
=
TileScheduler
::
get_work_k_tile_count
(
work_tile_info
,
problem_shape_MNKL
,
blk_shape
);
auto
work_k_tile_start
=
TileScheduler
::
get_work_k_tile_start
(
work_tile_info
);
auto
k_tile_iter
=
cute
::
make_coord_iterator
(
idx2crd
(
work_k_tile_start
,
shape
<
3
>
(
gA_mkl
)),
shape
<
3
>
(
gA_mkl
));
collective_mainloop
.
load
(
params
.
mainloop
,
mainloop_pipeline
,
mainloop_pipe_producer_state
,
load_inputs
,
blk_coord
,
k_tile_iter
,
work_k_tile_count
,
lane_idx
,
block_rank_in_cluster
,
shared_storage
.
tensors
.
mainloop
);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state
.
advance
(
work_k_tile_count
);
// Signal for the epilogue load warp to begin
if
(
do_load_order_arrive
)
{
load_order_barrier
.
arrive
();
do_load_order_arrive
=
false
;
}
// Get next work tile
work_tile_info
=
fetch_next_work
(
work_tile_info
,
scheduler
);
}
// Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop
.
load_tail
(
mainloop_pipeline
,
mainloop_pipe_producer_state
);
}
// Mainloop Producer Warp End
// Epilogue Producer Warp
else
if
(
producer_warp_role
==
ProducerWarpRole
::
Epilogue
&&
collective_epilogue
.
is_producer_load_needed
())
{
while
(
work_tile_info
.
is_valid
())
{
if
(
!
TileScheduler
::
requires_separate_reduction
(
params
.
scheduler
))
{
load_order_barrier
.
wait
();
}
if
(
TileScheduler
::
compute_epilogue
(
work_tile_info
,
params
.
scheduler
))
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto
m_coord
=
idx2crd
(
work_tile_info
.
M_idx
,
shape
<
2
>
(
gA_mkl
));
auto
n_coord
=
idx2crd
(
work_tile_info
.
N_idx
,
shape
<
2
>
(
gB_nkl
));
auto
l_coord
=
idx2crd
(
work_tile_info
.
L_idx
,
shape
<
4
>
(
gB_nkl
));
auto
blk_coord
=
make_coord
(
m_coord
,
n_coord
,
_
,
l_coord
);
epi_load_pipe_producer_state
=
collective_epilogue
.
load
(
epi_load_pipeline
,
epi_load_pipe_producer_state
,
problem_shape_MNKL
,
blk_shape
,
blk_coord
,
tiled_mma
,
lane_idx
,
shared_storage
.
tensors
.
epilogue
,
work_tile_info
.
reduction_subtile_idx
());
}
// Get next work tile
work_tile_info
=
fetch_next_work
(
work_tile_info
,
scheduler
);
}
// Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue
.
load_tail
(
epi_load_pipeline
,
epi_load_pipe_producer_state
);
}
// Epilogue Producer Warp End
}
// Producer Warp Group End
else
if
(
warp_group_role
==
WarpGroupRole
::
Consumer0
||
warp_group_role
==
WarpGroupRole
::
Consumer1
)
{
cutlass
::
arch
::
warpgroup_reg_alloc
<
MmaRegisterRequirement
>
();
// Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it
bool
do_store_tail
=
false
;
float
scale_d0
=
params
.
mainloop
.
scale_d0
;
float
scale_d1
=
params
.
mainloop
.
scale_d1
;
while
(
work_tile_info
.
is_valid
())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto
m_coord
=
idx2crd
(
work_tile_info
.
M_idx
,
shape
<
2
>
(
gA_mkl
));
auto
n_coord
=
idx2crd
(
work_tile_info
.
N_idx
,
shape
<
2
>
(
gB_nkl
));
auto
l_coord
=
idx2crd
(
work_tile_info
.
L_idx
,
shape
<
4
>
(
gB_nkl
));
auto
blk_coord
=
make_coord
(
m_coord
,
n_coord
,
_
,
l_coord
);
auto
work_k_tile_count
=
TileScheduler
::
get_work_k_tile_count
(
work_tile_info
,
problem_shape_MNKL
,
blk_shape
);
// Allocate the accumulators for the (M,N) blk_shape
//
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
auto
accumulators0
=
partition_fragment_C
(
tiled_mma
,
take
<
0
,
2
>
(
blk_shape
));
// (MMA,MMA_M,MMA_N)
auto
accumulators1
=
partition_fragment_C
(
tiled_mma
,
take
<
0
,
2
>
(
blk_shape
));
// (MMA,MMA_M,MMA_N)
if
(
TileScheduler
::
valid_warpgroup_in_work_tile
(
work_tile_info
))
{
collective_mainloop
.
mma
(
mainloop_pipeline
,
mainloop_pipe_consumer_state
,
accumulators0
,
accumulators1
,
work_k_tile_count
,
mma_thread_idx
,
shared_storage
.
tensors
.
mainloop
,
params
.
mainloop
);
// Make sure the math instructions are done and free buffers before entering the epilogue
collective_mainloop
.
mma_tail
(
mainloop_pipeline
,
mainloop_pipe_consumer_state
,
work_k_tile_count
);
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state
.
advance
(
work_k_tile_count
);
}
// Index of warp group within consumer warp groups
int
consumer_warp_group_idx
=
canonical_warp_group_idx
()
-
NumLoadWarpGroups
;
// Perform reduction across splits, if needed
TileScheduler
::
fixup
(
params
.
scheduler
,
work_tile_info
,
accumulators0
,
NumMmaWarpGroups
,
consumer_warp_group_idx
);
TileScheduler
::
fixup
(
params
.
scheduler
,
work_tile_info
,
accumulators1
,
NumMmaWarpGroups
,
consumer_warp_group_idx
);
Activation
elt_op
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
accumulators0
);
i
++
)
{
accumulators0
[
i
]
=
(
accumulators0
[
i
]
*
scale_d0
)
*
elt_op
(
scale_d1
*
accumulators1
[
i
]);
}
if
(
TileScheduler
::
compute_epilogue
(
work_tile_info
,
params
.
scheduler
))
{
// Epilogue and write to gD
auto
[
epi_load_pipe_consumer_state_next
,
epi_store_pipe_producer_state_next
]
=
collective_epilogue
.
store
(
epi_load_pipeline
,
epi_load_pipe_consumer_state
,
epi_store_pipeline
,
epi_store_pipe_producer_state
,
problem_shape_MNKL
,
blk_shape
,
blk_coord
,
accumulators0
,
tiled_mma
,
mma_thread_idx
,
shared_storage
.
tensors
.
epilogue
,
work_tile_info
.
reduction_subtile_idx
());
epi_load_pipe_consumer_state
=
epi_load_pipe_consumer_state_next
;
epi_store_pipe_producer_state
=
epi_store_pipe_producer_state_next
;
do_store_tail
=
true
;
}
// Get next work tile
work_tile_info
=
fetch_next_work
(
work_tile_info
,
scheduler
);
}
// Scheduler work fetch loop
if
(
do_store_tail
)
{
collective_epilogue
.
store_tail
(
epi_load_pipeline
,
epi_load_pipe_consumer_state
,
epi_store_pipeline
,
epi_store_pipe_producer_state
);
}
}
// Consumer Warp Groups End
#endif
}
private:
// Kernel helper function to get next work unit
CUTLASS_DEVICE
typename
TileScheduler
::
WorkTileInfo
fetch_next_work
(
typename
TileScheduler
::
WorkTileInfo
&
work_tile_info
,
TileScheduler
&
scheduler
)
const
{
// Check whether we should continue on with the current work unit. If this is the case,
// the work unit will have been updated in continue_current_work to reflect the new
// tile to be computed.
if
(
scheduler
.
continue_current_work
(
work_tile_info
))
{
return
work_tile_info
;
}
// Get next work tile
scheduler
.
advance_to_next_work
();
return
scheduler
.
get_current_work
();
}
};
///////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::kernel
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass/workspace.h"
#include "cute/tensor.hpp"
#include "cute/util/debug.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
kernel
{
///////////////////////////////////////////////////////////////////////////////
template
<
class
ProblemShape_
,
class
CollectiveMainloop_
,
class
CollectiveEpilogue_
,
class
TileScheduler_
>
class
GemmUniversalGated
<
ProblemShape_
,
CollectiveMainloop_
,
CollectiveEpilogue_
,
TileScheduler_
,
cute
::
enable_if_t
<
cute
::
is_base_of_v
<
KernelTmaWarpSpecializedPingpong
,
typename
CollectiveMainloop_
::
DispatchPolicy
::
Schedule
>
&&
CollectiveMainloop_
::
isGated
>>
{
public:
//
// Type Aliases
//
using
ProblemShape
=
ProblemShape_
;
static_assert
(
cute
::
rank
(
ProblemShape
{})
==
3
or
cute
::
rank
(
ProblemShape
{})
==
4
,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>"
);
// Mainloop derived types
using
CollectiveMainloop
=
CollectiveMainloop_
;
using
TileShape
=
typename
CollectiveMainloop
::
TileShape
;
using
TiledMma
=
typename
CollectiveMainloop
::
TiledMma
;
using
ArchTag
=
typename
CollectiveMainloop
::
ArchTag
;
using
ElementA
=
typename
CollectiveMainloop
::
ElementA
;
using
StrideA
=
typename
CollectiveMainloop
::
StrideA
;
using
ElementB
=
typename
CollectiveMainloop
::
ElementB
;
using
StrideB
=
typename
CollectiveMainloop
::
StrideB
;
using
DispatchPolicy
=
typename
CollectiveMainloop
::
DispatchPolicy
;
using
ElementAccumulator
=
typename
CollectiveMainloop
::
ElementAccumulator
;
using
ClusterShape
=
typename
DispatchPolicy
::
ClusterShape
;
using
MainloopArguments
=
typename
CollectiveMainloop
::
Arguments
;
using
MainloopParams
=
typename
CollectiveMainloop
::
Params
;
using
Activation
=
typename
CollectiveMainloop
::
Activation
;
static_assert
(
ArchTag
::
kMinComputeCapability
>=
90
);
// Epilogue derived types
using
CollectiveEpilogue
=
CollectiveEpilogue_
;
using
ElementC
=
typename
CollectiveEpilogue
::
ElementC
;
using
StrideC
=
typename
CollectiveEpilogue
::
StrideC
;
using
ElementD
=
typename
CollectiveEpilogue
::
ElementD
;
using
StrideD
=
typename
CollectiveEpilogue
::
StrideD
;
using
EpilogueArguments
=
typename
CollectiveEpilogue
::
Arguments
;
using
EpilogueParams
=
typename
CollectiveEpilogue
::
Params
;
static_assert
(
!
cute
::
is_same_v
<
TileScheduler_
,
StreamKScheduler
>
,
"Ping-pong kernel does not currently support stream-K scheduler."
);
using
TileSchedulerTag
=
TileScheduler_
;
using
TileScheduler
=
typename
detail
::
TileSchedulerSelector
<
TileScheduler_
,
ArchTag
,
TileShape
,
ClusterShape
>::
Scheduler
;
using
TileSchedulerArguments
=
typename
TileScheduler
::
Arguments
;
using
TileSchedulerParams
=
typename
TileScheduler
::
Params
;
static
constexpr
uint32_t
NumLoadWarpGroups
=
1
;
static
constexpr
uint32_t
NumMmaWarpGroups
=
2
;
static
constexpr
uint32_t
MaxThreadsPerBlock
=
CUTE_STATIC_V
(
size
(
TiledMma
{}))
+
(
NumMmaWarpGroups
*
NumThreadsPerWarpGroup
);
static
constexpr
uint32_t
MinBlocksPerMultiprocessor
=
1
;
/// Register requirement for Load and Math WGs
static
constexpr
uint32_t
LoadRegisterRequirement
=
40
;
static
constexpr
uint32_t
MmaRegisterRequirement
=
232
;
// 1 stage ordered sequence between mainloop and epilogue producer load threads
using
LoadWarpOrderBarrier
=
cutlass
::
OrderedSequenceBarrier
<
1
,
2
>
;
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
static
constexpr
uint32_t
StagesPerMathWarpGroup
=
2
;
using
MathWarpGroupOrderBarrier
=
cutlass
::
OrderedSequenceBarrier
<
StagesPerMathWarpGroup
,
NumMmaWarpGroups
>
;
// Kernel level shared memory storage
struct
SharedStorage
{
struct
TensorStorage
:
cute
::
aligned_struct
<
128
>
{
using
MainloopTensorStorage
=
typename
CollectiveMainloop
::
TensorStorage
;
using
EpilogueTensorStorage
=
typename
CollectiveEpilogue
::
TensorStorage
;
MainloopTensorStorage
mainloop
;
EpilogueTensorStorage
epilogue
;
}
tensors
;
struct
PipelineStorage
:
cute
::
aligned_struct
<
16
>
{
using
MainloopPipelineStorage
=
typename
CollectiveMainloop
::
PipelineStorage
;
using
EpiLoadPipelineStorage
=
typename
CollectiveEpilogue
::
PipelineStorage
;
using
MathWarpGroupOrderBarrierStorage
=
typename
MathWarpGroupOrderBarrier
::
SharedStorage
;
alignas
(
16
)
MainloopPipelineStorage
mainloop
;
alignas
(
16
)
EpiLoadPipelineStorage
epi_load
;
alignas
(
16
)
MathWarpGroupOrderBarrierStorage
math_wg_order
;
alignas
(
16
)
typename
LoadWarpOrderBarrier
::
SharedStorage
load_order
;
}
pipelines
;
};
static
constexpr
int
SharedStorageSize
=
sizeof
(
SharedStorage
);
// Device side arguments
struct
Arguments
{
GemmUniversalMode
mode
{};
ProblemShape
problem_shape
{};
MainloopArguments
mainloop
{};
EpilogueArguments
epilogue
{};
KernelHardwareInfo
hw_info
{};
TileSchedulerArguments
scheduler
{};
};
// Kernel entry point API
struct
Params
{
GemmUniversalMode
mode
{};
ProblemShape
problem_shape
{};
MainloopParams
mainloop
{};
EpilogueParams
epilogue
{};
KernelHardwareInfo
hw_info
{};
TileSchedulerParams
scheduler
{};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static
Params
to_underlying_arguments
(
Arguments
const
&
args
,
void
*
workspace
)
{
CUTLASS_TRACE_HOST
(
"to_underlying_arguments():"
);
(
void
)
workspace
;
auto
problem_shape
=
args
.
problem_shape
;
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
// // swap M/N
// get<0>(problem_shape) = get<1>(args.problem_shape);
// get<1>(problem_shape) = get<0>(args.problem_shape);
// }
auto
problem_shape_MNKL
=
append
<
4
>
(
problem_shape
,
1
);
// Get SM count if needed, otherwise use user supplied SM count
int
sm_count
=
args
.
hw_info
.
sm_count
;
if
(
sm_count
<=
0
)
{
CUTLASS_TRACE_HOST
(
" WARNING: Arguments do not include a valid SM count.
\n
"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."
);
sm_count
=
KernelHardwareInfo
::
query_device_multiprocessor_count
(
args
.
hw_info
.
device_id
);
}
CUTLASS_TRACE_HOST
(
"to_underlying_arguments(): Setting persistent grid SM count to "
<<
sm_count
);
KernelHardwareInfo
hw_info
{
args
.
hw_info
.
device_id
,
sm_count
};
// Calculate workspace pointers
uint8_t
*
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace
);
size_t
workspace_offset
=
0
;
void
*
scheduler_workspace
=
workspace_ptr
;
workspace_offset
+=
TileScheduler
::
template
get_workspace_size
<
ProblemShape
,
ElementAccumulator
>(
args
.
scheduler
,
args
.
problem_shape
,
args
.
hw_info
,
NumMmaWarpGroups
);
workspace_offset
=
round_nearest
(
workspace_offset
,
MinWorkspaceAlignment
);
void
*
epilogue_workspace
=
workspace_ptr
+
workspace_offset
;
workspace_offset
+=
CollectiveEpilogue
::
get_workspace_size
(
args
.
problem_shape
,
args
.
epilogue
);
workspace_offset
=
round_nearest
(
workspace_offset
,
MinWorkspaceAlignment
);
void
*
mainloop_workspace
=
nullptr
;
return
{
args
.
mode
,
problem_shape
,
CollectiveMainloop
::
to_underlying_arguments
(
args
.
problem_shape
,
args
.
mainloop
,
mainloop_workspace
),
CollectiveEpilogue
::
to_underlying_arguments
(
args
.
problem_shape
,
args
.
epilogue
,
epilogue_workspace
),
hw_info
,
TileScheduler
::
to_underlying_arguments
(
problem_shape_MNKL
,
TileShape
{},
ClusterShape
{},
hw_info
,
args
.
scheduler
,
scheduler_workspace
)};
}
static
bool
can_implement
(
Arguments
const
&
args
)
{
bool
implementable
=
(
args
.
mode
==
GemmUniversalMode
::
kGemm
)
or
(
args
.
mode
==
GemmUniversalMode
::
kBatched
&&
cute
::
rank
(
ProblemShape
{})
==
4
);
if
(
!
implementable
)
{
CUTLASS_TRACE_HOST
(
" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.
\n
"
);
return
implementable
;
}
implementable
&=
CollectiveMainloop
::
can_implement
(
args
.
problem_shape
,
args
.
mainloop
);
implementable
&=
CollectiveEpilogue
::
can_implement
(
args
.
problem_shape
,
args
.
epilogue
);
implementable
&=
TileScheduler
::
can_implement
(
args
.
scheduler
);
return
implementable
;
}
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
size_t
workspace_size
=
0
;
workspace_size
+=
TileScheduler
::
template
get_workspace_size
<
ProblemShape
,
ElementAccumulator
>(
args
.
scheduler
,
args
.
problem_shape
,
args
.
hw_info
,
NumMmaWarpGroups
);
workspace_size
=
round_nearest
(
workspace_size
,
MinWorkspaceAlignment
);
workspace_size
+=
CollectiveEpilogue
::
get_workspace_size
(
args
.
problem_shape
,
args
.
epilogue
);
workspace_size
=
round_nearest
(
workspace_size
,
MinWorkspaceAlignment
);
return
workspace_size
;
}
static
cutlass
::
Status
initialize_workspace
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
,
CudaHostAdapter
*
cuda_adapter
=
nullptr
)
{
Status
status
=
Status
::
kSuccess
;
uint8_t
*
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace
);
size_t
workspace_offset
=
0
;
status
=
TileScheduler
::
template
initialize_workspace
<
ProblemShape
,
ElementAccumulator
>(
args
.
scheduler
,
workspace_ptr
+
workspace_offset
,
stream
,
args
.
problem_shape
,
args
.
hw_info
,
NumMmaWarpGroups
);
workspace_offset
+=
TileScheduler
::
template
get_workspace_size
<
ProblemShape
,
ElementAccumulator
>(
args
.
scheduler
,
args
.
problem_shape
,
args
.
hw_info
,
NumMmaWarpGroups
);
workspace_offset
=
round_nearest
(
workspace_offset
,
MinWorkspaceAlignment
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
status
=
CollectiveEpilogue
::
initialize_workspace
(
args
.
problem_shape
,
args
.
epilogue
,
workspace_ptr
+
workspace_offset
,
stream
,
cuda_adapter
);
workspace_offset
+=
CollectiveEpilogue
::
get_workspace_size
(
args
.
problem_shape
,
args
.
epilogue
);
workspace_offset
=
round_nearest
(
workspace_offset
,
MinWorkspaceAlignment
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
return
status
;
}
// Computes the kernel launch grid shape based on runtime parameters
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
TileSchedulerArguments
args
{};
if
constexpr
(
!
std
::
is_const_v
<
decltype
(
args
.
max_swizzle_size
)
>
)
{
args
.
max_swizzle_size
=
1
<<
params
.
scheduler
.
log_swizzle_size_
;
}
args
.
raster_order
=
params
.
scheduler
.
raster_order_
==
TileScheduler
::
RasterOrder
::
AlongN
?
TileScheduler
::
RasterOrderOptions
::
AlongN
:
TileScheduler
::
RasterOrderOptions
::
AlongM
;
return
TileScheduler
::
get_grid_shape
(
params
.
problem_shape
,
TileShape
{},
ClusterShape
{},
params
.
hw_info
,
args
);
}
static
dim3
get_block_shape
()
{
return
dim3
(
MaxThreadsPerBlock
,
1
,
1
);
}
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
char
*
smem_buf
)
{
using
namespace
cute
;
using
X
=
Underscore
;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf
(
"ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.
\n
"
);
#else
// Preconditions
static_assert
(
cute
::
rank
(
StrideA
{})
==
3
,
"StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."
);
static_assert
(
cute
::
rank
(
StrideB
{})
==
3
,
"StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."
);
static_assert
(
cute
::
rank
(
StrideC
{})
==
3
,
"StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."
);
static_assert
(
cute
::
rank
(
StrideD
{})
==
3
,
"StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."
);
enum
class
WarpGroupRole
{
Producer
=
0
,
Consumer0
=
1
,
Consumer1
=
2
};
enum
class
ProducerWarpRole
{
Mainloop
=
0
,
Warp1
=
1
,
Epilogue
=
2
,
Warp3
=
3
};
// Kernel level shared memory storage
SharedStorage
&
shared_storage
=
*
reinterpret_cast
<
SharedStorage
*>
(
smem_buf
);
int
thread_idx
=
int
(
threadIdx
.
x
);
int
lane_idx
=
canonical_lane_idx
();
int
warp_idx
=
canonical_warp_idx_sync
();
int
warp_idx_in_warp_group
=
warp_idx
%
NumWarpsPerWarpGroup
;
int
warp_group_thread_idx
=
thread_idx
%
NumThreadsPerWarpGroup
;
auto
warp_group_role
=
WarpGroupRole
(
canonical_warp_group_idx
());
auto
producer_warp_role
=
ProducerWarpRole
(
warp_idx_in_warp_group
);
int
lane_predicate
=
cute
::
elect_one_sync
();
uint32_t
block_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
// Issue Tma Descriptor Prefetch from a single thread
if
((
warp_idx
==
0
)
&&
lane_predicate
)
{
CollectiveMainloop
::
prefetch_tma_descriptors
(
params
.
mainloop
);
CollectiveEpilogue
::
prefetch_tma_descriptors
(
params
.
epilogue
);
}
// Mainloop Load pipeline
using
MainloopPipeline
=
typename
CollectiveMainloop
::
MainloopPipeline
;
typename
MainloopPipeline
::
Params
mainloop_pipeline_params
;
if
(
warp_group_role
==
WarpGroupRole
::
Producer
&&
producer_warp_role
==
ProducerWarpRole
::
Mainloop
)
{
mainloop_pipeline_params
.
role
=
MainloopPipeline
::
ThreadCategory
::
Producer
;
}
if
(
warp_group_role
==
WarpGroupRole
::
Consumer0
||
warp_group_role
==
WarpGroupRole
::
Consumer1
)
{
mainloop_pipeline_params
.
role
=
MainloopPipeline
::
ThreadCategory
::
Consumer
;
}
mainloop_pipeline_params
.
is_leader
=
warp_group_thread_idx
==
0
;
mainloop_pipeline_params
.
num_consumers
=
NumThreadsPerWarpGroup
;
mainloop_pipeline_params
.
transaction_bytes
=
CollectiveMainloop
::
TmaTransactionBytes
;
MainloopPipeline
mainloop_pipeline
(
shared_storage
.
pipelines
.
mainloop
,
mainloop_pipeline_params
,
ClusterShape
{});
// Epilogue Load pipeline
using
EpiLoadPipeline
=
typename
CollectiveEpilogue
::
LoadPipeline
;
typename
EpiLoadPipeline
::
Params
epi_load_pipeline_params
;
if
(
warp_group_role
==
WarpGroupRole
::
Producer
&&
producer_warp_role
==
ProducerWarpRole
::
Epilogue
)
{
epi_load_pipeline_params
.
role
=
EpiLoadPipeline
::
ThreadCategory
::
Producer
;
}
if
(
warp_group_role
==
WarpGroupRole
::
Consumer0
||
warp_group_role
==
WarpGroupRole
::
Consumer1
)
{
epi_load_pipeline_params
.
role
=
EpiLoadPipeline
::
ThreadCategory
::
Consumer
;
}
epi_load_pipeline_params
.
dst_blockid
=
cute
::
block_rank_in_cluster
();
epi_load_pipeline_params
.
producer_arv_count
=
NumThreadsPerWarp
;
epi_load_pipeline_params
.
consumer_arv_count
=
NumThreadsPerWarpGroup
;
epi_load_pipeline_params
.
transaction_bytes
=
CollectiveEpilogue
::
TmaTransactionBytes
;
EpiLoadPipeline
epi_load_pipeline
(
shared_storage
.
pipelines
.
epi_load
,
epi_load_pipeline_params
);
// Epilogue Store pipeline
using
EpiStorePipeline
=
typename
CollectiveEpilogue
::
StorePipeline
;
typename
EpiStorePipeline
::
Params
epi_store_pipeline_params
;
epi_store_pipeline_params
.
always_wait
=
true
;
EpiStorePipeline
epi_store_pipeline
(
epi_store_pipeline_params
);
typename
LoadWarpOrderBarrier
::
Params
params_load_order_barrier
;
params_load_order_barrier
.
group_id
=
producer_warp_role
==
ProducerWarpRole
::
Mainloop
?
0
:
1
;
params_load_order_barrier
.
group_size
=
NumThreadsPerWarp
;
LoadWarpOrderBarrier
load_order_barrier
(
shared_storage
.
pipelines
.
load_order
,
params_load_order_barrier
);
typename
MathWarpGroupOrderBarrier
::
Params
params_math_wg_order_barrier
;
// DMA Load WG will not participate in these Ordered Barrier syncs
params_math_wg_order_barrier
.
group_id
=
canonical_warp_group_idx
()
-
static_cast
<
int
>
(
WarpGroupRole
::
Consumer0
);
params_math_wg_order_barrier
.
group_size
=
NumThreadsPerWarpGroup
;
// Number of threads / participants in a group
MathWarpGroupOrderBarrier
math_wg_order_barrier
(
shared_storage
.
pipelines
.
math_wg_order
,
params_math_wg_order_barrier
);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
typename
CollectiveMainloop
::
PipelineState
mainloop_pipe_consumer_state
;
typename
CollectiveEpilogue
::
LoadPipelineState
epi_load_pipe_consumer_state
;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState
mainloop_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
MainloopPipeline
>
();
PipelineState
epi_load_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
EpiLoadPipeline
>
();
PipelineState
epi_store_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
EpiStorePipeline
>
();
auto
cluster_wait_fn
=
[
&
]()
{
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if
constexpr
(
size
(
ClusterShape
{})
>
1
)
{
cute
::
cluster_arrive_relaxed
();
return
[]()
{
cute
::
cluster_wait
();
};
}
else
{
__syncthreads
();
return
[]()
{};
// do nothing
}
}();
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto
problem_shape_MNKL
=
append
<
4
>
(
params
.
problem_shape
,
Int
<
1
>
{});
// Get the appropriate blocks for this thread block -- potential for thread block locality
TiledMma
tiled_mma
;
auto
blk_shape
=
TileShape
{};
// (BLK_M,BLK_N,BLK_K)
// In a warp specialized kernel, collectives expose data movement and compute operations separately
CollectiveMainloop
collective_mainloop
;
CollectiveEpilogue
collective_epilogue
(
params
.
epilogue
,
shared_storage
.
tensors
.
epilogue
);
// Prepare and partition the input tensors. Expects a tuple of tensors where:
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto
load_inputs
=
collective_mainloop
.
load_init
(
problem_shape_MNKL
,
params
.
mainloop
);
static_assert
(
cute
::
tuple_size_v
<
decltype
(
load_inputs
)
>
>=
3
,
"Output of load_init must have at least three elements (A, B, Aux)"
);
// Extract out partitioned A and B.
Tensor
gA_mkl
=
get
<
0
>
(
load_inputs
);
Tensor
gB_nkl
=
get
<
1
>
(
load_inputs
);
Tensor
gAux_xkl
=
get
<
2
>
(
load_inputs
);
// Get pipeline stage increments from tensor shapes
auto
k_tile_count
=
size
<
3
>
(
gA_mkl
);
auto
c_tile_count
=
CollectiveEpilogue
::
get_load_pipe_increment
(
blk_shape
);
auto
d_tile_count
=
CollectiveEpilogue
::
get_store_pipe_increment
(
blk_shape
);
TileScheduler
scheduler
{
params
.
scheduler
};
if
(
warp_group_role
==
WarpGroupRole
::
Consumer1
)
{
// Advance 2nd Math WG to the next work tile for the startup
scheduler
.
advance_to_next_work
();
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
mainloop_pipe_consumer_state
.
advance
(
k_tile_count
);
epi_load_pipe_consumer_state
.
advance
(
c_tile_count
);
epi_store_pipe_producer_state
.
advance
(
d_tile_count
);
}
auto
work_tile_info
=
scheduler
.
get_current_work
();
// Wait for all thread blocks in the Cluster
cluster_wait_fn
();
if
(
warp_group_role
==
WarpGroupRole
::
Producer
)
{
cutlass
::
arch
::
warpgroup_reg_dealloc
<
LoadRegisterRequirement
>
();
// Mainloop Producer Warp
if
(
producer_warp_role
==
ProducerWarpRole
::
Mainloop
)
{
bool
do_load_order_arrive
=
true
;
while
(
work_tile_info
.
is_valid
())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto
m_coord
=
idx2crd
(
work_tile_info
.
M_idx
,
shape
<
2
>
(
gA_mkl
));
auto
n_coord
=
idx2crd
(
work_tile_info
.
N_idx
,
shape
<
2
>
(
gB_nkl
));
auto
l_coord
=
idx2crd
(
work_tile_info
.
L_idx
,
shape
<
4
>
(
gB_nkl
));
auto
blk_coord
=
make_coord
(
m_coord
,
n_coord
,
_
,
l_coord
);
auto
k_tile_iter
=
cute
::
make_coord_iterator
(
shape
<
3
>
(
gA_mkl
));
collective_mainloop
.
load
(
params
.
mainloop
,
mainloop_pipeline
,
mainloop_pipe_producer_state
,
load_inputs
,
blk_coord
,
k_tile_iter
,
k_tile_count
,
lane_idx
,
block_rank_in_cluster
,
shared_storage
.
tensors
.
mainloop
);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state
.
advance
(
k_tile_count
);
// Signal for the epilogue load warp to begin
if
(
do_load_order_arrive
)
{
load_order_barrier
.
arrive
();
do_load_order_arrive
=
false
;
}
// Get next work tile
scheduler
.
advance_to_next_work
();
work_tile_info
=
scheduler
.
get_current_work
();
}
// Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop
.
load_tail
(
mainloop_pipeline
,
mainloop_pipe_producer_state
);
}
// Mainloop Producer Warp End
// Epilogue Producer Warp
else
if
(
producer_warp_role
==
ProducerWarpRole
::
Epilogue
&&
collective_epilogue
.
is_producer_load_needed
())
{
load_order_barrier
.
wait
();
while
(
work_tile_info
.
is_valid
())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto
m_coord
=
idx2crd
(
work_tile_info
.
M_idx
,
shape
<
2
>
(
gA_mkl
));
auto
n_coord
=
idx2crd
(
work_tile_info
.
N_idx
,
shape
<
2
>
(
gB_nkl
));
auto
l_coord
=
idx2crd
(
work_tile_info
.
L_idx
,
shape
<
4
>
(
gB_nkl
));
auto
blk_coord
=
make_coord
(
m_coord
,
n_coord
,
_
,
l_coord
);
epi_load_pipe_producer_state
=
collective_epilogue
.
load
(
epi_load_pipeline
,
epi_load_pipe_producer_state
,
problem_shape_MNKL
,
blk_shape
,
blk_coord
,
tiled_mma
,
lane_idx
,
shared_storage
.
tensors
.
epilogue
);
// Get next work tile
scheduler
.
advance_to_next_work
();
work_tile_info
=
scheduler
.
get_current_work
();
}
// Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue
.
load_tail
(
epi_load_pipeline
,
epi_load_pipe_producer_state
);
}
// Epilogue Producer Warp End
}
// Producer Warp Group End
else
if
(
warp_group_role
==
WarpGroupRole
::
Consumer0
||
warp_group_role
==
WarpGroupRole
::
Consumer1
)
{
cutlass
::
arch
::
warpgroup_reg_alloc
<
MmaRegisterRequirement
>
();
float
scale_d0
=
params
.
mainloop
.
scale_d0
;
float
scale_d1
=
params
.
mainloop
.
scale_d1
;
while
(
work_tile_info
.
is_valid
())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto
m_coord
=
idx2crd
(
work_tile_info
.
M_idx
,
shape
<
2
>
(
gA_mkl
));
auto
n_coord
=
idx2crd
(
work_tile_info
.
N_idx
,
shape
<
2
>
(
gB_nkl
));
auto
l_coord
=
idx2crd
(
work_tile_info
.
L_idx
,
shape
<
4
>
(
gB_nkl
));
auto
blk_coord
=
make_coord
(
m_coord
,
n_coord
,
_
,
l_coord
);
// Allocate the accumulators for the (M,N) blk_shape
Tensor
accumulators0
=
partition_fragment_C
(
tiled_mma
,
take
<
0
,
2
>
(
blk_shape
));
// (MMA,MMA_M,MMA_N)
Tensor
accumulators1
=
partition_fragment_C
(
tiled_mma
,
take
<
0
,
2
>
(
blk_shape
));
// (MMA,MMA_M,MMA_N)
// Order two Math WG's MMA one after the other, helps hide Epilogue
math_wg_order_barrier
.
wait
();
collective_mainloop
.
mma
(
mainloop_pipeline
,
mainloop_pipe_consumer_state
,
accumulators0
,
accumulators1
,
k_tile_count
,
warp_group_thread_idx
,
shared_storage
.
tensors
.
mainloop
,
params
.
mainloop
);
// Cue for next Math WG's MMA to start
math_wg_order_barrier
.
arrive
();
// Make sure the math instructions are done and free buffers before entering the epilogue
collective_mainloop
.
mma_tail
(
mainloop_pipeline
,
mainloop_pipe_consumer_state
,
k_tile_count
);
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state
.
advance
(
k_tile_count
*
NumMmaWarpGroups
);
Activation
elt_op
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
accumulators0
);
i
++
)
{
accumulators0
[
i
]
=
(
accumulators0
[
i
]
*
scale_d0
)
*
elt_op
(
scale_d1
*
accumulators1
[
i
]);
}
// Order two Math WG's Epilogue one after the other
math_wg_order_barrier
.
wait
();
// Epilogue and write to gD
auto
[
epi_load_pipe_consumer_state_next
,
epi_store_pipe_producer_state_next
]
=
collective_epilogue
.
store
(
epi_load_pipeline
,
epi_load_pipe_consumer_state
,
epi_store_pipeline
,
epi_store_pipe_producer_state
,
problem_shape_MNKL
,
blk_shape
,
blk_coord
,
accumulators0
,
tiled_mma
,
warp_group_thread_idx
,
shared_storage
.
tensors
.
epilogue
);
// TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels
// we need to wait for all TMA stores to complete before issuing consumer order barrier arrives
// to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer.
auto
[
epi_load_pipe_consumer_state_next_
,
epi_store_pipe_producer_state_next_
]
=
collective_epilogue
.
store_tail
(
epi_load_pipeline
,
epi_load_pipe_consumer_state_next
,
epi_store_pipeline
,
epi_store_pipe_producer_state_next
);
// Update starting load/store pipeline states for the next tile
// state has already been incremented by 1 tile in collective calls, advance once again for ping pong
epi_load_pipe_consumer_state
=
epi_load_pipe_consumer_state_next_
;
epi_store_pipe_producer_state
=
epi_store_pipe_producer_state_next_
;
epi_load_pipe_consumer_state
.
advance
(
c_tile_count
);
epi_store_pipe_producer_state
.
advance
(
d_tile_count
);
// Cue for next Math WG's Epilogue to start
math_wg_order_barrier
.
arrive
();
// Get next work tile
scheduler
.
advance_to_next_work
(
NumMmaWarpGroups
);
work_tile_info
=
scheduler
.
get_current_work
();
}
// Scheduler work fetch loop
}
// Consumer Warp Groups End
#endif
}
};
///////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::kernel
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
,
///! Threadblock swizzling function
GroupScheduleMode
GroupScheduleMode_
,
///! Type of scheduling to perform
bool
Transposed
=
false
>
struct
SplitkGemmGrouped
{
public:
using
Mma
=
Mma_
;
using
Epilogue
=
Epilogue_
;
using
EpilogueOutputOp
=
typename
Epilogue
::
OutputOp
;
using
ThreadblockSwizzle
=
ThreadblockSwizzle_
;
static
GroupScheduleMode
const
kGroupScheduleMode
=
GroupScheduleMode_
;
static
bool
const
kTransposed
=
Transposed
;
// Optional transpose
using
MapArguments
=
kernel
::
detail
::
MapArguments
<
typename
Mma
::
IteratorA
::
Element
,
typename
Mma
::
IteratorA
::
Layout
,
Mma
::
kTransformA
,
Mma
::
IteratorA
::
AccessType
::
kElements
,
typename
Mma
::
IteratorB
::
Element
,
typename
Mma
::
IteratorB
::
Layout
,
Mma
::
kTransformB
,
Mma
::
IteratorB
::
AccessType
::
kElements
,
typename
Mma
::
LayoutC
,
kTransposed
>
;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
using
ElementA
=
typename
MapArguments
::
ElementA
;
using
LayoutA
=
typename
MapArguments
::
LayoutA
;
using
ElementB
=
typename
MapArguments
::
ElementB
;
using
LayoutB
=
typename
MapArguments
::
LayoutB
;
using
ElementC
=
typename
Epilogue
::
OutputTileIterator
::
Element
;
using
LayoutC
=
typename
MapArguments
::
LayoutC
;
using
ElementFinalOutput
=
typename
MapArguments
::
ElementA
;
static
ComplexTransform
const
kTransformA
=
MapArguments
::
kTransformA
;
static
ComplexTransform
const
kTransformB
=
MapArguments
::
kTransformB
;
// Type definitions about the mainloop.
using
Operator
=
typename
Mma
::
Operator
;
using
OperatorClass
=
typename
Mma
::
Operator
::
OperatorClass
;
using
ThreadblockShape
=
typename
Mma
::
Shape
;
using
WarpShape
=
typename
Mma
::
Operator
::
Shape
;
using
InstructionShape
=
typename
Mma
::
Policy
::
Operator
::
InstructionShape
;
using
ArchTag
=
typename
Mma
::
ArchTag
;
static
int
const
kStages
=
Mma
::
kStages
;
static
int
const
kAlignmentA
=
MapArguments
::
kAlignmentA
;
static
int
const
kAlignmentB
=
MapArguments
::
kAlignmentB
;
static
int
const
kAlignmentC
=
Epilogue
::
OutputTileIterator
::
kElementsPerAccess
;
/// Warp count (concept: GemmShape)
using
WarpCount
=
typename
Mma
::
WarpCount
;
static
int
const
kThreadCount
=
32
*
WarpCount
::
kCount
;
using
ProblemVisitor
=
GemmGroupedProblemVisitor
<
ThreadblockShape
,
kGroupScheduleMode
,
kThreadCount
,
kThreadCount
,
kTransposed
>
;
//
// Structures
//
/// Argument structure
struct
Arguments
{
//
// Data members
//
GemmCoord
*
problem_sizes
;
int
problem_count
;
int
threadblock_count
;
typename
EpilogueOutputOp
::
Params
output_op
;
ElementA
**
ptr_A
;
ElementB
**
ptr_B
;
ElementFinalOutput
**
ptr_C
;
ElementFinalOutput
**
ptr_D
;
typename
LayoutA
::
Stride
::
LongIndex
*
lda
;
typename
LayoutB
::
Stride
::
LongIndex
*
ldb
;
typename
LayoutC
::
Stride
::
LongIndex
*
ldc
;
typename
LayoutC
::
Stride
::
LongIndex
*
ldd
;
// Only used by device-level operator
GemmCoord
*
host_problem_sizes
;
// splitK
int
split_k_slices
;
int64_t
*
splitk_buffer_offsets
;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments
()
:
problem_count
(
0
)
,
threadblock_count
(
0
)
,
ptr_A
(
nullptr
)
,
ptr_B
(
nullptr
)
,
ptr_C
(
nullptr
)
,
ptr_D
(
nullptr
)
,
lda
(
nullptr
)
,
ldb
(
nullptr
)
,
ldc
(
nullptr
)
,
ldd
(
nullptr
)
,
host_problem_sizes
(
nullptr
)
,
split_k_slices
(
1
)
,
splitk_buffer_offsets
(
nullptr
)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments
(
GemmCoord
*
problem_sizes
,
int
problem_count
,
int
threadblock_count
,
typename
EpilogueOutputOp
::
Params
output_op
,
ElementA
**
ptr_A
,
ElementB
**
ptr_B
,
ElementFinalOutput
**
ptr_C
,
ElementFinalOutput
**
ptr_D
,
typename
LayoutA
::
Stride
::
LongIndex
*
lda
,
typename
LayoutB
::
Stride
::
LongIndex
*
ldb
,
typename
LayoutC
::
Stride
::
LongIndex
*
ldc
,
typename
LayoutC
::
Stride
::
LongIndex
*
ldd
,
GemmCoord
*
host_problem_sizes
,
int
split_k_slices
,
int64_t
*
splitk_buffer_offsets
)
:
problem_sizes
(
problem_sizes
)
,
problem_count
(
problem_count
)
,
threadblock_count
(
threadblock_count
)
,
output_op
(
output_op
)
,
ptr_A
(
ptr_A
)
,
ptr_B
(
ptr_B
)
,
ptr_C
(
ptr_C
)
,
ptr_D
(
ptr_D
)
,
lda
(
lda
)
,
ldb
(
ldb
)
,
ldc
(
ldc
)
,
ldd
(
ldd
)
,
host_problem_sizes
(
host_problem_sizes
)
,
split_k_slices
(
split_k_slices
)
,
splitk_buffer_offsets
(
splitk_buffer_offsets
)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct
Params
{
typename
ProblemVisitor
::
Params
problem_visitor
;
int
threadblock_count
;
typename
EpilogueOutputOp
::
Params
output_op
;
ElementA
**
ptr_A
;
ElementB
**
ptr_B
;
ElementFinalOutput
**
ptr_C
;
ElementFinalOutput
**
ptr_D
;
ElementC
*
ptr_C_split
;
ElementC
*
ptr_D_split
;
typename
LayoutA
::
Stride
::
LongIndex
*
lda
;
typename
LayoutB
::
Stride
::
LongIndex
*
ldb
;
typename
LayoutC
::
Stride
::
LongIndex
*
ldc
;
typename
LayoutC
::
Stride
::
LongIndex
*
ldd
;
//
// Methods
//
// splitk
GemmCoord
grid_tiled_shape
;
int
swizzle_log_tile
;
int
gemm_k_size
;
GemmCoord
*
host_problem_sizes
;
int
split_k_slices
;
int64_t
*
splitk_buffer_offsets
;
CUTLASS_HOST_DEVICE
Params
()
:
ptr_A
(
nullptr
)
,
ptr_B
(
nullptr
)
,
ptr_C
(
nullptr
)
,
ptr_D
(
nullptr
)
,
ptr_C_split
(
nullptr
)
,
ptr_D_split
(
nullptr
)
,
lda
(
nullptr
)
,
ldb
(
nullptr
)
,
ldc
(
nullptr
)
,
ldd
(
nullptr
)
,
swizzle_log_tile
(
0
)
,
gemm_k_size
(
0
)
,
host_problem_sizes
(
nullptr
)
,
split_k_slices
(
1
)
,
splitk_buffer_offsets
(
nullptr
)
{
}
CUTLASS_HOST_DEVICE
Params
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
int
tile_count
=
0
)
:
problem_visitor
(
args
.
problem_sizes
,
args
.
problem_count
,
workspace
,
tile_count
)
,
host_problem_sizes
(
args
.
host_problem_sizes
)
,
threadblock_count
(
args
.
threadblock_count
)
,
output_op
(
args
.
output_op
)
,
ptr_A
(
args
.
ptr_A
)
,
ptr_B
(
args
.
ptr_B
)
,
ptr_C
(
args
.
ptr_C
)
,
ptr_D
(
args
.
ptr_D
)
,
ptr_C_split
((
ElementC
*
)
workspace
)
,
ptr_D_split
((
ElementC
*
)
workspace
)
,
lda
(
args
.
lda
)
,
ldb
(
args
.
ldb
)
,
ldc
(
args
.
ldc
)
,
ldd
(
args
.
ldd
)
,
split_k_slices
(
args
.
split_k_slices
)
,
splitk_buffer_offsets
(
args
.
splitk_buffer_offsets
)
{
// Determine grid shape
ThreadblockSwizzle
threadblock_swizzle
;
grid_tiled_shape
=
threadblock_swizzle
.
get_tiled_shape
(
args
.
host_problem_sizes
[
0
],
{
ThreadblockShape
::
kM
,
ThreadblockShape
::
kN
,
ThreadblockShape
::
kK
},
args
.
split_k_slices
);
swizzle_log_tile
=
ThreadblockSwizzle
().
get_log_tile
(
grid_tiled_shape
);
// only support same k
int
full_gemm_k_iterations
=
args
.
host_problem_sizes
[
0
].
k
()
/
Mma
::
Shape
::
kK
;
int
gemm_k_iterations
=
full_gemm_k_iterations
/
grid_tiled_shape
.
k
();
gemm_k_size
=
gemm_k_iterations
*
Mma
::
Shape
::
kK
;
}
CUTLASS_HOST_DEVICE
void
update
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
int
tile_count
=
0
)
{
problem_visitor
=
typename
ProblemVisitor
::
Params
(
args
.
problem_sizes
,
args
.
problem_count
,
workspace
,
tile_count
);
threadblock_count
=
args
.
threadblock_count
;
output_op
=
args
.
output_op
;
ptr_A
=
args
.
ptr_A
;
ptr_B
=
args
.
ptr_B
;
ptr_C
=
args
.
ptr_C
;
ptr_D
=
args
.
ptr_D
;
ptr_C_split
=
workspace
;
ptr_D_split
=
workspace
;
lda
=
args
.
lda
;
ldb
=
args
.
ldb
;
ldc
=
args
.
ldc
;
ldd
=
args
.
ldd
;
}
};
/// Shared memory storage structure
struct
SharedStorage
{
union
{
typename
Mma
::
SharedStorage
main_loop
;
typename
Epilogue
::
SharedStorage
epilogue
;
}
kernel
;
// ProblemVisitor shared storage can't be overlapped with others
typename
ProblemVisitor
::
SharedStorage
problem_visitor
;
};
public:
//
// Methods
//
CUTLASS_DEVICE
SplitkGemmGrouped
()
{}
/// Determines whether kernel satisfies alignment
static
Status
can_implement
(
cutlass
::
gemm
::
GemmCoord
const
&
problem_size
)
{
return
Status
::
kSuccess
;
}
static
Status
can_implement
(
Arguments
const
&
args
)
{
return
Status
::
kSuccess
;
}
/// Executes one GEMM
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using
ElementA
=
typename
Mma
::
IteratorA
::
Element
;
using
LayoutA
=
typename
Mma
::
IteratorA
::
Layout
;
using
ElementB
=
typename
Mma
::
IteratorB
::
Element
;
using
LayoutB
=
typename
Mma
::
IteratorB
::
Layout
;
using
ElementC
=
typename
Epilogue
::
OutputTileIterator
::
Element
;
using
LayoutC
=
typename
Epilogue
::
OutputTileIterator
::
Layout
;
//
// Problem visitor.
//
ProblemVisitor
problem_visitor
(
params
.
problem_visitor
,
shared_storage
.
problem_visitor
,
blockIdx
.
x
);
// Outer 'persistent' loop to iterate over tiles
while
(
problem_visitor
.
next_tile
())
{
GemmCoord
problem_size
=
problem_visitor
.
problem_size
();
int32_t
problem_idx
=
problem_visitor
.
problem_index
();
int32_t
threadblock_idx
=
int32_t
(
problem_visitor
.
threadblock_idx
());
GemmCoord
grid_shape
=
problem_visitor
.
grid_shape
(
problem_size
);
// Load element pointers. Exchange pointers and strides if working on the transpose
ElementA
*
ptr_A
=
reinterpret_cast
<
ElementA
*>
((
kTransposed
?
params
.
ptr_B
[
problem_idx
]
:
params
.
ptr_A
[
problem_idx
]));
typename
LayoutA
::
LongIndex
ldm_A
=
(
kTransposed
?
params
.
ldb
[
problem_idx
]
:
params
.
lda
[
problem_idx
]);
ElementB
*
ptr_B
=
reinterpret_cast
<
ElementB
*>
((
kTransposed
?
params
.
ptr_A
[
problem_idx
]
:
params
.
ptr_B
[
problem_idx
]));
typename
LayoutB
::
LongIndex
ldm_B
=
(
kTransposed
?
params
.
lda
[
problem_idx
]
:
params
.
ldb
[
problem_idx
]);
// Compute threadblock location
ThreadblockSwizzle
threadblock_swizzle
;
GemmCoord
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
cutlass
::
gemm
::
GemmCoord
threadblock_offset
(
int
(
threadblock_idx
/
grid_shape
.
n
())
*
Mma
::
Shape
::
kM
,
int
(
threadblock_idx
%
grid_shape
.
n
())
*
Mma
::
Shape
::
kN
,
0
);
// Compute initial location in logical coordinates
cutlass
::
MatrixCoord
tb_offset_A
{
threadblock_offset
.
m
(),
threadblock_tile_offset
.
k
()
*
params
.
gemm_k_size
,
};
cutlass
::
MatrixCoord
tb_offset_B
{
threadblock_tile_offset
.
k
()
*
params
.
gemm_k_size
,
threadblock_offset
.
n
()};
// Problem size is a function of threadblock index in the K dimension
int
problem_size_k
;
if
(
threadblock_tile_offset
.
k
()
+
1
==
params
.
grid_tiled_shape
.
k
())
{
problem_size_k
=
problem_size
.
k
();
}
else
{
problem_size_k
=
(
threadblock_tile_offset
.
k
()
+
1
)
*
params
.
gemm_k_size
;
}
// Compute threadblock-scoped matrix multiply-add
int
gemm_k_iterations
=
(
problem_size_k
-
tb_offset_A
.
column
()
+
Mma
::
Shape
::
kK
-
1
)
/
Mma
::
Shape
::
kK
;
// Compute position within threadblock
int
thread_idx
=
threadIdx
.
x
;
// Construct iterators to A and B operands
typename
Mma
::
IteratorA
iterator_A
(
LayoutA
(
ldm_A
),
ptr_A
,
{
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorB
iterator_B
(
LayoutB
(
ldm_B
),
ptr_B
,
{
problem_size_k
,
problem_size
.
n
()},
thread_idx
,
tb_offset_B
);
typename
Mma
::
FragmentC
accumulators
;
accumulators
.
clear
();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int
warp_idx
=
canonical_warp_idx_sync
();
int
lane_idx
=
threadIdx
.
x
%
32
;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma
mma
(
shared_storage
.
kernel
.
main_loop
,
thread_idx
,
warp_idx
,
lane_idx
);
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads
();
// Compute threadblock-scoped matrix multiply-add
mma
(
gemm_k_iterations
,
accumulators
,
iterator_A
,
iterator_B
,
accumulators
);
//
// Epilogue
//
EpilogueOutputOp
output_op
(
params
.
output_op
);
ElementC
*
ptr_C
=
params
.
ptr_C_split
;
ElementC
*
ptr_D
=
params
.
ptr_D_split
;
LayoutC
layout_C
(
params
.
ldc
[
problem_idx
]);
LayoutC
layout_D
(
params
.
ldd
[
problem_idx
]);
typename
Epilogue
::
OutputTileIterator
::
Params
params_C
(
layout_C
);
typename
Epilogue
::
OutputTileIterator
::
Params
params_D
(
layout_D
);
// assume identity swizzle
MatrixCoord
threadblock_offset_C
(
threadblock_offset
.
m
(),
threadblock_offset
.
n
());
// Tile iterator loading from source tensor.
typename
Epilogue
::
OutputTileIterator
iterator_C
(
params_C
,
ptr_C
,
problem_size
.
mn
(),
thread_idx
,
threadblock_offset_C
);
iterator_C
.
add_pointer_offset
(
problem_size
.
m
()
*
problem_size
.
n
()
*
threadblock_tile_offset
.
k
()
+
gridDim
.
z
*
params
.
splitk_buffer_offsets
[
problem_idx
]);
// Tile iterator writing to destination tensor.
typename
Epilogue
::
OutputTileIterator
iterator_D
(
params_D
,
ptr_D
,
problem_size
.
mn
(),
thread_idx
,
threadblock_offset_C
);
iterator_D
.
add_pointer_offset
(
problem_size
.
m
()
*
problem_size
.
n
()
*
threadblock_tile_offset
.
k
()
+
gridDim
.
z
*
params
.
splitk_buffer_offsets
[
problem_idx
]);
Epilogue
epilogue
(
shared_storage
.
kernel
.
epilogue
,
thread_idx
,
warp_idx
,
lane_idx
);
// Execute the epilogue operator to update the destination tensor.
epilogue
(
output_op
,
iterator_D
,
accumulators
,
iterator_C
);
// Next tile
problem_visitor
.
advance
(
gridDim
.
x
);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
////////////////////////////////////////////////////////////////////////////////
// We need to distinguish here, since we want volta support. It is too much effort
// to write shared memory iterators that are probably needed for volta to function
// properly. As a result, we allow converters both after the LDG (for volta) and after
// the LDS for Turing+.
template
<
/// Iterator for B matrix in global memory
typename
IteratorB
,
/// Warp level Mma
typename
MmaOperator
,
/// Math operation perform by warp level operator
typename
MathOperator
>
struct
SetConverters
{
};
// Dequantize after LDG, so set transforms accordingly
template
<
/// Iterator for B matrix in global memory
typename
IteratorB
,
/// Mma Policy
typename
MmaOperator
>
struct
SetConverters
<
IteratorB
,
MmaOperator
,
arch
::
OpMultiplyAdd
>
{
using
TransformAfterLDG
=
FastInterleavedAndBiasedNumericArrayConverter
<
typename
MmaOperator
::
ArchMmaOperator
::
ElementB
,
typename
IteratorB
::
Element
,
IteratorB
::
Fragment
::
kElements
>
;
using
TransformAfterLDS
=
NumericArrayConverter
<
typename
MmaOperator
::
ArchMmaOperator
::
ElementB
,
typename
MmaOperator
::
ArchMmaOperator
::
ElementB
,
MmaOperator
::
FragmentB
::
kElements
>
;
};
// Dequantize after LDS, so set transforms accordingly
template
<
/// Iterator for B matrix in global memory
typename
IteratorB
,
/// Mma Policy
typename
MmaOperator
>
struct
SetConverters
<
IteratorB
,
MmaOperator
,
arch
::
OpMultiplyAddDequantizeInterleavedBToA
>
{
using
TransformAfterLDG
=
NumericArrayConverter
<
typename
IteratorB
::
Element
,
typename
IteratorB
::
Element
,
IteratorB
::
Fragment
::
kElements
>
;
using
TransformAfterLDS
=
FastInterleavedAndBiasedNumericArrayConverter
<
typename
MmaOperator
::
ArchMmaOperator
::
ElementB
,
typename
TransformAfterLDG
::
result_type
::
Element
,
MmaOperator
::
FragmentB
::
kElements
>
;
};
////////////////////////////////////////////////////////////////////////////////
template
<
/// Element type for A matrix operand
typename
ElementA_
,
/// Layout type for A matrix operand
typename
LayoutA_
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Element type for B matrix operand
typename
ElementB_
,
/// Layout type for B matrix operand
typename
LayoutB_
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for the input scale
typename
ElementScale_
,
/// Layout for the scale operand
typename
LayoutScale_
,
/// Access granularity of Scales in unit of elements
int
kAlignmentScale
,
/// Element type for internal accumulation
typename
ElementAccumulator_
,
/// Layout type for C and D matrix operands
typename
LayoutC_
,
/// Operator class tag
typename
OperatorClass_
,
/// Tag indicating architecture to tune for
typename
ArchTag_
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape_
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape_
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape_
,
/// Number of stages used in the pipelined mainloop
int
Stages
,
/// Operation performed by GEMM
typename
Operator_
,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption
SharedMemoryClear
=
SharedMemoryClearOption
::
kNone
,
///
typename
Enable
=
void
>
struct
DqMma
;
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
////////////////////////////////////////////////////////////////////////////////
template
<
typename
MmaShape
,
typename
Element
,
typename
Layout
,
WeightOnlyQuantOp
QuantOp
,
int
Alignment
,
typename
Enable
=
void
>
struct
DefaultScaleIteratorsMultistage
;
// Fine grained iterators
template
<
typename
MmaShape
,
typename
Element
,
typename
Layout
,
WeightOnlyQuantOp
QuantOp
,
int
Alignment
>
struct
DefaultScaleIteratorsMultistage
<
MmaShape
,
Element
,
Layout
,
QuantOp
,
Alignment
,
std
::
enable_if_t
<
isFinegrained
(
QuantOp
)
>>
{
using
IteratorScale
=
cutlass
::
transform
::
threadblock
::
FineGrainedScaleZeroIterator
<
cutlass
::
MatrixShape
<
1
,
MmaShape
::
kN
>
,
Element
,
Layout
,
0
,
Alignment
>
;
using
SmemIteratorScale
=
IteratorScale
;
};
// Per column iterators
template
<
typename
MmaShape
,
typename
Element
,
typename
Layout
,
WeightOnlyQuantOp
QuantOp
,
int
Alignment
>
struct
DefaultScaleIteratorsMultistage
<
MmaShape
,
Element
,
Layout
,
QuantOp
,
Alignment
,
std
::
enable_if_t
<!
isFinegrained
(
QuantOp
)
>>
{
// ThreadMap for scale iterator
static_assert
((
MmaShape
::
kN
%
Alignment
)
==
0
,
""
);
private:
using
IteratorScaleThreadMap
=
transform
::
PitchLinearStripminedThreadMap
<
layout
::
PitchLinearShape
<
MmaShape
::
kN
,
1
>
,
MmaShape
::
kN
/
Alignment
,
Alignment
>
;
public:
// Define iterators over tiles from the scale operand
using
IteratorScale
=
cutlass
::
transform
::
threadblock
::
PredicatedTileIterator
<
cutlass
::
MatrixShape
<
1
,
MmaShape
::
kN
>
,
Element
,
Layout
,
0
,
IteratorScaleThreadMap
,
Alignment
>
;
using
SmemIteratorScale
=
IteratorScale
;
};
////////////////////////////////////////////////////////////////////////////////
template
<
/// Type for element A
typename
ElementA
,
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Type for element B
typename
ElementB
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for the input scale
typename
ElementScale
,
/// Layout for the scale operand
typename
LayoutScale
,
/// Access granularity of Scales in unit of elements
int
kAlignmentScale
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Operator class tag
typename
OperatorClass
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Stages in GEMM
int
kStages
,
/// Operator performed by GEMM
typename
Operator_
,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption
SharedMemoryClear
>
struct
DqMma
<
ElementA
,
LayoutA
,
kAlignmentA
,
ElementB
,
LayoutB
,
kAlignmentB
,
ElementScale
,
LayoutScale
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
OperatorClass
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator_
,
SharedMemoryClear
,
typename
platform
::
enable_if
<
(
ArchTag
::
kMinComputeCapability
>=
80
&&
!
layout
::
IsColumnMajorTileInterleave
<
LayoutB
>::
value
)
>::
type
>
{
static_assert
(
platform
::
is_same
<
ElementA
,
half_t
>::
value
||
platform
::
is_same
<
ElementA
,
bfloat16_t
>::
value
||
platform
::
is_same
<
ElementA
,
float_e4m3_t
>::
value
,
"Element A must be fp16, fp8 or bf16"
);
using
OperatorInfo
=
arch
::
DetagOperator
<
Operator_
>
;
using
Operator
=
typename
OperatorInfo
::
Operator
;
static_assert
(
platform
::
is_same
<
Operator
,
arch
::
OpMultiplyAddDequantizeInterleavedBToA
>::
value
,
"Mma multistage must dequantize after ldsm"
);
static_assert
(
platform
::
is_same
<
ElementB
,
uint8_t
>::
value
||
platform
::
is_same
<
ElementB
,
uint4b_t
>::
value
,
"Element B must be uint8 or uint4"
);
static
cutlass
::
arch
::
CacheOperation
::
Kind
const
CacheOpA
=
((
sizeof_bits
<
ElementA
>::
value
*
kAlignmentA
)
==
128
)
?
cutlass
::
arch
::
CacheOperation
::
Global
:
cutlass
::
arch
::
CacheOperation
::
Always
;
static
cutlass
::
arch
::
CacheOperation
::
Kind
const
CacheOpB
=
((
sizeof_bits
<
ElementB
>::
value
*
kAlignmentB
)
==
128
)
?
cutlass
::
arch
::
CacheOperation
::
Global
:
cutlass
::
arch
::
CacheOperation
::
Always
;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using
MmaCore
=
typename
cutlass
::
gemm
::
threadblock
::
DefaultMmaCore
<
ThreadblockShape
,
WarpShape
,
InstructionShape
,
ElementA
,
LayoutA
,
ElementB
,
LayoutB
,
ElementAccumulator
,
layout
::
RowMajor
,
OperatorClass
,
std
::
max
(
kStages
,
3
),
Operator
,
false
,
CacheOpA
,
CacheOpB
>
;
// Define iterators over tiles from the A operand
using
ThreadMapA
=
typename
MmaCore
::
IteratorThreadMapA
;
using
AccessTypeA
=
cutlass
::
Array
<
ElementA
,
kAlignmentA
>
;
using
IteratorA
=
cutlass
::
transform
::
threadblock
::
PredicatedTileAccessIterator
<
cutlass
::
MatrixShape
<
ThreadblockShape
::
kM
,
ThreadblockShape
::
kK
>
,
ElementA
,
LayoutA
,
1
,
ThreadMapA
,
AccessTypeA
>
;
// Define iterators over tiles from the B operand
using
ThreadMapB
=
typename
MmaCore
::
IteratorThreadMapB
;
using
AccessTypeB
=
cutlass
::
Array
<
ElementB
,
kAlignmentB
>
;
using
IteratorB
=
cutlass
::
transform
::
threadblock
::
PredicatedTileAccessIterator
<
cutlass
::
MatrixShape
<
ThreadblockShape
::
kK
,
ThreadblockShape
::
kN
>
,
ElementB
,
LayoutB
,
0
,
ThreadMapB
,
AccessTypeB
>
;
using
ScaleIterators
=
DefaultScaleIteratorsMultistage
<
typename
MmaCore
::
Shape
,
ElementScale
,
LayoutScale
,
OperatorInfo
::
QuantOp
,
kAlignmentScale
>
;
// Define iterators over tiles from the scale operand
using
IteratorScale
=
typename
ScaleIterators
::
IteratorScale
;
using
SmemIteratorScale
=
typename
ScaleIterators
::
SmemIteratorScale
;
using
Converter
=
FastInterleavedAndBiasedNumericArrayConverter
<
ElementScale
,
ElementB
,
MmaCore
::
MmaPolicy
::
Operator
::
FragmentB
::
kElements
>
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
cutlass
::
gemm
::
threadblock
::
DqMmaMultistage
<
typename
MmaCore
::
Shape
,
IteratorA
,
typename
MmaCore
::
SmemIteratorA
,
MmaCore
::
kCacheOpA
,
IteratorB
,
typename
MmaCore
::
SmemIteratorB
,
MmaCore
::
kCacheOpB
,
IteratorScale
,
SmemIteratorScale
,
ElementAccumulator
,
layout
::
RowMajor
,
typename
MmaCore
::
MmaPolicy
,
kStages
,
Converter
,
OperatorInfo
::
QuantOp
,
SharedMemoryClear
>
;
};
// Specialization to handle column major interleave B
template
<
/// Type for element A
typename
ElementA
,
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Type for element B
typename
ElementB
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for the input scale
typename
ElementScale
,
/// Layout for the scale operand
typename
LayoutScale
,
/// Access granularity of Scales in unit of elements
int
kAlignmentScale
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Operator class tag
typename
OperatorClass
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Stages in GEMM
int
kStages
,
/// Operator performed by GEMM
typename
Operator_
,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption
SharedMemoryClear
>
struct
DqMma
<
ElementA
,
LayoutA
,
kAlignmentA
,
ElementB
,
LayoutB
,
kAlignmentB
,
ElementScale
,
LayoutScale
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
OperatorClass
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator_
,
SharedMemoryClear
,
typename
platform
::
enable_if
<
(
ArchTag
::
kMinComputeCapability
>=
80
&&
layout
::
IsColumnMajorTileInterleave
<
LayoutB
>::
value
)
>::
type
>
{
static_assert
(
platform
::
is_same
<
ElementA
,
half_t
>::
value
||
platform
::
is_same
<
ElementA
,
bfloat16_t
>::
value
||
platform
::
is_same
<
ElementA
,
float_e4m3_t
>::
value
,
"Element A must be fp16, fp8 or bf16"
);
using
OperatorInfo
=
arch
::
DetagOperator
<
Operator_
>
;
using
Operator
=
typename
OperatorInfo
::
Operator
;
static_assert
(
platform
::
is_same
<
Operator
,
arch
::
OpMultiplyAddDequantizeInterleavedBToA
>::
value
,
"Mma multistage must dequantize after ldsm"
);
static_assert
(
platform
::
is_same
<
ElementB
,
uint8_t
>::
value
||
platform
::
is_same
<
ElementB
,
uint4b_t
>::
value
,
"Element B must be uint8 or uint4"
);
static
cutlass
::
arch
::
CacheOperation
::
Kind
const
CacheOpA
=
((
sizeof_bits
<
ElementA
>::
value
*
kAlignmentA
)
==
128
)
?
cutlass
::
arch
::
CacheOperation
::
Global
:
cutlass
::
arch
::
CacheOperation
::
Always
;
static
cutlass
::
arch
::
CacheOperation
::
Kind
const
CacheOpB
=
((
sizeof_bits
<
ElementB
>::
value
*
kAlignmentB
)
==
128
)
?
cutlass
::
arch
::
CacheOperation
::
Global
:
cutlass
::
arch
::
CacheOperation
::
Always
;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using
MmaCore
=
typename
cutlass
::
gemm
::
threadblock
::
DefaultMmaCore
<
ThreadblockShape
,
WarpShape
,
InstructionShape
,
ElementA
,
LayoutA
,
ElementB
,
layout
::
ColumnMajor
,
ElementAccumulator
,
layout
::
RowMajor
,
OperatorClass
,
std
::
max
(
kStages
,
3
),
Operator
,
false
,
CacheOpA
,
CacheOpB
>
;
// Define iterators over tiles from the A operand
using
ThreadMapA
=
typename
MmaCore
::
IteratorThreadMapA
;
using
AccessTypeA
=
cutlass
::
Array
<
ElementA
,
kAlignmentA
>
;
using
IteratorA
=
cutlass
::
transform
::
threadblock
::
PredicatedTileAccessIterator
<
cutlass
::
MatrixShape
<
ThreadblockShape
::
kM
,
ThreadblockShape
::
kK
>
,
ElementA
,
LayoutA
,
1
,
ThreadMapA
,
AccessTypeA
>
;
private:
static
constexpr
int
ColumnsInterleaved
=
LayoutB
::
kColumnsInterleaved
;
static
constexpr
int
RowsPerTile
=
LayoutB
::
kRowsPerTile
;
static_assert
(
!
(
MmaCore
::
Shape
::
kN
%
ColumnsInterleaved
),
""
);
static_assert
(
RowsPerTile
==
MmaCore
::
Shape
::
kK
,
""
);
using
OriginalThreadMap
=
typename
MmaCore
::
IteratorThreadMapB
;
using
OriginalWarpArrangement
=
typename
OriginalThreadMap
::
Detail
::
WarpThreadArrangement
;
static_assert
(
!
(
OriginalWarpArrangement
::
kStrided
%
ColumnsInterleaved
),
""
);
using
GmemIteratorShape
=
MatrixShape
<
MmaCore
::
Shape
::
kK
*
ColumnsInterleaved
,
MmaCore
::
Shape
::
kN
/
ColumnsInterleaved
>
;
using
GmemThreadMapB
=
transform
::
PitchLinearWarpRakedThreadMap
<
layout
::
PitchLinearShape
<
GmemIteratorShape
::
kRow
,
GmemIteratorShape
::
kColumn
>
,
OriginalThreadMap
::
kThreads
,
layout
::
PitchLinearShape
<
OriginalWarpArrangement
::
kContiguous
*
ColumnsInterleaved
,
OriginalWarpArrangement
::
kStrided
/
ColumnsInterleaved
>
,
MmaCore
::
kAccessSizeInBits
/
sizeof_bits
<
ElementB
>::
value
>
;
public:
// Define iterators over tiles from the B operand
using
ThreadMapB
=
typename
MmaCore
::
IteratorThreadMapB
;
using
AccessTypeB
=
cutlass
::
Array
<
ElementB
,
kAlignmentB
>
;
using
IteratorB
=
cutlass
::
transform
::
threadblock
::
PredicatedTileAccessIterator
<
GmemIteratorShape
,
ElementB
,
layout
::
ColumnMajor
,
0
,
GmemThreadMapB
,
AccessTypeB
>
;
using
ScaleIterators
=
DefaultScaleIteratorsMultistage
<
typename
MmaCore
::
Shape
,
ElementScale
,
LayoutScale
,
OperatorInfo
::
QuantOp
,
kAlignmentScale
>
;
// Define iterators over tiles from the scale operand
using
IteratorScale
=
typename
ScaleIterators
::
IteratorScale
;
using
SmemIteratorScale
=
typename
ScaleIterators
::
SmemIteratorScale
;
using
Converter
=
FastInterleavedAndBiasedNumericArrayConverter
<
ElementScale
,
ElementB
,
MmaCore
::
MmaPolicy
::
Operator
::
FragmentB
::
kElements
>
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
cutlass
::
gemm
::
threadblock
::
DqMmaMultistage
<
typename
MmaCore
::
Shape
,
IteratorA
,
typename
MmaCore
::
SmemIteratorA
,
MmaCore
::
kCacheOpA
,
IteratorB
,
typename
MmaCore
::
SmemIteratorB
,
MmaCore
::
kCacheOpB
,
IteratorScale
,
SmemIteratorScale
,
ElementAccumulator
,
layout
::
RowMajor
,
typename
MmaCore
::
MmaPolicy
,
kStages
,
Converter
,
OperatorInfo
::
QuantOp
,
SharedMemoryClear
>
;
};
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
////////////////////////////////////////////////////////////////////////////////
template
<
typename
MmaShape
,
typename
Element
,
typename
Layout
,
WeightOnlyQuantOp
QuantOp
,
int
Alignment
,
typename
Enable
=
void
>
struct
DefaultScaleIteratorsPipelined
;
// Fine grained iterators
template
<
typename
MmaShape
,
typename
Element
,
typename
Layout
,
WeightOnlyQuantOp
QuantOp
,
int
Alignment
>
struct
DefaultScaleIteratorsPipelined
<
MmaShape
,
Element
,
Layout
,
QuantOp
,
Alignment
,
std
::
enable_if_t
<
isFinegrained
(
QuantOp
)
>>
{
private:
using
SmemScaleType
=
half_t
;
public:
using
IteratorScale
=
cutlass
::
transform
::
threadblock
::
FineGrainedScaleZeroIterator
<
cutlass
::
MatrixShape
<
1
,
MmaShape
::
kN
>
,
Element
,
Layout
,
0
,
Alignment
>
;
using
SmemIteratorScale
=
cutlass
::
transform
::
threadblock
::
FineGrainedScaleZeroIterator
<
cutlass
::
MatrixShape
<
1
,
MmaShape
::
kN
>
,
SmemScaleType
,
Layout
,
0
,
Alignment
>
;
};
// Per column iterators
template
<
typename
MmaShape
,
typename
Element
,
typename
Layout
,
WeightOnlyQuantOp
QuantOp
,
int
Alignment
>
struct
DefaultScaleIteratorsPipelined
<
MmaShape
,
Element
,
Layout
,
QuantOp
,
Alignment
,
std
::
enable_if_t
<!
isFinegrained
(
QuantOp
)
>>
{
static_assert
((
MmaShape
::
kN
%
Alignment
)
==
0
,
""
);
private:
// ThreadMap for scale iterator
using
IteratorScaleThreadMap
=
transform
::
PitchLinearStripminedThreadMap
<
layout
::
PitchLinearShape
<
MmaShape
::
kN
,
1
>
,
MmaShape
::
kN
/
Alignment
,
Alignment
>
;
using
SmemScaleType
=
half_t
;
public:
// Define iterators over tiles from the scale operand
using
IteratorScale
=
cutlass
::
transform
::
threadblock
::
PredicatedTileIterator
<
cutlass
::
MatrixShape
<
1
,
MmaShape
::
kN
>
,
Element
,
Layout
,
0
,
IteratorScaleThreadMap
,
Alignment
>
;
using
SmemIteratorScale
=
cutlass
::
transform
::
threadblock
::
PredicatedTileIterator
<
cutlass
::
MatrixShape
<
1
,
MmaShape
::
kN
>
,
SmemScaleType
,
Layout
,
0
,
IteratorScaleThreadMap
,
Alignment
>
;
};
////////////////////////////////////////////////////////////////////////////////
template
<
/// Type for element A
typename
ElementA
,
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Type for element B
typename
ElementB
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for the input scale
typename
ElementScale
,
/// Layout for the scale operand
typename
LayoutScale
,
/// Access granularity of Scales in unit of elements
int
kAlignmentScale
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Operator class tag
typename
OperatorClass
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator_
>
struct
DqMma
<
ElementA
,
LayoutA
,
kAlignmentA
,
ElementB
,
LayoutB
,
kAlignmentB
,
ElementScale
,
LayoutScale
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
OperatorClass
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator_
,
SharedMemoryClearOption
::
kNone
,
typename
platform
::
enable_if
<
(
ArchTag
::
kMinComputeCapability
<
80
&&
!
layout
::
IsColumnMajorTileInterleave
<
LayoutB
>::
value
)
>::
type
>
{
static_assert
(
platform
::
is_same
<
ElementA
,
half_t
>::
value
||
platform
::
is_same
<
ElementA
,
bfloat16_t
>::
value
,
"Element A must be fp16 or bf16"
);
static_assert
(
platform
::
is_same
<
ElementB
,
uint8_t
>::
value
||
platform
::
is_same
<
ElementB
,
uint4b_t
>::
value
,
"Element B must be uint8 or uint4"
);
using
OperatorInfo
=
arch
::
DetagOperator
<
Operator_
>
;
using
Operator
=
typename
OperatorInfo
::
Operator
;
static_assert
(
OperatorInfo
::
QuantOp
==
WeightOnlyQuantOp
::
PER_COLUMN_SCALE_ONLY
,
""
);
static
constexpr
bool
DqAfterLDG
=
platform
::
is_same
<
arch
::
OpMultiplyAdd
,
Operator
>::
value
;
using
MmaCoreElementA
=
half_t
;
using
MmaCoreElementB
=
typename
platform
::
conditional
<
DqAfterLDG
,
MmaCoreElementA
,
ElementB
>::
type
;
// Define the MmaCore components
using
MmaCore
=
typename
cutlass
::
gemm
::
threadblock
::
DefaultMmaCore
<
ThreadblockShape
,
WarpShape
,
InstructionShape
,
MmaCoreElementA
,
LayoutA
,
MmaCoreElementB
,
LayoutB
,
ElementAccumulator
,
layout
::
RowMajor
,
OperatorClass
,
2
,
Operator
>
;
// Define iterators over tiles from the A operand
using
IteratorA
=
cutlass
::
transform
::
threadblock
::
PredicatedTileIterator
<
cutlass
::
MatrixShape
<
MmaCore
::
Shape
::
kM
,
MmaCore
::
Shape
::
kK
>
,
ElementA
,
LayoutA
,
1
,
typename
MmaCore
::
IteratorThreadMapA
,
kAlignmentA
>
;
// Define iterators over tiles from the B operand
using
IteratorB
=
cutlass
::
transform
::
threadblock
::
PredicatedTileIterator
<
cutlass
::
MatrixShape
<
MmaCore
::
Shape
::
kK
,
MmaCore
::
Shape
::
kN
>
,
ElementB
,
LayoutB
,
0
,
typename
MmaCore
::
IteratorThreadMapB
,
kAlignmentB
>
;
using
ScaleIterators
=
DefaultScaleIteratorsPipelined
<
typename
MmaCore
::
Shape
,
ElementScale
,
LayoutScale
,
OperatorInfo
::
QuantOp
,
kAlignmentScale
>
;
// Define iterators over tiles from the scale operand
using
IteratorScale
=
typename
ScaleIterators
::
IteratorScale
;
using
SmemIteratorScale
=
typename
ScaleIterators
::
SmemIteratorScale
;
using
Converters
=
SetConverters
<
IteratorB
,
typename
MmaCore
::
MmaPolicy
::
Operator
,
Operator
>
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
cutlass
::
gemm
::
threadblock
::
DqMmaPipelined
<
typename
MmaCore
::
Shape
,
IteratorA
,
typename
MmaCore
::
SmemIteratorA
,
IteratorB
,
typename
MmaCore
::
SmemIteratorB
,
IteratorScale
,
SmemIteratorScale
,
ElementAccumulator
,
layout
::
RowMajor
,
typename
MmaCore
::
MmaPolicy
,
typename
Converters
::
TransformAfterLDG
,
typename
Converters
::
TransformAfterLDS
,
OperatorInfo
::
QuantOp
>
;
};
// Specialization to handle column major interleave B
template
<
/// Type for element A
typename
ElementA
,
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Type for element B
typename
ElementB
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for the input scale
typename
ElementScale
,
/// Layout for the scale operand
typename
LayoutScale
,
/// Access granularity of Scales in unit of elements
int
kAlignmentScale
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Operator class tag
typename
OperatorClass
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator_
>
struct
DqMma
<
ElementA
,
LayoutA
,
kAlignmentA
,
ElementB
,
LayoutB
,
kAlignmentB
,
ElementScale
,
LayoutScale
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
OperatorClass
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator_
,
SharedMemoryClearOption
::
kNone
,
typename
platform
::
enable_if
<
(
ArchTag
::
kMinComputeCapability
<
80
&&
layout
::
IsColumnMajorTileInterleave
<
LayoutB
>::
value
)
>::
type
>
{
static_assert
(
platform
::
is_same
<
ElementA
,
half_t
>::
value
||
platform
::
is_same
<
ElementA
,
bfloat16_t
>::
value
,
"Element A must be fp16 or bf16"
);
static_assert
(
platform
::
is_same
<
ElementB
,
uint8_t
>::
value
||
platform
::
is_same
<
ElementB
,
uint4b_t
>::
value
,
"Element B must be uint8 or uint4"
);
using
OperatorInfo
=
arch
::
DetagOperator
<
Operator_
>
;
using
Operator
=
typename
OperatorInfo
::
Operator
;
static
constexpr
bool
DqAfterLDG
=
platform
::
is_same
<
arch
::
OpMultiplyAdd
,
Operator
>::
value
;
using
MmaCoreElementA
=
half_t
;
using
MmaCoreElementB
=
typename
platform
::
conditional
<
DqAfterLDG
,
MmaCoreElementA
,
ElementB
>::
type
;
// Define the MmaCore components
using
MmaCore
=
typename
cutlass
::
gemm
::
threadblock
::
DefaultMmaCore
<
ThreadblockShape
,
WarpShape
,
InstructionShape
,
MmaCoreElementA
,
LayoutA
,
MmaCoreElementB
,
layout
::
ColumnMajor
,
ElementAccumulator
,
layout
::
RowMajor
,
OperatorClass
,
2
,
Operator
>
;
// Define iterators over tiles from the A operand
using
IteratorA
=
cutlass
::
transform
::
threadblock
::
PredicatedTileIterator
<
cutlass
::
MatrixShape
<
MmaCore
::
Shape
::
kM
,
MmaCore
::
Shape
::
kK
>
,
ElementA
,
LayoutA
,
1
,
typename
MmaCore
::
IteratorThreadMapA
,
kAlignmentA
>
;
private:
static
constexpr
int
ColumnsInterleaved
=
LayoutB
::
kColumnsInterleaved
;
static
constexpr
int
RowsPerTile
=
LayoutB
::
kRowsPerTile
;
static_assert
(
!
(
MmaCore
::
Shape
::
kN
%
ColumnsInterleaved
),
""
);
static_assert
(
RowsPerTile
==
MmaCore
::
Shape
::
kK
,
""
);
using
OriginalThreadMap
=
typename
MmaCore
::
IteratorThreadMapB
;
using
OriginalWarpArrangement
=
typename
OriginalThreadMap
::
Detail
::
WarpThreadArrangement
;
static_assert
(
!
(
OriginalWarpArrangement
::
kStrided
%
ColumnsInterleaved
),
""
);
using
GmemIteratorShape
=
MatrixShape
<
MmaCore
::
Shape
::
kK
*
ColumnsInterleaved
,
MmaCore
::
Shape
::
kN
/
ColumnsInterleaved
>
;
using
GmemThreadMapB
=
transform
::
PitchLinearWarpRakedThreadMap
<
layout
::
PitchLinearShape
<
GmemIteratorShape
::
kRow
,
GmemIteratorShape
::
kColumn
>
,
OriginalThreadMap
::
kThreads
,
layout
::
PitchLinearShape
<
OriginalWarpArrangement
::
kContiguous
*
ColumnsInterleaved
,
OriginalWarpArrangement
::
kStrided
/
ColumnsInterleaved
>
,
MmaCore
::
kAccessSizeInBits
/
sizeof_bits
<
ElementB
>::
value
>
;
public:
// Define iterators over tiles from the B operand
using
IteratorB
=
cutlass
::
transform
::
threadblock
::
PredicatedTileIterator
<
GmemIteratorShape
,
ElementB
,
layout
::
ColumnMajor
,
0
,
GmemThreadMapB
,
kAlignmentB
>
;
// ThreadMap for scale iterator
static_assert
((
MmaCore
::
Shape
::
kN
%
kAlignmentScale
)
==
0
,
""
);
using
IteratorScaleThreadMap
=
transform
::
PitchLinearStripminedThreadMap
<
layout
::
PitchLinearShape
<
MmaCore
::
Shape
::
kN
,
1
>
,
MmaCore
::
Shape
::
kN
/
kAlignmentScale
,
kAlignmentScale
>
;
using
ScaleIterators
=
DefaultScaleIteratorsPipelined
<
typename
MmaCore
::
Shape
,
ElementScale
,
LayoutScale
,
OperatorInfo
::
QuantOp
,
kAlignmentScale
>
;
// Define iterators over tiles from the scale operand
using
IteratorScale
=
typename
ScaleIterators
::
IteratorScale
;
using
SmemIteratorScale
=
typename
ScaleIterators
::
SmemIteratorScale
;
using
Converters
=
SetConverters
<
IteratorB
,
typename
MmaCore
::
MmaPolicy
::
Operator
,
Operator
>
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
cutlass
::
gemm
::
threadblock
::
DqMmaPipelined
<
typename
MmaCore
::
Shape
,
IteratorA
,
typename
MmaCore
::
SmemIteratorA
,
IteratorB
,
typename
MmaCore
::
SmemIteratorB
,
IteratorScale
,
SmemIteratorScale
,
ElementAccumulator
,
layout
::
RowMajor
,
typename
MmaCore
::
MmaPolicy
,
typename
Converters
::
TransformAfterLDG
,
typename
Converters
::
TransformAfterLDS
,
OperatorInfo
::
QuantOp
>
;
};
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2)
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
>
struct
DefaultMma
<
cutlass
::
half_t
,
LayoutA
,
kAlignmentA
,
uint8_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator
>
{
private:
static
constexpr
int
kAlignmentScale
=
128
/
sizeof_bits
<
half_t
>::
value
;
using
Mma
=
DqMma
<
half_t
,
LayoutA
,
kAlignmentA
,
uint8_t
,
LayoutB
,
kAlignmentB
,
half_t
,
layout
::
RowMajor
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator
>
;
public:
// Define the MmaCore components
using
MmaCore
=
typename
Mma
::
MmaCore
;
// Define iterators over tiles from the A operand
using
IteratorA
=
typename
Mma
::
IteratorA
;
// Define iterators over tiles from the B operand
using
IteratorB
=
typename
Mma
::
IteratorB
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
typename
Mma
::
ThreadblockMma
;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
>
struct
DefaultMma
<
cutlass
::
half_t
,
LayoutA
,
kAlignmentA
,
uint4b_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator
>
{
private:
static
constexpr
int
kAlignmentScale
=
128
/
sizeof_bits
<
half_t
>::
value
;
using
Mma
=
DqMma
<
half_t
,
LayoutA
,
kAlignmentA
,
uint4b_t
,
LayoutB
,
kAlignmentB
,
half_t
,
layout
::
RowMajor
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator
>
;
public:
// Define the MmaCore components
using
MmaCore
=
typename
Mma
::
MmaCore
;
// Define iterators over tiles from the A operand
using
IteratorA
=
typename
Mma
::
IteratorA
;
// Define iterators over tiles from the B operand
using
IteratorB
=
typename
Mma
::
IteratorB
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
typename
Mma
::
ThreadblockMma
;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
/// (stage>=3)
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
,
///
int
kStages
,
/// Shared memory clear option
SharedMemoryClearOption
SharedMemoryClear
>
struct
DefaultMma
<
cutlass
::
half_t
,
LayoutA
,
kAlignmentA
,
uint8_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator
,
false
,
SharedMemoryClear
>
{
private:
static
constexpr
int
kAlignmentScale
=
128
/
sizeof_bits
<
half_t
>::
value
;
using
Mma
=
DqMma
<
half_t
,
LayoutA
,
kAlignmentA
,
uint8_t
,
LayoutB
,
kAlignmentB
,
half_t
,
layout
::
RowMajor
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator
,
SharedMemoryClear
>
;
public:
// Define the MmaCore components
using
MmaCore
=
typename
Mma
::
MmaCore
;
// Define iterators over tiles from the A operand
using
IteratorA
=
typename
Mma
::
IteratorA
;
// Define iterators over tiles from the B operand
using
IteratorB
=
typename
Mma
::
IteratorB
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
typename
Mma
::
ThreadblockMma
;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
/// (stage>=3)
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
,
///
int
kStages
,
/// Shared memory clear option
SharedMemoryClearOption
SharedMemoryClear
>
struct
DefaultMma
<
cutlass
::
half_t
,
LayoutA
,
kAlignmentA
,
uint4b_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator
,
false
,
SharedMemoryClear
>
{
private:
static
constexpr
int
kAlignmentScale
=
128
/
sizeof_bits
<
half_t
>::
value
;
using
Mma
=
DqMma
<
half_t
,
LayoutA
,
kAlignmentA
,
uint4b_t
,
LayoutB
,
kAlignmentB
,
half_t
,
layout
::
RowMajor
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator
,
SharedMemoryClear
>
;
public:
// Define the MmaCore components
using
MmaCore
=
typename
Mma
::
MmaCore
;
// Define iterators over tiles from the A operand
using
IteratorA
=
typename
Mma
::
IteratorA
;
// Define iterators over tiles from the B operand
using
IteratorB
=
typename
Mma
::
IteratorB
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
typename
Mma
::
ThreadblockMma
;
};
#ifdef ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
/// (stage>=3)
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
,
///
int
kStages
,
/// Shared memory clear option
SharedMemoryClearOption
SharedMemoryClear
>
struct
DefaultMma
<
cutlass
::
float_e4m3_t
,
LayoutA
,
kAlignmentA
,
uint4b_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator
,
false
,
SharedMemoryClear
>
{
private:
static
constexpr
int
kAlignmentScale
=
128
/
sizeof_bits
<
half_t
>::
value
;
using
Mma
=
DqMma
<
cutlass
::
float_e4m3_t
,
LayoutA
,
kAlignmentA
,
uint4b_t
,
LayoutB
,
kAlignmentB
,
half_t
,
layout
::
RowMajor
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator
,
SharedMemoryClear
>
;
public:
// Define the MmaCore components
using
MmaCore
=
typename
Mma
::
MmaCore
;
// Define iterators over tiles from the A operand
using
IteratorA
=
typename
Mma
::
IteratorA
;
// Define iterators over tiles from the B operand
using
IteratorB
=
typename
Mma
::
IteratorB
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
typename
Mma
::
ThreadblockMma
;
};
#endif
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
// large tile when not enough shared mem is present to do 3+ stage
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption
SharedMemoryClear
,
/// Gather operand A by using an index array
bool
GatherA
,
/// Gather operand B by using an index array
bool
GatherB
>
struct
DefaultMma
<
half_t
,
LayoutA
,
kAlignmentA
,
half_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
arch
::
Sm80
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator
,
false
,
SharedMemoryClear
,
GatherA
,
GatherB
>
{
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using
MmaCore
=
typename
cutlass
::
gemm
::
threadblock
::
DefaultMmaCore
<
ThreadblockShape
,
WarpShape
,
InstructionShape
,
half_t
,
LayoutA
,
half_t
,
LayoutB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
3
,
Operator
>
;
// Define iterators over tiles from the A operand
using
ThreadMapA
=
typename
MmaCore
::
IteratorThreadMapA
;
using
AccessTypeA
=
cutlass
::
Array
<
half_t
,
kAlignmentA
>
;
using
IteratorA
=
cutlass
::
transform
::
threadblock
::
PredicatedTileAccessIterator
<
cutlass
::
MatrixShape
<
ThreadblockShape
::
kM
,
ThreadblockShape
::
kK
>
,
half_t
,
LayoutA
,
1
,
ThreadMapA
,
AccessTypeA
,
GatherA
>
;
// Define iterators over tiles from the B operand
using
ThreadMapB
=
typename
MmaCore
::
IteratorThreadMapB
;
using
AccessTypeB
=
cutlass
::
Array
<
half_t
,
kAlignmentB
>
;
using
IteratorB
=
cutlass
::
transform
::
threadblock
::
PredicatedTileAccessIterator
<
cutlass
::
MatrixShape
<
ThreadblockShape
::
kK
,
ThreadblockShape
::
kN
>
,
half_t
,
LayoutB
,
0
,
ThreadMapB
,
AccessTypeB
,
GatherB
>
;
// Define the threadblock-scoped multistage matrix multiply
using
ThreadblockMma
=
cutlass
::
gemm
::
threadblock
::
MmaMultistage
<
typename
MmaCore
::
Shape
,
IteratorA
,
typename
MmaCore
::
SmemIteratorA
,
MmaCore
::
kCacheOpA
,
IteratorB
,
typename
MmaCore
::
SmemIteratorB
,
MmaCore
::
kCacheOpB
,
ElementAccumulator
,
layout
::
RowMajor
,
typename
MmaCore
::
MmaPolicy
,
2
>
;
};
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption
SharedMemoryClear
,
/// Gather operand A by using an index array
bool
GatherA
,
/// Gather operand B by using an index array
bool
GatherB
>
struct
DefaultMma
<
bfloat16_t
,
LayoutA
,
kAlignmentA
,
bfloat16_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator
,
false
,
SharedMemoryClear
,
GatherA
,
GatherB
>
{
private:
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
static
constexpr
bool
arch_has_bf16_mma
=
ArchTag
::
kMinComputeCapability
>=
80
;
using
MmaElementA
=
typename
platform
::
conditional
<
arch_has_bf16_mma
,
bfloat16_t
,
half_t
>::
type
;
using
MmaElementB
=
typename
platform
::
conditional
<
arch_has_bf16_mma
,
bfloat16_t
,
half_t
>::
type
;
public:
// Define the MmaCore components
using
MmaCore
=
typename
cutlass
::
gemm
::
threadblock
::
DefaultMmaCore
<
ThreadblockShape
,
WarpShape
,
InstructionShape
,
MmaElementA
,
LayoutA
,
MmaElementB
,
LayoutB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
2
,
Operator
>
;
using
IteratorA
=
cutlass
::
transform
::
threadblock
::
PredicatedTileIterator
<
cutlass
::
MatrixShape
<
MmaCore
::
Shape
::
kM
,
MmaCore
::
Shape
::
kK
>
,
bfloat16_t
,
LayoutA
,
1
,
typename
MmaCore
::
IteratorThreadMapA
,
kAlignmentA
,
GatherA
>
;
// Define iterators over tiles from the B operand
using
IteratorB
=
cutlass
::
transform
::
threadblock
::
PredicatedTileIterator
<
cutlass
::
MatrixShape
<
MmaCore
::
Shape
::
kK
,
MmaCore
::
Shape
::
kN
>
,
bfloat16_t
,
LayoutB
,
0
,
typename
MmaCore
::
IteratorThreadMapB
,
kAlignmentB
,
GatherB
>
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
cutlass
::
gemm
::
threadblock
::
MmaPipelined
<
typename
MmaCore
::
Shape
,
IteratorA
,
typename
MmaCore
::
SmemIteratorA
,
IteratorB
,
typename
MmaCore
::
SmemIteratorB
,
ElementAccumulator
,
layout
::
RowMajor
,
typename
MmaCore
::
MmaPolicy
>
;
};
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
// large tile when not enough shared mem is present to do 3+ stage
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption
SharedMemoryClear
,
/// Gather operand A by using an index array
bool
GatherA
,
/// Gather operand B by using an index array
bool
GatherB
>
struct
DefaultMma
<
bfloat16_t
,
LayoutA
,
kAlignmentA
,
bfloat16_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
arch
::
Sm80
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator
,
false
,
SharedMemoryClear
,
GatherA
,
GatherB
>
{
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using
MmaCore
=
typename
cutlass
::
gemm
::
threadblock
::
DefaultMmaCore
<
ThreadblockShape
,
WarpShape
,
InstructionShape
,
bfloat16_t
,
LayoutA
,
bfloat16_t
,
LayoutB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
3
,
Operator
>
;
// Define iterators over tiles from the A operand
using
ThreadMapA
=
typename
MmaCore
::
IteratorThreadMapA
;
using
AccessTypeA
=
cutlass
::
Array
<
bfloat16_t
,
kAlignmentA
>
;
using
IteratorA
=
cutlass
::
transform
::
threadblock
::
PredicatedTileAccessIterator
<
cutlass
::
MatrixShape
<
ThreadblockShape
::
kM
,
ThreadblockShape
::
kK
>
,
bfloat16_t
,
LayoutA
,
1
,
ThreadMapA
,
AccessTypeA
,
GatherA
>
;
// Define iterators over tiles from the B operand
using
ThreadMapB
=
typename
MmaCore
::
IteratorThreadMapB
;
using
AccessTypeB
=
cutlass
::
Array
<
bfloat16_t
,
kAlignmentB
>
;
using
IteratorB
=
cutlass
::
transform
::
threadblock
::
PredicatedTileAccessIterator
<
cutlass
::
MatrixShape
<
ThreadblockShape
::
kK
,
ThreadblockShape
::
kN
>
,
bfloat16_t
,
LayoutB
,
0
,
ThreadMapB
,
AccessTypeB
,
GatherB
>
;
// Define the threadblock-scoped multistage matrix multiply
using
ThreadblockMma
=
cutlass
::
gemm
::
threadblock
::
MmaMultistage
<
typename
MmaCore
::
Shape
,
IteratorA
,
typename
MmaCore
::
SmemIteratorA
,
MmaCore
::
kCacheOpA
,
IteratorB
,
typename
MmaCore
::
SmemIteratorB
,
MmaCore
::
kCacheOpB
,
ElementAccumulator
,
layout
::
RowMajor
,
typename
MmaCore
::
MmaPolicy
,
2
>
;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
>
struct
DefaultMma
<
cutlass
::
bfloat16_t
,
LayoutA
,
kAlignmentA
,
uint8_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator
>
{
private:
static
constexpr
int
kAlignmentScale
=
128
/
sizeof_bits
<
bfloat16_t
>::
value
;
using
Mma
=
DqMma
<
bfloat16_t
,
LayoutA
,
kAlignmentA
,
uint8_t
,
LayoutB
,
kAlignmentB
,
bfloat16_t
,
layout
::
RowMajor
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator
>
;
public:
// Define the MmaCore components
using
MmaCore
=
typename
Mma
::
MmaCore
;
// Define iterators over tiles from the A operand
using
IteratorA
=
typename
Mma
::
IteratorA
;
// Define iterators over tiles from the B operand
using
IteratorB
=
typename
Mma
::
IteratorB
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
typename
Mma
::
ThreadblockMma
;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
>
struct
DefaultMma
<
cutlass
::
bfloat16_t
,
LayoutA
,
kAlignmentA
,
uint4b_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator
>
{
private:
static
constexpr
int
kAlignmentScale
=
128
/
sizeof_bits
<
bfloat16_t
>::
value
;
using
Mma
=
DqMma
<
bfloat16_t
,
LayoutA
,
kAlignmentA
,
uint4b_t
,
LayoutB
,
kAlignmentB
,
bfloat16_t
,
layout
::
RowMajor
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
2
,
Operator
>
;
public:
// Define the MmaCore components
using
MmaCore
=
typename
Mma
::
MmaCore
;
// Define iterators over tiles from the A operand
using
IteratorA
=
typename
Mma
::
IteratorA
;
// Define iterators over tiles from the B operand
using
IteratorB
=
typename
Mma
::
IteratorB
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
typename
Mma
::
ThreadblockMma
;
};
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
,
///
int
kStages
,
/// Shared memory clear option
SharedMemoryClearOption
SharedMemoryClear
>
struct
DefaultMma
<
cutlass
::
bfloat16_t
,
LayoutA
,
kAlignmentA
,
uint8_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator
,
false
,
SharedMemoryClear
>
{
private:
static
constexpr
int
kAlignmentScale
=
128
/
sizeof_bits
<
bfloat16_t
>::
value
;
using
Mma
=
DqMma
<
bfloat16_t
,
LayoutA
,
kAlignmentA
,
uint8_t
,
LayoutB
,
kAlignmentB
,
bfloat16_t
,
layout
::
RowMajor
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator
,
SharedMemoryClear
>
;
public:
// Define the MmaCore components
using
MmaCore
=
typename
Mma
::
MmaCore
;
// Define iterators over tiles from the A operand
using
IteratorA
=
typename
Mma
::
IteratorA
;
// Define iterators over tiles from the B operand
using
IteratorB
=
typename
Mma
::
IteratorB
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
typename
Mma
::
ThreadblockMma
;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
template
<
/// Layout type for A matrix operand
typename
LayoutA
,
/// Access granularity of A matrix in units of elements
int
kAlignmentA
,
/// Layout type for B matrix operand
typename
LayoutB
,
/// Access granularity of B matrix in units of elements
int
kAlignmentB
,
/// Element type for internal accumulation
typename
ElementAccumulator
,
/// Tag indicating architecture to tune for
typename
ArchTag
,
/// Threadblock-level tile size (concept: GemmShape)
typename
ThreadblockShape
,
/// Warp-level tile size (concept: GemmShape)
typename
WarpShape
,
/// Instruction-level tile size (concept: GemmShape)
typename
InstructionShape
,
/// Operation performed by GEMM
typename
Operator
,
///
int
kStages
,
/// Shared memory clear option
SharedMemoryClearOption
SharedMemoryClear
>
struct
DefaultMma
<
cutlass
::
bfloat16_t
,
LayoutA
,
kAlignmentA
,
uint4b_t
,
LayoutB
,
kAlignmentB
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator
,
false
,
SharedMemoryClear
>
{
private:
static
constexpr
int
kAlignmentScale
=
128
/
sizeof_bits
<
bfloat16_t
>::
value
;
using
Mma
=
DqMma
<
bfloat16_t
,
LayoutA
,
kAlignmentA
,
uint4b_t
,
LayoutB
,
kAlignmentB
,
bfloat16_t
,
layout
::
RowMajor
,
kAlignmentScale
,
ElementAccumulator
,
layout
::
RowMajor
,
arch
::
OpClassTensorOp
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
kStages
,
Operator
,
SharedMemoryClear
>
;
public:
// Define the MmaCore components
using
MmaCore
=
typename
Mma
::
MmaCore
;
// Define iterators over tiles from the A operand
using
IteratorA
=
typename
Mma
::
IteratorA
;
// Define iterators over tiles from the B operand
using
IteratorB
=
typename
Mma
::
IteratorB
;
// Define the threadblock-scoped pipelined matrix multiply
using
ThreadblockMma
=
typename
Mma
::
ThreadblockMma
;
};
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment