Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3ee62235
Unverified
Commit
3ee62235
authored
Jan 31, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 31, 2025
Browse files
revert the MoE dependence (#3230)
parent
9829e77e
Changes
94
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
5321 deletions
+0
-5321
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h
...include/cutlass_extensions/gemm/threadblock/dq_mma_base.h
+0
-257
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h
...e/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h
+0
-110
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h
...tensions/gemm/threadblock/dq_mma_multistage_finegrained.h
+0
-708
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h
...ss_extensions/gemm/threadblock/dq_mma_multistage_percol.h
+0
-647
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h
...de/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h
+0
-106
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h
...xtensions/gemm/threadblock/dq_mma_pipelined_finegrained.h
+0
-486
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h
...ass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h
+0
-399
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h
...lude/cutlass_extensions/gemm/warp/default_mma_tensor_op.h
+0
-107
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h
...ss_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h
+0
-306
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h
...e/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h
+0
-463
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h
...lass_extensions/include/cutlass_extensions/gemm_configs.h
+0
-224
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
...clude/cutlass_extensions/interleaved_numeric_conversion.h
+0
-447
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h
...ions/include/cutlass_extensions/tile_interleaved_layout.h
+0
-66
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h
.../transform/threadblock/fine_grained_scale_zero_iterator.h
+0
-250
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp
...ensions/include/cutlass_extensions/util/gather_tensor.hpp
+0
-181
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h
...ensions/include/cutlass_extensions/weight_only_quant_op.h
+0
-58
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h
...kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h
+0
-25
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl
...rnels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl
+0
-96
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h
...tlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h
+0
-37
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl
...ass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl
+0
-348
No files found.
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/weight_only_quant_op.h"
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
////////////////////////////////////////////////////////////////////////////////
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
// correct warp level mma. On volta, all data is stored to shared memory as FP16.
template
<
typename
WarpMma
,
int
kExpansionFactor
=
1
>
CUTLASS_DEVICE
void
run_warp_mma
(
WarpMma
&
warp_mma
,
typename
WarpMma
::
FragmentC
&
D
,
typename
WarpMma
::
FragmentA
const
&
A
,
typename
WarpMma
::
FragmentB
const
&
B
,
typename
WarpMma
::
FragmentC
const
&
C
,
int
const
warp_tileB_k_offset
)
{
warp_mma
(
D
,
A
,
B
,
C
);
}
template
<
typename
WarpMma
,
int
kExpansionFactor
=
WarpMma
::
kExpansionFactor
>
CUTLASS_DEVICE
void
run_warp_mma
(
WarpMma
&
warp_mma
,
typename
WarpMma
::
FragmentC
&
D
,
typename
WarpMma
::
TransformedFragmentA
const
&
A
,
typename
WarpMma
::
TransformedFragmentB
const
&
B
,
typename
WarpMma
::
FragmentC
const
&
C
,
int
const
warp_tileB_k_offset
)
{
warp_mma
(
D
,
A
,
B
,
C
,
warp_tileB_k_offset
);
}
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template
<
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename
Shape_
,
/// Policy describing tuning details (concept: MmaPolicy)
typename
Policy_
,
/// The type of the scales
typename
ElementScale_
,
/// Number of stages,
int
Stages
,
/// The dequantizing op to be performed.
WeightOnlyQuantOp
DequantOp
,
/// Used for partial specialization,
typename
Enable
=
bool
>
class
DqMmaBase
{
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using
Shape
=
Shape_
;
///< Policy describing tuning details
using
Policy
=
Policy_
;
///< Type of the scale to be loaded
using
ElementScale
=
ElementScale_
;
static_assert
(
DequantOp
!=
WeightOnlyQuantOp
::
UNDEFINED
,
""
);
// Finegrained scales get streamed in via cp.async
static
constexpr
int
ScalebiasStages
=
isFinegrained
(
DequantOp
)
?
Stages
:
1
;
// We always have scales.
static
constexpr
int
ScaleElementsPerStage
=
Shape
::
kN
;
// We sometimes have a bias
static
constexpr
int
BiasElementsPerStage
=
hasZero
(
DequantOp
)
?
Shape
::
kN
:
0
;
//
// Dependent types
//
/// Warp-level Mma
using
Operator
=
typename
Policy
::
Operator
;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using
WarpGemm
=
typename
Policy
::
Operator
::
Shape
;
/// Shape describing the number of warps filling the CTA
using
WarpCount
=
GemmShape
<
Shape
::
kM
/
WarpGemm
::
kM
,
Shape
::
kN
/
WarpGemm
::
kN
,
Shape
::
kK
/
WarpGemm
::
kK
>
;
/// Number of warp-level GEMM operations
static
int
const
kWarpGemmIterations
=
(
WarpGemm
::
kK
/
Operator
::
Policy
::
MmaShape
::
kK
);
static
constexpr
int
kNumKIterationsPerWarpBLoad
=
Operator
::
IteratorB
::
InstructionShape
::
kRow
/
Operator
::
InstructionShape
::
kK
;
static_assert
(
!
(
kWarpGemmIterations
%
kNumKIterationsPerWarpBLoad
),
""
);
static
constexpr
int
kWarpGemmIterationsForB
=
kWarpGemmIterations
/
kNumKIterationsPerWarpBLoad
;
/// Number of stages
static
int
const
kStages
=
Stages
;
/// Tensor reference to the A operand
using
TensorRefA
=
TensorRef
<
typename
Operator
::
ElementA
,
typename
Operator
::
LayoutA
>
;
/// Tensor reference to the B operand
using
TensorRefB
=
TensorRef
<
typename
Operator
::
ElementB
,
typename
Operator
::
LayoutB
>
;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class
SharedStorage
{
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using
ShapeA
=
MatrixShape
<
Shape
::
kM
+
Policy
::
SmemPaddingA
::
kRow
,
Shape
::
kK
*
kStages
+
Policy
::
SmemPaddingA
::
kColumn
>
;
/// Shape of the B matrix operand in shared memory
using
ShapeB
=
MatrixShape
<
Shape
::
kK
*
kStages
+
Policy
::
SmemPaddingB
::
kRow
,
Shape
::
kN
+
Policy
::
SmemPaddingB
::
kColumn
>
;
/// Shape of the shared memory buffer for the scales for the B matrix.
using
ShapeScale
=
MatrixShape
<
ScalebiasStages
,
ScaleElementsPerStage
>
;
/// Shape of the shared memory buffer for the biases of the B matrix.
using
ShapeZero
=
MatrixShape
<
ScalebiasStages
,
BiasElementsPerStage
>
;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer
<
typename
Operator
::
ElementA
,
ShapeA
::
kCount
>
operand_A
;
/// Buffer for B operand
AlignedBuffer
<
typename
Operator
::
ElementB
,
ShapeB
::
kCount
>
operand_B
;
/// Buffer to hold scales for threadblock
AlignedBuffer
<
ElementScale
,
ShapeScale
::
kCount
>
operand_scale
;
/// Buffer to hold scales for threadblock
AlignedBuffer
<
ElementScale
,
ShapeZero
::
kCount
>
operand_zero
;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static
typename
Operator
::
LayoutA
LayoutA
()
{
return
Operator
::
LayoutA
::
packed
({
ShapeA
::
kRow
,
ShapeA
::
kColumn
});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static
typename
Operator
::
LayoutB
LayoutB
()
{
return
Operator
::
LayoutB
::
packed
({
ShapeB
::
kRow
,
ShapeB
::
kColumn
});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA
operand_A_ref
()
{
return
TensorRefA
{
operand_A
.
data
(),
LayoutA
()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB
operand_B_ref
()
{
return
TensorRefB
{
operand_B
.
data
(),
LayoutB
()};
}
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename
Operator
::
IteratorA
warp_tile_iterator_A_
;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename
Operator
::
IteratorB
warp_tile_iterator_B_
;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaBase
(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage
&
shared_storage
,
///< ID within the threadblock
int
thread_idx
,
///< ID of warp
int
warp_idx
,
///< ID of each thread within a warp
int
lane_idx
)
:
warp_tile_iterator_A_
(
shared_storage
.
operand_A_ref
(),
lane_idx
)
,
warp_tile_iterator_B_
(
shared_storage
.
operand_B_ref
(),
lane_idx
)
{
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template
<
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename
Shape_
,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename
IteratorA_
,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorA_
,
/// Cache operation for operand A
cutlass
::
arch
::
CacheOperation
::
Kind
CacheOpA
,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename
IteratorB_
,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorB_
,
/// Cache operation for operand B
cutlass
::
arch
::
CacheOperation
::
Kind
CacheOpB
,
/// Data type for the scales
typename
IteratorScale_
,
/// Iterators over scales in shared memory
typename
SmemIteratorScale_
,
/// Data type of accumulator matrix
typename
ElementC_
,
/// Data type of accumulator matrix
typename
LayoutC_
,
/// Policy describing tuning details (concept: MmaPolicy)
typename
Policy_
,
/// Number of stages,
int
Stages
,
/// Converter for B matrix applited immediately after the LDS
typename
TransformBAfterLDS_
,
/// The quantization operator being used
WeightOnlyQuantOp
QuantOp_
,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption
SharedMemoryClear
=
SharedMemoryClearOption
::
kNone
,
/// Used for partial specialization
typename
Enable
=
void
>
class
DqMmaMultistage
;
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h"
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template
<
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename
Shape_
,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename
IteratorA_
,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorA_
,
/// Cache operation for operand A
cutlass
::
arch
::
CacheOperation
::
Kind
CacheOpA
,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename
IteratorB_
,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorB_
,
/// Cache operation for operand B
cutlass
::
arch
::
CacheOperation
::
Kind
CacheOpB
,
/// Iterators over scales in global memory
typename
IteratorScale_
,
/// Iterators over scales in shared memory
typename
SmemIteratorScale_
,
/// Data type of accumulator matrix
typename
ElementC_
,
/// Layout of accumulator matrix
typename
LayoutC_
,
/// Policy describing tuning details (concept: MmaPolicy)
typename
Policy_
,
/// Number of stages,
int
Stages
,
/// Converter for B matrix applied immediately after the LDS
typename
TransformBAfterLDS_
,
/// The quantization operator being used
WeightOnlyQuantOp
QuantOp_
,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption
SharedMemoryClear
>
class
DqMmaMultistage
<
Shape_
,
IteratorA_
,
SmemIteratorA_
,
CacheOpA
,
IteratorB_
,
SmemIteratorB_
,
CacheOpB
,
IteratorScale_
,
SmemIteratorScale_
,
ElementC_
,
LayoutC_
,
Policy_
,
Stages
,
TransformBAfterLDS_
,
QuantOp_
,
SharedMemoryClear
,
std
::
enable_if_t
<
isFinegrained
(
QuantOp_
)
>>
:
public
DqMmaBase
<
Shape_
,
Policy_
,
typename
IteratorScale_
::
Element
,
Stages
,
QuantOp_
>
{
public:
///< Base class
using
Base
=
DqMmaBase
<
Shape_
,
Policy_
,
typename
IteratorScale_
::
Element
,
Stages
,
QuantOp_
>
;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using
Shape
=
Shape_
;
///< Iterates over tiles of A operand in global memory
using
IteratorA
=
IteratorA_
;
///< Iterates over tiles of B operand in global memory
using
IteratorB
=
IteratorB_
;
///< Data type of accumulator matrix
using
ElementC
=
ElementC_
;
///< Layout of accumulator matrix
using
LayoutC
=
LayoutC_
;
///< Policy describing tuning details
using
Policy
=
Policy_
;
using
IteratorScale
=
IteratorScale_
;
using
ElementScale
=
typename
IteratorScale
::
Element
;
using
LayoutScale
=
typename
IteratorScale
::
Layout
;
using
SmemIteratorA
=
SmemIteratorA_
;
using
SmemIteratorB
=
SmemIteratorB_
;
using
SmemIteratorScale
=
SmemIteratorScale_
;
static
cutlass
::
arch
::
CacheOperation
::
Kind
const
kCacheOpA
=
CacheOpA
;
static
cutlass
::
arch
::
CacheOperation
::
Kind
const
kCacheOpB
=
CacheOpB
;
using
TransformBAfterLDS
=
TransformBAfterLDS_
;
static
constexpr
WeightOnlyQuantOp
QuantOp
=
QuantOp_
;
//
// Dependent types
//
/// Fragment of accumulator tile
using
FragmentC
=
typename
Policy
::
Operator
::
FragmentC
;
/// Warp-level Mma
using
Operator
=
typename
Policy
::
Operator
;
/// Minimum architecture is Sm80 to support cp.async
using
ArchTag
=
arch
::
Sm80
;
using
Dequantizer
=
warp
::
MmaTensorOpDequantizer
<
Operator
,
typename
Base
::
WarpGemm
,
Operand
::
kB
,
ElementScale
,
LayoutScale
,
32
,
QuantOp
>
;
/// Complex transform on A operand
static
ComplexTransform
const
kTransformA
=
Operator
::
kTransformA
;
/// Complex transform on B operand
static
ComplexTransform
const
kTransformB
=
Operator
::
kTransformB
;
static_assert
(
Base
::
SharedStorage
::
ShapeScale
::
kRow
==
Stages
,
""
);
static_assert
(
Base
::
SharedStorage
::
ShapeScale
::
kColumn
==
Shape
::
kN
,
""
);
/// Internal structure exposed for introspection.
struct
Detail
{
static_assert
(
Base
::
kWarpGemmIterations
>
1
,
"The pipelined structure requires at least two warp-level "
"GEMM operations."
);
/// Number of cp.async instructions to load one stage of operand A
static
int
const
AsyncCopyIterationsPerStageA
=
IteratorA
::
ThreadMap
::
Iterations
::
kCount
;
/// Number of cp.async instructions to load one stage of operand B
static
int
const
AsyncCopyIterationsPerStageB
=
IteratorB
::
ThreadMap
::
Iterations
::
kCount
;
/// Number of stages
static
int
const
kStages
=
Stages
;
/// Number of cp.async instructions to load on group of operand A
static
int
const
kAccessesPerGroupA
=
(
AsyncCopyIterationsPerStageA
+
Base
::
kWarpGemmIterations
-
1
)
/
Base
::
kWarpGemmIterations
;
/// Number of cp.async instructions to load on group of operand B
static
int
const
kAccessesPerGroupB
=
(
AsyncCopyIterationsPerStageB
+
Base
::
kWarpGemmIterations
-
1
)
/
Base
::
kWarpGemmIterations
;
};
private:
using
WarpFragmentA
=
typename
Operator
::
FragmentA
;
using
WarpFragmentB
=
typename
Operator
::
FragmentB
;
Dequantizer
warp_dequantizer_
;
using
ElementA
=
typename
IteratorA
::
Element
;
using
ElementB
=
typename
IteratorB
::
Element
;
using
LayoutDetailsForB
=
kernel
::
LayoutDetailsB
<
ElementA
,
ElementB
,
ArchTag
>
;
static
constexpr
bool
RequiresTileInterleave
=
layout
::
IsColumnMajorTileInterleave
<
typename
LayoutDetailsForB
::
Layout
>::
value
;
static_assert
(
!
RequiresTileInterleave
||
(
RequiresTileInterleave
&&
(
Shape
::
kK
==
LayoutDetailsForB
::
ThreadblockK
)),
"Layout K must match threadblockK"
);
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA
smem_iterator_A_
;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB
smem_iterator_B_
;
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
SmemIteratorScale
smem_iterator_scale_
;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaMultistage
(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename
Base
::
SharedStorage
&
shared_storage
,
/// The group size for quantization
int
const
group_size
,
///< ID within the threadblock
int
thread_idx
,
///< ID of warp
int
warp_idx
,
///< ID of each thread within a warp
int
lane_idx
)
:
Base
(
shared_storage
,
thread_idx
,
warp_idx
,
lane_idx
)
,
warp_dequantizer_
({
shared_storage
.
operand_scale
.
data
(),
LayoutScale
(
Shape
::
kN
)},
{
shared_storage
.
operand_zero
.
data
(),
LayoutScale
(
Shape
::
kN
)},
(
warp_idx
%
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
))
/
Base
::
WarpCount
::
kM
,
lane_idx
)
,
smem_iterator_A_
(
shared_storage
.
operand_A_ref
(),
thread_idx
)
,
smem_iterator_B_
(
shared_storage
.
operand_B_ref
(),
thread_idx
)
,
smem_iterator_scale_
(
LayoutScale
(
Shape
::
kN
),
shared_storage
.
operand_scale
.
data
(),
shared_storage
.
operand_zero
.
data
(),
{
Base
::
kStages
,
Shape
::
kN
},
thread_idx
,
group_size
)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int
warp_idx_mn
=
warp_idx
%
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
);
int
warp_idx_k
=
warp_idx
/
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
);
int
warp_idx_m
=
warp_idx_mn
%
Base
::
WarpCount
::
kM
;
int
warp_idx_n
=
warp_idx_mn
/
Base
::
WarpCount
::
kM
;
// Add per-warp offsets in units of warp-level tiles
this
->
warp_tile_iterator_A_
.
add_tile_offset
({
warp_idx_m
,
Base
::
kWarpGemmIterations
*
warp_idx_k
});
this
->
warp_tile_iterator_B_
.
add_tile_offset
({
Base
::
kWarpGemmIterationsForB
*
warp_idx_k
,
warp_idx_n
});
}
CUTLASS_DEVICE
void
copy_scales_and_advance
(
IteratorScale
&
iterator_scale
,
int
stage
=
-
1
,
int
k_iter
=
-
1
)
{
static_assert
(
IteratorScale
::
Shape
::
kRow
==
1
,
"Scale stride must be 1."
);
typename
IteratorScale
::
AccessType
*
gmem_scale_ptr
=
iterator_scale
.
get_scale
();
typename
IteratorScale
::
AccessType
*
gmem_zero_ptr
=
iterator_scale
.
get_zero
();
typename
IteratorScale
::
AccessType
*
smem_scale_ptr
=
reinterpret_cast
<
typename
IteratorScale
::
AccessType
*>
(
this
->
smem_iterator_scale_
.
get_scale
());
typename
IteratorScale
::
AccessType
*
smem_zero_ptr
=
reinterpret_cast
<
typename
IteratorScale
::
AccessType
*>
(
this
->
smem_iterator_scale_
.
get_zero
());
int
const
kSrcBytes
=
sizeof_bits
<
typename
IteratorScale
::
Element
>::
value
*
IteratorScale
::
kAlignment
/
8
;
cutlass
::
arch
::
cp_async
<
kSrcBytes
,
kCacheOpB
>
(
smem_scale_ptr
,
gmem_scale_ptr
,
iterator_scale
.
valid
());
if
(
gmem_zero_ptr
!=
nullptr
)
{
cutlass
::
arch
::
cp_async
<
kSrcBytes
,
kCacheOpB
>
(
smem_zero_ptr
,
gmem_zero_ptr
,
iterator_scale
.
valid
());
}
if
(
iterator_scale
.
group_size_
==
64
)
{
iterator_scale
.
add_tile_offset
({
1
,
0
});
}
else
if
(
iterator_scale
.
group_size_
==
128
)
{
if
constexpr
(
Shape
::
kK
==
128
)
{
iterator_scale
.
add_tile_offset
({
1
,
0
});
}
else
if
constexpr
(
Shape
::
kK
==
64
)
{
if
(
iterator_scale
.
row_groupsize64_
&
0x1
)
{
iterator_scale
.
add_tile_offset
({
1
,
0
});
}
}
else
{
static_assert
(
Shape
::
kK
==
0
,
"Unsupported k tile shape, can only be 64 or 128"
);
}
}
iterator_scale
.
row_groupsize64_
++
;
this
->
smem_iterator_scale_
.
add_tile_offset
({
1
,
0
});
}
CUTLASS_DEVICE
void
copy_tiles_and_advance
(
IteratorA
&
iterator_A
,
IteratorB
&
iterator_B
,
int
group_start_A
=
0
,
int
group_start_B
=
0
)
{
iterator_A
.
set_iteration_index
(
group_start_A
*
IteratorA
::
kAccessesPerVector
);
this
->
smem_iterator_A_
.
set_iteration_index
(
group_start_A
);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
kAccessesPerGroupA
;
++
j
)
{
if
(
group_start_A
+
j
<
Detail
::
AsyncCopyIterationsPerStageA
)
{
typename
IteratorA
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorA
::
AccessType
*>
(
this
->
smem_iterator_A_
.
get
());
int
const
kSrcBytes
=
sizeof_bits
<
typename
IteratorA
::
Element
>::
value
*
IteratorA
::
ThreadMap
::
kElementsPerAccess
/
IteratorA
::
kAccessesPerVector
/
8
;
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
IteratorA
::
kAccessesPerVector
;
++
v
)
{
auto
gmem_ptr
=
iterator_A
.
get
();
if
(
SharedMemoryClear
==
SharedMemoryClearOption
::
kZfill
)
{
cutlass
::
arch
::
cp_async_zfill
<
kSrcBytes
,
kCacheOpA
>
(
dst_ptr
+
v
,
gmem_ptr
,
iterator_A
.
valid
());
}
else
{
cutlass
::
arch
::
cp_async
<
kSrcBytes
,
kCacheOpA
>
(
dst_ptr
+
v
,
gmem_ptr
,
iterator_A
.
valid
());
}
++
iterator_A
;
}
++
this
->
smem_iterator_A_
;
}
}
iterator_B
.
set_iteration_index
(
group_start_B
*
IteratorB
::
kAccessesPerVector
);
this
->
smem_iterator_B_
.
set_iteration_index
(
group_start_B
);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
kAccessesPerGroupB
;
++
j
)
{
if
(
group_start_B
+
j
<
Detail
::
AsyncCopyIterationsPerStageB
)
{
typename
IteratorB
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorB
::
AccessType
*>
(
this
->
smem_iterator_B_
.
get
());
int
const
kSrcBytes
=
sizeof_bits
<
typename
IteratorB
::
Element
>::
value
*
IteratorB
::
ThreadMap
::
kElementsPerAccess
/
IteratorB
::
kAccessesPerVector
/
8
;
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
IteratorB
::
kAccessesPerVector
;
++
v
)
{
auto
gmem_ptr
=
iterator_B
.
get
();
if
(
SharedMemoryClear
==
SharedMemoryClearOption
::
kZfill
)
{
cutlass
::
arch
::
cp_async_zfill
<
kSrcBytes
,
kCacheOpB
>
(
dst_ptr
+
v
,
gmem_ptr
,
iterator_B
.
valid
());
}
else
{
cutlass
::
arch
::
cp_async
<
kSrcBytes
,
kCacheOpB
>
(
dst_ptr
+
v
,
gmem_ptr
,
iterator_B
.
valid
());
}
++
iterator_B
;
}
++
this
->
smem_iterator_B_
;
}
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void
operator
()(
///< problem size of GEMM
int
gemm_k_iterations
,
///< destination accumulator tile
FragmentC
&
accum
,
///< iterator over A operand in global memory
IteratorA
iterator_A
,
///< iterator over B operand in global memory
IteratorB
iterator_B
,
///< iterator over scale operand in global memory
IteratorScale
iterator_scale
,
///< initial value of accumulator
FragmentC
const
&
src_accum
)
{
//
// Prologue
//
TransformBAfterLDS
lds_converter
;
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for
(
int
stage
=
0
;
stage
<
Base
::
kStages
-
1
;
++
stage
,
--
gemm_k_iterations
)
{
iterator_A
.
clear_mask
(
gemm_k_iterations
==
0
);
iterator_B
.
clear_mask
(
gemm_k_iterations
==
0
);
iterator_scale
.
clear_mask
(
gemm_k_iterations
==
0
);
iterator_A
.
set_iteration_index
(
0
);
this
->
smem_iterator_A_
.
set_iteration_index
(
0
);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
AsyncCopyIterationsPerStageA
;
++
j
)
{
typename
IteratorA
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorA
::
AccessType
*>
(
this
->
smem_iterator_A_
.
get
());
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
IteratorA
::
kAccessesPerVector
;
++
v
)
{
int
const
kSrcBytes
=
sizeof_bits
<
typename
IteratorA
::
Element
>::
value
*
IteratorA
::
ThreadMap
::
kElementsPerAccess
/
IteratorA
::
kAccessesPerVector
/
8
;
cutlass
::
arch
::
cp_async_zfill
<
kSrcBytes
,
kCacheOpA
>
(
dst_ptr
+
v
,
iterator_A
.
get
(),
iterator_A
.
valid
());
++
iterator_A
;
}
++
this
->
smem_iterator_A_
;
}
iterator_B
.
set_iteration_index
(
0
);
this
->
smem_iterator_B_
.
set_iteration_index
(
0
);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
AsyncCopyIterationsPerStageB
;
++
j
)
{
typename
IteratorB
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorB
::
AccessType
*>
(
this
->
smem_iterator_B_
.
get
());
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
IteratorB
::
kAccessesPerVector
;
++
v
)
{
int
const
kSrcBytes
=
sizeof_bits
<
typename
IteratorB
::
Element
>::
value
*
IteratorB
::
ThreadMap
::
kElementsPerAccess
/
IteratorB
::
kAccessesPerVector
/
8
;
cutlass
::
arch
::
cp_async_zfill
<
kSrcBytes
,
kCacheOpB
>
(
dst_ptr
+
v
,
iterator_B
.
get
(),
iterator_B
.
valid
());
++
iterator_B
;
}
++
this
->
smem_iterator_B_
;
}
copy_scales_and_advance
(
iterator_scale
,
stage
,
gemm_k_iterations
);
// Move to the next stage
iterator_A
.
add_tile_offset
({
0
,
1
});
iterator_B
.
add_tile_offset
({
1
,
0
});
this
->
smem_iterator_A_
.
add_tile_offset
({
0
,
1
});
this
->
smem_iterator_B_
.
add_tile_offset
({
1
,
0
});
// Defines the boundary of a stage of cp.async.
cutlass
::
arch
::
cp_async_fence
();
}
// Perform accumulation in the 'd' output operand
accum
=
src_accum
;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
// so that all accumulator elements outside the GEMM footprint are zero.
//
if
(
SharedMemoryClear
==
SharedMemoryClearOption
::
kClearLastStage
)
{
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA
last_smem_iterator_A
(
this
->
smem_iterator_A_
);
typename
IteratorA
::
AccessType
zero_A
;
zero_A
.
clear
();
last_smem_iterator_A
.
set_iteration_index
(
0
);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
AsyncCopyIterationsPerStageA
;
++
j
)
{
typename
IteratorA
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorA
::
AccessType
*>
(
last_smem_iterator_A
.
get
());
*
dst_ptr
=
zero_A
;
++
last_smem_iterator_A
;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB
last_smem_iterator_B
(
this
->
smem_iterator_B_
);
typename
IteratorB
::
AccessType
zero_B
;
zero_B
.
clear
();
last_smem_iterator_B
.
set_iteration_index
(
0
);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
AsyncCopyIterationsPerStageB
;
++
j
)
{
typename
IteratorB
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorB
::
AccessType
*>
(
last_smem_iterator_B
.
get
());
*
dst_ptr
=
zero_B
;
++
last_smem_iterator_B
;
}
}
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
cutlass
::
arch
::
cp_async_wait
<
Base
::
kStages
-
2
>
();
__syncthreads
();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA
warp_frag_A
[
2
];
WarpFragmentB
warp_frag_B
[
2
];
typename
Dequantizer
::
FragmentScale
warp_frag_scales
;
typename
Dequantizer
::
FragmentZero
warp_frag_zeros
;
Operator
warp_mma
;
this
->
warp_tile_iterator_A_
.
set_kgroup_index
(
0
);
this
->
warp_tile_iterator_B_
.
set_kgroup_index
(
0
);
this
->
warp_tile_iterator_A_
.
load
(
warp_frag_A
[
0
]);
this
->
warp_tile_iterator_B_
.
load
(
warp_frag_B
[
0
]);
warp_dequantizer_
.
load
(
warp_frag_scales
,
warp_frag_zeros
);
++
this
->
warp_tile_iterator_A_
;
++
this
->
warp_tile_iterator_B_
;
warp_dequantizer_
.
add_pointer_offset
(
Shape
::
kN
);
iterator_A
.
clear_mask
(
gemm_k_iterations
==
0
);
iterator_B
.
clear_mask
(
gemm_k_iterations
==
0
);
iterator_scale
.
clear_mask
(
gemm_k_iterations
==
0
);
int
smem_write_stage_idx
=
Base
::
kStages
-
1
;
int
smem_read_stage_idx
=
0
;
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for
(;
gemm_k_iterations
>
(
-
Base
::
kStages
+
1
);)
{
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for
(
int
warp_mma_k
=
0
;
warp_mma_k
<
Base
::
kWarpGemmIterations
;
++
warp_mma_k
)
{
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this
->
warp_tile_iterator_A_
.
set_kgroup_index
((
warp_mma_k
+
1
)
%
Base
::
kWarpGemmIterations
);
this
->
warp_tile_iterator_A_
.
load
(
warp_frag_A
[(
warp_mma_k
+
1
)
%
2
]);
++
this
->
warp_tile_iterator_A_
;
int
const
warp_tileB_k_compute_offset
=
warp_mma_k
%
Base
::
kNumKIterationsPerWarpBLoad
;
int
const
warp_tileB_k_load_offset
=
warp_mma_k
/
Base
::
kNumKIterationsPerWarpBLoad
;
if
(
warp_tileB_k_compute_offset
==
Base
::
kNumKIterationsPerWarpBLoad
-
1
)
{
this
->
warp_tile_iterator_B_
.
set_kgroup_index
(
(
warp_tileB_k_load_offset
+
1
)
%
Base
::
kWarpGemmIterationsForB
);
this
->
warp_tile_iterator_B_
.
load
(
warp_frag_B
[(
warp_tileB_k_load_offset
+
1
)
%
2
]);
++
this
->
warp_tile_iterator_B_
;
}
typename
TransformBAfterLDS
::
result_type
converted_frag_B
=
lds_converter
(
warp_frag_B
[
warp_tileB_k_load_offset
%
2
]);
warp_dequantizer_
.
dequantize
(
converted_frag_B
,
warp_frag_scales
,
warp_frag_zeros
);
using
FragmentOperandB
=
cutlass
::
Array
<
ElementA
,
Operator
::
FragmentB
::
kElements
>
;
constexpr
cutlass
::
FloatRoundStyle
RoundStyle
=
cutlass
::
FloatRoundStyle
::
round_to_nearest
;
constexpr
int
ConversionVectorWidth
=
TransformBAfterLDS
::
result_type
::
kElements
;
static_assert
(
ConversionVectorWidth
==
FragmentOperandB
::
kElements
);
using
Converter
=
cutlass
::
NumericArrayConverter
<
ElementA
,
ElementScale
,
ConversionVectorWidth
,
RoundStyle
>
;
FragmentOperandB
converted_frag_B_operand
=
Converter
::
convert
(
converted_frag_B
);
run_warp_mma
(
warp_mma
,
accum
,
warp_frag_A
[
warp_mma_k
%
2
],
converted_frag_B_operand
,
accum
,
warp_tileB_k_compute_offset
);
// Issue global->shared copies for the this stage
if
(
warp_mma_k
<
Base
::
kWarpGemmIterations
-
1
)
{
int
group_start_iteration_A
,
group_start_iteration_B
;
group_start_iteration_A
=
warp_mma_k
*
Detail
::
kAccessesPerGroupA
;
group_start_iteration_B
=
warp_mma_k
*
Detail
::
kAccessesPerGroupB
;
copy_tiles_and_advance
(
iterator_A
,
iterator_B
,
group_start_iteration_A
,
group_start_iteration_B
);
// This is the first group of a given stage, so we issue the loads for the B scales immediately.
if
(
group_start_iteration_B
==
0
)
{
copy_scales_and_advance
(
iterator_scale
);
}
}
if
(
warp_mma_k
+
2
==
Base
::
kWarpGemmIterations
)
{
int
group_start_iteration_A
,
group_start_iteration_B
;
group_start_iteration_A
=
(
warp_mma_k
+
1
)
*
Detail
::
kAccessesPerGroupA
;
group_start_iteration_B
=
(
warp_mma_k
+
1
)
*
Detail
::
kAccessesPerGroupB
;
copy_tiles_and_advance
(
iterator_A
,
iterator_B
,
group_start_iteration_A
,
group_start_iteration_B
);
// Inserts a memory fence between stages of cp.async instructions.
cutlass
::
arch
::
cp_async_fence
();
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
// #committed)
arch
::
cp_async_wait
<
Base
::
kStages
-
2
>
();
__syncthreads
();
// Move to the next stage
iterator_A
.
add_tile_offset
({
0
,
1
});
iterator_B
.
add_tile_offset
({
1
,
0
});
this
->
smem_iterator_A_
.
add_tile_offset
({
0
,
1
});
this
->
smem_iterator_B_
.
add_tile_offset
({
1
,
0
});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if
(
smem_write_stage_idx
==
(
Base
::
kStages
-
1
))
{
this
->
smem_iterator_A_
.
add_tile_offset
({
0
,
-
Base
::
kStages
});
this
->
smem_iterator_B_
.
add_tile_offset
({
-
Base
::
kStages
,
0
});
this
->
smem_iterator_scale_
.
add_tile_offset
({
-
Base
::
kStages
,
0
});
smem_write_stage_idx
=
0
;
}
else
{
++
smem_write_stage_idx
;
}
if
(
smem_read_stage_idx
==
(
Base
::
kStages
-
1
))
{
this
->
warp_tile_iterator_A_
.
add_tile_offset
(
{
0
,
-
Base
::
kStages
*
Policy
::
kPartitionsK
*
Base
::
kWarpGemmIterations
});
this
->
warp_tile_iterator_B_
.
add_tile_offset
(
{
-
Base
::
kStages
*
Policy
::
kPartitionsK
*
Base
::
kWarpGemmIterationsForB
,
0
});
warp_dequantizer_
.
add_pointer_offset
(
-
Base
::
kStages
*
Shape
::
kN
);
smem_read_stage_idx
=
0
;
}
else
{
++
smem_read_stage_idx
;
}
--
gemm_k_iterations
;
iterator_A
.
clear_mask
(
gemm_k_iterations
==
0
);
iterator_B
.
clear_mask
(
gemm_k_iterations
==
0
);
iterator_scale
.
clear_mask
(
gemm_k_iterations
==
0
);
}
}
// Load the scale needed for the next tile iteration.
warp_dequantizer_
.
load
(
warp_frag_scales
,
warp_frag_zeros
);
// Update internal pointer to set of scales in shared memory.
warp_dequantizer_
.
add_pointer_offset
(
Shape
::
kN
);
}
if
(
SharedMemoryClear
==
SharedMemoryClearOption
::
kZfill
)
{
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
cutlass
::
arch
::
cp_async_fence
();
cutlass
::
arch
::
cp_async_wait
<
0
>
();
__syncthreads
();
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template
<
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename
Shape_
,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename
IteratorA_
,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorA_
,
/// Cache operation for operand A
cutlass
::
arch
::
CacheOperation
::
Kind
CacheOpA
,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename
IteratorB_
,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorB_
,
/// Cache operation for operand B
cutlass
::
arch
::
CacheOperation
::
Kind
CacheOpB
,
/// Iterators over scales in global memory
typename
IteratorScale_
,
/// Iterators over scales in shared memory
typename
SmemIteratorScale_
,
/// Data type of accumulator matrix
typename
ElementC_
,
/// Layout of accumulator matrix
typename
LayoutC_
,
/// Policy describing tuning details (concept: MmaPolicy)
typename
Policy_
,
/// Number of stages,
int
Stages
,
/// Converter for B matrix applited immediately after the LDS
typename
TransformBAfterLDS_
,
/// The quantization operator being used
WeightOnlyQuantOp
QuantOp_
,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption
SharedMemoryClear
>
class
DqMmaMultistage
<
Shape_
,
IteratorA_
,
SmemIteratorA_
,
CacheOpA
,
IteratorB_
,
SmemIteratorB_
,
CacheOpB
,
IteratorScale_
,
SmemIteratorScale_
,
ElementC_
,
LayoutC_
,
Policy_
,
Stages
,
TransformBAfterLDS_
,
QuantOp_
,
SharedMemoryClear
,
std
::
enable_if_t
<!
isFinegrained
(
QuantOp_
)
>>
:
public
DqMmaBase
<
Shape_
,
Policy_
,
typename
IteratorScale_
::
Element
,
Stages
,
QuantOp_
>
{
public:
///< Base class
using
Base
=
DqMmaBase
<
Shape_
,
Policy_
,
typename
IteratorScale_
::
Element
,
Stages
,
QuantOp_
>
;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using
Shape
=
Shape_
;
///< Iterates over tiles of A operand in global memory
using
IteratorA
=
IteratorA_
;
///< Iterates over tiles of B operand in global memory
using
IteratorB
=
IteratorB_
;
///< Data type of accumulator matrix
using
ElementC
=
ElementC_
;
///< Layout of accumulator matrix
using
LayoutC
=
LayoutC_
;
///< Policy describing tuning details
using
Policy
=
Policy_
;
using
IteratorScale
=
IteratorScale_
;
using
ElementScale
=
typename
IteratorScale
::
Element
;
using
LayoutScale
=
typename
IteratorScale
::
Layout
;
using
SmemIteratorA
=
SmemIteratorA_
;
using
SmemIteratorB
=
SmemIteratorB_
;
using
SmemIteratorScale
=
SmemIteratorScale_
;
static
cutlass
::
arch
::
CacheOperation
::
Kind
const
kCacheOpA
=
CacheOpA
;
static
cutlass
::
arch
::
CacheOperation
::
Kind
const
kCacheOpB
=
CacheOpB
;
using
TransformBAfterLDS
=
TransformBAfterLDS_
;
static
constexpr
WeightOnlyQuantOp
QuantOp
=
QuantOp_
;
//
// Dependent types
//
/// Fragment of operand Scale loaded from global memory;
using
FragmentScale
=
typename
IteratorScale
::
Fragment
;
/// Fragment of accumulator tile
using
FragmentC
=
typename
Policy
::
Operator
::
FragmentC
;
/// Warp-level Mma
using
Operator
=
typename
Policy
::
Operator
;
/// Minimum architecture is Sm80 to support cp.async
using
ArchTag
=
arch
::
Sm80
;
using
Dequantizer
=
warp
::
MmaTensorOpDequantizer
<
Operator
,
typename
Base
::
WarpGemm
,
Operand
::
kB
,
ElementScale
,
LayoutScale
,
32
,
QuantOp
>
;
/// Complex transform on A operand
static
ComplexTransform
const
kTransformA
=
Operator
::
kTransformA
;
/// Complex transform on B operand
static
ComplexTransform
const
kTransformB
=
Operator
::
kTransformB
;
/// Internal structure exposed for introspection.
struct
Detail
{
static_assert
(
Base
::
kWarpGemmIterations
>
1
,
"The pipelined structure requires at least two warp-level "
"GEMM operations."
);
/// Number of cp.async instructions to load one stage of operand A
static
int
const
AsyncCopyIterationsPerStageA
=
IteratorA
::
ThreadMap
::
Iterations
::
kCount
;
/// Number of cp.async instructions to load one stage of operand B
static
int
const
AsyncCopyIterationsPerStageB
=
IteratorB
::
ThreadMap
::
Iterations
::
kCount
;
/// Number of stages
static
int
const
kStages
=
Stages
;
/// Number of cp.async instructions to load on group of operand A
static
int
const
kAccessesPerGroupA
=
(
AsyncCopyIterationsPerStageA
+
Base
::
kWarpGemmIterations
-
1
)
/
Base
::
kWarpGemmIterations
;
/// Number of cp.async instructions to load on group of operand B
static
int
const
kAccessesPerGroupB
=
(
AsyncCopyIterationsPerStageB
+
Base
::
kWarpGemmIterations
-
1
)
/
Base
::
kWarpGemmIterations
;
};
private:
using
WarpFragmentA
=
typename
Operator
::
FragmentA
;
using
WarpFragmentB
=
typename
Operator
::
FragmentB
;
Dequantizer
warp_dequantizer_
;
using
ElementA
=
typename
IteratorA
::
Element
;
using
ElementB
=
typename
IteratorB
::
Element
;
using
LayoutDetailsForB
=
kernel
::
LayoutDetailsB
<
ElementA
,
ElementB
,
ArchTag
>
;
static
constexpr
bool
RequiresTileInterleave
=
layout
::
IsColumnMajorTileInterleave
<
typename
LayoutDetailsForB
::
Layout
>::
value
;
static_assert
(
!
RequiresTileInterleave
||
(
RequiresTileInterleave
&&
(
Shape
::
kK
==
LayoutDetailsForB
::
ThreadblockK
)),
"Layout K must match threadblockK"
);
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA
smem_iterator_A_
;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB
smem_iterator_B_
;
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
SmemIteratorScale
smem_iterator_scale_
;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaMultistage
(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename
Base
::
SharedStorage
&
shared_storage
,
///< Group size for quantization. Not used by this main loop since it assumes per-column
int
const
group_size
,
///< ID within the threadblock
int
thread_idx
,
///< ID of warp
int
warp_idx
,
///< ID of each thread within a warp
int
lane_idx
)
:
Base
(
shared_storage
,
thread_idx
,
warp_idx
,
lane_idx
)
,
warp_dequantizer_
({
shared_storage
.
operand_scale
.
data
(),
LayoutScale
(
Shape
::
kN
)},
(
warp_idx
%
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
))
/
Base
::
WarpCount
::
kM
,
lane_idx
)
,
smem_iterator_A_
(
shared_storage
.
operand_A_ref
(),
thread_idx
)
,
smem_iterator_B_
(
shared_storage
.
operand_B_ref
(),
thread_idx
)
,
smem_iterator_scale_
(
LayoutScale
(
Shape
::
kN
),
shared_storage
.
operand_scale
.
data
(),
{
1
,
Shape
::
kN
},
thread_idx
)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int
warp_idx_mn
=
warp_idx
%
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
);
int
warp_idx_k
=
warp_idx
/
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
);
int
warp_idx_m
=
warp_idx_mn
%
Base
::
WarpCount
::
kM
;
int
warp_idx_n
=
warp_idx_mn
/
Base
::
WarpCount
::
kM
;
// Add per-warp offsets in units of warp-level tiles
this
->
warp_tile_iterator_A_
.
add_tile_offset
({
warp_idx_m
,
Base
::
kWarpGemmIterations
*
warp_idx_k
});
this
->
warp_tile_iterator_B_
.
add_tile_offset
({
Base
::
kWarpGemmIterationsForB
*
warp_idx_k
,
warp_idx_n
});
}
CUTLASS_DEVICE
void
copy_tiles_and_advance
(
IteratorA
&
iterator_A
,
IteratorB
&
iterator_B
,
int
group_start_A
=
0
,
int
group_start_B
=
0
)
{
iterator_A
.
set_iteration_index
(
group_start_A
*
IteratorA
::
kAccessesPerVector
);
this
->
smem_iterator_A_
.
set_iteration_index
(
group_start_A
);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
kAccessesPerGroupA
;
++
j
)
{
if
(
group_start_A
+
j
<
Detail
::
AsyncCopyIterationsPerStageA
)
{
typename
IteratorA
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorA
::
AccessType
*>
(
this
->
smem_iterator_A_
.
get
());
int
const
kSrcBytes
=
sizeof_bits
<
typename
IteratorA
::
Element
>::
value
*
IteratorA
::
ThreadMap
::
kElementsPerAccess
/
IteratorA
::
kAccessesPerVector
/
8
;
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
IteratorA
::
kAccessesPerVector
;
++
v
)
{
auto
gmem_ptr
=
iterator_A
.
get
();
if
(
SharedMemoryClear
==
SharedMemoryClearOption
::
kZfill
)
{
cutlass
::
arch
::
cp_async_zfill
<
kSrcBytes
,
kCacheOpA
>
(
dst_ptr
+
v
,
gmem_ptr
,
iterator_A
.
valid
());
}
else
{
cutlass
::
arch
::
cp_async
<
kSrcBytes
,
kCacheOpA
>
(
dst_ptr
+
v
,
gmem_ptr
,
iterator_A
.
valid
());
}
++
iterator_A
;
}
++
this
->
smem_iterator_A_
;
}
}
iterator_B
.
set_iteration_index
(
group_start_B
*
IteratorB
::
kAccessesPerVector
);
this
->
smem_iterator_B_
.
set_iteration_index
(
group_start_B
);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
kAccessesPerGroupB
;
++
j
)
{
if
(
group_start_B
+
j
<
Detail
::
AsyncCopyIterationsPerStageB
)
{
typename
IteratorB
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorB
::
AccessType
*>
(
this
->
smem_iterator_B_
.
get
());
int
const
kSrcBytes
=
sizeof_bits
<
typename
IteratorB
::
Element
>::
value
*
IteratorB
::
ThreadMap
::
kElementsPerAccess
/
IteratorB
::
kAccessesPerVector
/
8
;
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
IteratorB
::
kAccessesPerVector
;
++
v
)
{
auto
gmem_ptr
=
iterator_B
.
get
();
if
(
SharedMemoryClear
==
SharedMemoryClearOption
::
kZfill
)
{
cutlass
::
arch
::
cp_async_zfill
<
kSrcBytes
,
kCacheOpB
>
(
dst_ptr
+
v
,
gmem_ptr
,
iterator_B
.
valid
());
}
else
{
cutlass
::
arch
::
cp_async
<
kSrcBytes
,
kCacheOpB
>
(
dst_ptr
+
v
,
gmem_ptr
,
iterator_B
.
valid
());
}
++
iterator_B
;
}
++
this
->
smem_iterator_B_
;
}
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void
operator
()(
///< problem size of GEMM
int
gemm_k_iterations
,
///< destination accumulator tile
FragmentC
&
accum
,
///< iterator over A operand in global memory
IteratorA
iterator_A
,
///< iterator over B operand in global memory
IteratorB
iterator_B
,
///< iterator over scale operand in global memory
IteratorScale
iterator_scale
,
///< initial value of accumulator
FragmentC
const
&
src_accum
)
{
//
// Prologue
//
TransformBAfterLDS
lds_converter
;
// NOTE - switch to ldg.sts
// Issue this first, so cp.async.commit_group will commit this load as well.
// Note: we do not commit here and this load will commit in the same group as
// the first load of A.
FragmentScale
tb_frag_scales
;
tb_frag_scales
.
clear
();
iterator_scale
.
load
(
tb_frag_scales
);
this
->
smem_iterator_scale_
.
store
(
tb_frag_scales
);
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for
(
int
stage
=
0
;
stage
<
Base
::
kStages
-
1
;
++
stage
,
--
gemm_k_iterations
)
{
iterator_A
.
clear_mask
(
gemm_k_iterations
==
0
);
iterator_B
.
clear_mask
(
gemm_k_iterations
==
0
);
iterator_A
.
set_iteration_index
(
0
);
this
->
smem_iterator_A_
.
set_iteration_index
(
0
);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
AsyncCopyIterationsPerStageA
;
++
j
)
{
typename
IteratorA
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorA
::
AccessType
*>
(
this
->
smem_iterator_A_
.
get
());
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
IteratorA
::
kAccessesPerVector
;
++
v
)
{
int
const
kSrcBytes
=
sizeof_bits
<
typename
IteratorA
::
Element
>::
value
*
IteratorA
::
ThreadMap
::
kElementsPerAccess
/
IteratorA
::
kAccessesPerVector
/
8
;
int
src_bytes
=
(
iterator_A
.
valid
()
?
kSrcBytes
:
0
);
cutlass
::
arch
::
cp_async_zfill
<
kSrcBytes
,
kCacheOpA
>
(
dst_ptr
+
v
,
iterator_A
.
get
(),
iterator_A
.
valid
());
++
iterator_A
;
}
++
this
->
smem_iterator_A_
;
}
iterator_B
.
set_iteration_index
(
0
);
this
->
smem_iterator_B_
.
set_iteration_index
(
0
);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
AsyncCopyIterationsPerStageB
;
++
j
)
{
typename
IteratorB
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorB
::
AccessType
*>
(
this
->
smem_iterator_B_
.
get
());
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
IteratorB
::
kAccessesPerVector
;
++
v
)
{
int
const
kSrcBytes
=
sizeof_bits
<
typename
IteratorB
::
Element
>::
value
*
IteratorB
::
ThreadMap
::
kElementsPerAccess
/
IteratorB
::
kAccessesPerVector
/
8
;
cutlass
::
arch
::
cp_async_zfill
<
kSrcBytes
,
kCacheOpB
>
(
dst_ptr
+
v
,
iterator_B
.
get
(),
iterator_B
.
valid
());
++
iterator_B
;
}
++
this
->
smem_iterator_B_
;
}
// Move to the next stage
iterator_A
.
add_tile_offset
({
0
,
1
});
iterator_B
.
add_tile_offset
({
1
,
0
});
this
->
smem_iterator_A_
.
add_tile_offset
({
0
,
1
});
this
->
smem_iterator_B_
.
add_tile_offset
({
1
,
0
});
// Defines the boundary of a stage of cp.async.
cutlass
::
arch
::
cp_async_fence
();
}
// Perform accumulation in the 'd' output operand
accum
=
src_accum
;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
// so that all accumulator elements outside the GEMM footprint are zero.
//
if
(
SharedMemoryClear
==
SharedMemoryClearOption
::
kClearLastStage
)
{
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA
last_smem_iterator_A
(
this
->
smem_iterator_A_
);
typename
IteratorA
::
AccessType
zero_A
;
zero_A
.
clear
();
last_smem_iterator_A
.
set_iteration_index
(
0
);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
AsyncCopyIterationsPerStageA
;
++
j
)
{
typename
IteratorA
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorA
::
AccessType
*>
(
last_smem_iterator_A
.
get
());
*
dst_ptr
=
zero_A
;
++
last_smem_iterator_A
;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB
last_smem_iterator_B
(
this
->
smem_iterator_B_
);
typename
IteratorB
::
AccessType
zero_B
;
zero_B
.
clear
();
last_smem_iterator_B
.
set_iteration_index
(
0
);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
Detail
::
AsyncCopyIterationsPerStageB
;
++
j
)
{
typename
IteratorB
::
AccessType
*
dst_ptr
=
reinterpret_cast
<
typename
IteratorB
::
AccessType
*>
(
last_smem_iterator_B
.
get
());
*
dst_ptr
=
zero_B
;
++
last_smem_iterator_B
;
}
}
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
cutlass
::
arch
::
cp_async_wait
<
Base
::
kStages
-
2
>
();
__syncthreads
();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA
warp_frag_A
[
2
];
WarpFragmentB
warp_frag_B
[
2
];
typename
Dequantizer
::
FragmentScale
warp_frag_scales
;
Operator
warp_mma
;
this
->
warp_tile_iterator_A_
.
set_kgroup_index
(
0
);
this
->
warp_tile_iterator_B_
.
set_kgroup_index
(
0
);
this
->
warp_tile_iterator_A_
.
load
(
warp_frag_A
[
0
]);
this
->
warp_tile_iterator_B_
.
load
(
warp_frag_B
[
0
]);
warp_dequantizer_
.
load
(
warp_frag_scales
);
++
this
->
warp_tile_iterator_A_
;
++
this
->
warp_tile_iterator_B_
;
iterator_A
.
clear_mask
(
gemm_k_iterations
==
0
);
iterator_B
.
clear_mask
(
gemm_k_iterations
==
0
);
int
smem_write_stage_idx
=
Base
::
kStages
-
1
;
int
smem_read_stage_idx
=
0
;
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for
(;
gemm_k_iterations
>
(
-
Base
::
kStages
+
1
);)
{
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for
(
int
warp_mma_k
=
0
;
warp_mma_k
<
Base
::
kWarpGemmIterations
;
++
warp_mma_k
)
{
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this
->
warp_tile_iterator_A_
.
set_kgroup_index
((
warp_mma_k
+
1
)
%
Base
::
kWarpGemmIterations
);
this
->
warp_tile_iterator_A_
.
load
(
warp_frag_A
[(
warp_mma_k
+
1
)
%
2
]);
++
this
->
warp_tile_iterator_A_
;
int
const
warp_tileB_k_compute_offset
=
warp_mma_k
%
Base
::
kNumKIterationsPerWarpBLoad
;
int
const
warp_tileB_k_load_offset
=
warp_mma_k
/
Base
::
kNumKIterationsPerWarpBLoad
;
if
(
warp_tileB_k_compute_offset
==
Base
::
kNumKIterationsPerWarpBLoad
-
1
)
{
this
->
warp_tile_iterator_B_
.
set_kgroup_index
(
(
warp_tileB_k_load_offset
+
1
)
%
Base
::
kWarpGemmIterationsForB
);
this
->
warp_tile_iterator_B_
.
load
(
warp_frag_B
[(
warp_tileB_k_load_offset
+
1
)
%
2
]);
++
this
->
warp_tile_iterator_B_
;
}
typename
TransformBAfterLDS
::
result_type
converted_frag_B
=
lds_converter
(
warp_frag_B
[
warp_tileB_k_load_offset
%
2
]);
warp_dequantizer_
.
dequantize
(
converted_frag_B
,
warp_frag_scales
);
using
FragmentOperandB
=
cutlass
::
Array
<
ElementA
,
Operator
::
FragmentB
::
kElements
>
;
constexpr
cutlass
::
FloatRoundStyle
RoundStyle
=
cutlass
::
FloatRoundStyle
::
round_to_nearest
;
constexpr
int
ConversionVectorWidth
=
TransformBAfterLDS
::
result_type
::
kElements
;
static_assert
(
ConversionVectorWidth
==
FragmentOperandB
::
kElements
);
using
Converter
=
cutlass
::
NumericArrayConverter
<
ElementA
,
ElementScale
,
ConversionVectorWidth
,
RoundStyle
>
;
FragmentOperandB
converted_frag_B_operand
=
Converter
::
convert
(
converted_frag_B
);
run_warp_mma
(
warp_mma
,
accum
,
warp_frag_A
[
warp_mma_k
%
2
],
converted_frag_B_operand
,
accum
,
warp_tileB_k_compute_offset
);
// Issue global->shared copies for the this stage
if
(
warp_mma_k
<
Base
::
kWarpGemmIterations
-
1
)
{
int
group_start_iteration_A
,
group_start_iteration_B
;
group_start_iteration_A
=
warp_mma_k
*
Detail
::
kAccessesPerGroupA
;
group_start_iteration_B
=
warp_mma_k
*
Detail
::
kAccessesPerGroupB
;
copy_tiles_and_advance
(
iterator_A
,
iterator_B
,
group_start_iteration_A
,
group_start_iteration_B
);
}
if
(
warp_mma_k
+
2
==
Base
::
kWarpGemmIterations
)
{
int
group_start_iteration_A
,
group_start_iteration_B
;
group_start_iteration_A
=
(
warp_mma_k
+
1
)
*
Detail
::
kAccessesPerGroupA
;
group_start_iteration_B
=
(
warp_mma_k
+
1
)
*
Detail
::
kAccessesPerGroupB
;
copy_tiles_and_advance
(
iterator_A
,
iterator_B
,
group_start_iteration_A
,
group_start_iteration_B
);
// Inserts a memory fence between stages of cp.async instructions.
cutlass
::
arch
::
cp_async_fence
();
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
// #committed)
arch
::
cp_async_wait
<
Base
::
kStages
-
2
>
();
__syncthreads
();
// Move to the next stage
iterator_A
.
add_tile_offset
({
0
,
1
});
iterator_B
.
add_tile_offset
({
1
,
0
});
this
->
smem_iterator_A_
.
add_tile_offset
({
0
,
1
});
this
->
smem_iterator_B_
.
add_tile_offset
({
1
,
0
});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if
(
smem_write_stage_idx
==
(
Base
::
kStages
-
1
))
{
this
->
smem_iterator_A_
.
add_tile_offset
({
0
,
-
Base
::
kStages
});
this
->
smem_iterator_B_
.
add_tile_offset
({
-
Base
::
kStages
,
0
});
smem_write_stage_idx
=
0
;
}
else
{
++
smem_write_stage_idx
;
}
if
(
smem_read_stage_idx
==
(
Base
::
kStages
-
1
))
{
this
->
warp_tile_iterator_A_
.
add_tile_offset
(
{
0
,
-
Base
::
kStages
*
Policy
::
kPartitionsK
*
Base
::
kWarpGemmIterations
});
this
->
warp_tile_iterator_B_
.
add_tile_offset
(
{
-
Base
::
kStages
*
Policy
::
kPartitionsK
*
Base
::
kWarpGemmIterationsForB
,
0
});
smem_read_stage_idx
=
0
;
}
else
{
++
smem_read_stage_idx
;
}
--
gemm_k_iterations
;
iterator_A
.
clear_mask
(
gemm_k_iterations
==
0
);
iterator_B
.
clear_mask
(
gemm_k_iterations
==
0
);
}
}
}
if
(
SharedMemoryClear
==
SharedMemoryClearOption
::
kZfill
)
{
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
cutlass
::
arch
::
cp_async_fence
();
cutlass
::
arch
::
cp_async_wait
<
0
>
();
__syncthreads
();
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
#include "cutlass_extensions/gemm_configs.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template
<
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename
Shape_
,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename
IteratorA_
,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorA_
,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename
IteratorB_
,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorB_
,
/// Data type for the scales
typename
IteratorScale_
,
/// Iterators over scales in shared memory
typename
SmemIteratorScale_
,
/// Data type of accumulator matrix
typename
ElementC_
,
/// Data type of accumulator matrix
typename
LayoutC_
,
/// Policy describing tuning details (concept: MmaPolicy)
typename
Policy_
,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename
TransformBAfterLDG_
,
/// Converter for B matrix applited immediately after the LDS
typename
TransformBAfterLDS_
,
/// The quantization operator being used
WeightOnlyQuantOp
QuantOp_
,
/// Used for partial specialization
typename
Enable
=
void
>
class
DqMmaPipelined
;
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h"
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
#include "cutlass_extensions/gemm_configs.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template
<
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename
Shape_
,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename
IteratorA_
,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorA_
,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename
IteratorB_
,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorB_
,
/// Iterators over scales in global memory
typename
IteratorScale_
,
/// Iterators over scales in shared memory
typename
SmemIteratorScale_
,
/// Data type of accumulator matrix
typename
ElementC_
,
/// Layout of accumulator matrix
typename
LayoutC_
,
/// Policy describing tuning details (concept: MmaPolicy)
typename
Policy_
,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename
TransformBAfterLDG_
,
/// Converter for B matrix applited immediately after the LDS
typename
TransformBAfterLDS_
,
/// The quantization operator being used
WeightOnlyQuantOp
QuantOp_
>
class
DqMmaPipelined
<
Shape_
,
IteratorA_
,
SmemIteratorA_
,
IteratorB_
,
SmemIteratorB_
,
IteratorScale_
,
SmemIteratorScale_
,
ElementC_
,
LayoutC_
,
Policy_
,
TransformBAfterLDG_
,
TransformBAfterLDS_
,
QuantOp_
,
std
::
enable_if_t
<
isFinegrained
(
QuantOp_
)
>>
:
public
DqMmaBase
<
Shape_
,
Policy_
,
typename
SmemIteratorScale_
::
Element
,
2
,
QuantOp_
>
{
public:
///< Base class
using
Base
=
DqMmaBase
<
Shape_
,
Policy_
,
typename
SmemIteratorScale_
::
Element
,
2
,
QuantOp_
>
;
using
Shape
=
Shape_
;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using
IteratorA
=
IteratorA_
;
///< Iterates over tiles of A operand in global memory
using
IteratorB
=
IteratorB_
;
///< Iterates over tiles of B operand in global memory
using
ElementC
=
ElementC_
;
///< Data type of accumulator matrix
using
LayoutC
=
LayoutC_
;
///< Layout of accumulator matrix
using
Policy
=
Policy_
;
///< Policy describing tuning details
using
IteratorScale
=
IteratorScale_
;
using
ElementScale
=
typename
IteratorScale
::
Element
;
using
LayoutScale
=
typename
IteratorScale
::
Layout
;
using
SmemIteratorA
=
SmemIteratorA_
;
using
SmemIteratorB
=
SmemIteratorB_
;
using
SmemIteratorScale
=
SmemIteratorScale_
;
using
TransformBAfterLDG
=
TransformBAfterLDG_
;
using
TransformBAfterLDS
=
TransformBAfterLDS_
;
static
constexpr
WeightOnlyQuantOp
QuantOp
=
QuantOp_
;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using
FragmentA
=
typename
IteratorA
::
Fragment
;
/// Fragment of operand B loaded from global memory
using
FragmentB
=
typename
IteratorB
::
Fragment
;
/// Fragment of operand Scale loaded from global memory;
using
FragmentScale
=
typename
IteratorScale
::
Fragment
;
/// Fragment of accumulator tile
using
FragmentC
=
typename
Policy
::
Operator
::
FragmentC
;
/// Warp-level Mma
using
Operator
=
typename
Policy
::
Operator
;
/// Obtain the arch tag from the warp-level operator
using
ArchTag
=
typename
Policy
::
Operator
::
ArchTag
;
using
Dequantizer
=
warp
::
MmaTensorOpDequantizer
<
Operator
,
typename
Base
::
WarpGemm
,
Operand
::
kB
,
typename
SmemIteratorScale
::
Element
,
LayoutScale
,
32
,
QuantOp
>
;
/// Complex transform on A operand
static
ComplexTransform
const
kTransformA
=
Operator
::
kTransformA
;
/// Complex transform on B operand
static
ComplexTransform
const
kTransformB
=
Operator
::
kTransformB
;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert
((
Base
::
kStages
==
2
),
"DqMmaPipelined requires kStages set to value 2"
);
static_assert
(
Base
::
SharedStorage
::
ShapeScale
::
kRow
==
Base
::
kStages
,
""
);
static_assert
(
Base
::
SharedStorage
::
ShapeScale
::
kColumn
==
Shape
::
kN
,
""
);
private:
using
WarpFragmentA
=
typename
Operator
::
FragmentA
;
using
WarpFragmentB
=
typename
Operator
::
FragmentB
;
Dequantizer
warp_dequantizer_
;
using
WarpFragmentScale
=
typename
Dequantizer
::
FragmentScale
;
using
WarpFragmentZero
=
typename
Dequantizer
::
FragmentZero
;
using
ElementA
=
typename
IteratorA
::
Element
;
using
ElementB
=
typename
IteratorB
::
Element
;
using
LayoutDetailsForB
=
kernel
::
LayoutDetailsB
<
ElementA
,
ElementB
,
ArchTag
>
;
static
constexpr
bool
RequiresTileInterleave
=
layout
::
IsColumnMajorTileInterleave
<
typename
LayoutDetailsForB
::
Layout
>::
value
;
static_assert
(
!
RequiresTileInterleave
||
(
RequiresTileInterleave
&&
(
Shape
::
kK
==
LayoutDetailsForB
::
ThreadblockK
)),
"Layout K must match threadblockK"
);
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA
smem_iterator_A_
;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB
smem_iterator_B_
;
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
SmemIteratorScale
smem_iterator_scale_
;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined
(
typename
Base
::
SharedStorage
&
shared_storage
,
///< Shared storage needed for internal use by threadblock-scoped GEMM
int
const
group_size
,
///< The group size for quantization
int
thread_idx
,
///< ID within the threadblock
int
warp_idx
,
///< ID of warp
int
lane_idx
///< ID of each thread within a warp
)
:
Base
(
shared_storage
,
thread_idx
,
warp_idx
,
lane_idx
)
,
warp_dequantizer_
({
shared_storage
.
operand_scale
.
data
(),
LayoutScale
(
Shape
::
kN
)},
{
shared_storage
.
operand_zero
.
data
(),
LayoutScale
(
Shape
::
kN
)},
(
warp_idx
%
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
))
/
Base
::
WarpCount
::
kM
,
lane_idx
)
,
smem_iterator_A_
(
shared_storage
.
operand_A_ref
(),
thread_idx
)
,
smem_iterator_B_
(
shared_storage
.
operand_B_ref
(),
thread_idx
)
,
smem_iterator_scale_
(
LayoutScale
(
Shape
::
kN
),
shared_storage
.
operand_scale
.
data
(),
shared_storage
.
operand_zero
.
data
(),
{
Base
::
kStages
,
Shape
::
kN
},
thread_idx
,
group_size
)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int
warp_idx_mn
=
warp_idx
%
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
);
int
warp_idx_k
=
warp_idx
/
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
);
int
warp_idx_m
=
warp_idx_mn
%
Base
::
WarpCount
::
kM
;
int
warp_idx_n
=
warp_idx_mn
/
Base
::
WarpCount
::
kM
;
// Add per-warp offsets in units of warp-level tiles
this
->
warp_tile_iterator_A_
.
add_tile_offset
({
warp_idx_m
,
Base
::
kWarpGemmIterations
*
warp_idx_k
});
this
->
warp_tile_iterator_B_
.
add_tile_offset
({
Base
::
kWarpGemmIterationsForB
*
warp_idx_k
,
warp_idx_n
});
}
CUTLASS_DEVICE
void
copy_scales_and_advance
(
IteratorScale
&
iterator_scale
)
{
using
TransformScale
=
NumericArrayConverter
<
typename
SmemIteratorScale
::
Element
,
typename
FragmentScale
::
Element
,
FragmentScale
::
kElements
>
;
FragmentScale
tb_frag_scales
;
FragmentScale
tb_frag_zeros
;
tb_frag_scales
.
clear
();
tb_frag_zeros
.
clear
();
TransformScale
transformScale
;
using
FragmentElement
=
typename
FragmentScale
::
Element
;
auto
gmem_scale_ptr
=
iterator_scale
.
get_scale
();
auto
gmem_zero_ptr
=
iterator_scale
.
get_zero
();
arch
::
global_load
<
FragmentScale
,
sizeof
(
FragmentScale
)
>
(
tb_frag_scales
,
gmem_scale_ptr
,
iterator_scale
.
valid
());
if
(
gmem_zero_ptr
!=
nullptr
)
{
arch
::
global_load
<
FragmentScale
,
sizeof
(
FragmentScale
)
>
(
tb_frag_zeros
,
gmem_zero_ptr
,
iterator_scale
.
valid
());
}
typename
TransformScale
::
result_type
tb_frag_scales_fp16
=
transformScale
(
tb_frag_scales
);
typename
TransformScale
::
result_type
tb_frag_zeros_fp16
;
if
(
gmem_zero_ptr
!=
nullptr
)
tb_frag_zeros_fp16
=
transformScale
(
tb_frag_zeros
);
auto
frag_scale_ptr_fp16
=
reinterpret_cast
<
typename
SmemIteratorScale
::
Element
*>
(
&
tb_frag_scales_fp16
);
auto
frag_zero_ptr_fp16
=
reinterpret_cast
<
typename
SmemIteratorScale
::
Element
*>
(
&
tb_frag_zeros_fp16
);
auto
smem_scale_ptr
=
this
->
smem_iterator_scale_
.
get_scale
();
auto
smem_zero_ptr
=
this
->
smem_iterator_scale_
.
get_zero
();
if
(
iterator_scale
.
valid
())
{
auto
smem_offset
=
cast_smem_ptr_to_uint
(
smem_scale_ptr
);
arch
::
shared_store
<
sizeof
(
FragmentScale
)
>
(
smem_offset
,
frag_scale_ptr_fp16
);
if
(
gmem_zero_ptr
!=
nullptr
)
{
smem_offset
=
cast_smem_ptr_to_uint
(
smem_zero_ptr
);
arch
::
shared_store
<
sizeof
(
FragmentScale
)
>
(
smem_offset
,
frag_zero_ptr_fp16
);
}
}
if
(
iterator_scale
.
group_size_
==
64
)
{
iterator_scale
.
add_tile_offset
({
1
,
0
});
}
else
if
(
iterator_scale
.
group_size_
==
128
)
{
if
constexpr
(
Shape
::
kK
==
128
)
{
iterator_scale
.
add_tile_offset
({
1
,
0
});
}
else
if
constexpr
(
Shape
::
kK
==
64
)
{
if
(
iterator_scale
.
row_groupsize64_
&
0x1
)
{
iterator_scale
.
add_tile_offset
({
1
,
0
});
}
}
else
{
static_assert
(
Shape
::
kK
==
0
,
"Unsupported k tile shape, can only be 64 or 128"
);
}
}
iterator_scale
.
row_groupsize64_
++
;
this
->
smem_iterator_scale_
.
add_tile_offset
({
1
,
0
});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void
operator
()(
int
gemm_k_iterations
,
///< number of iterations of the mainloop
FragmentC
&
accum
,
///< destination accumulator tile
IteratorA
iterator_A
,
///< iterator over A operand in global memory
IteratorB
iterator_B
,
///< iterator over B operand in global memory
IteratorScale
iterator_scale
,
///< iterator over scale operand in global memory
FragmentC
const
&
src_accum
)
{
///< source accumulator tile
//
// Prologue
//
TransformBAfterLDG
ldg_converter
;
TransformBAfterLDS
lds_converter
;
using
TransformA
=
NumericArrayConverter
<
typename
WarpFragmentA
::
Element
,
typename
FragmentA
::
Element
,
FragmentA
::
kElements
>
;
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
TransformA
transformA
;
// Perform accumulation in the 'd' output operand
accum
=
src_accum
;
FragmentA
tb_frag_A
;
FragmentB
tb_frag_B
;
tb_frag_A
.
clear
();
tb_frag_B
.
clear
();
// The last kblock is loaded in the prolog
iterator_A
.
load
(
tb_frag_A
);
iterator_B
.
load
(
tb_frag_B
);
++
iterator_A
;
++
iterator_B
;
this
->
smem_iterator_A_
.
store
(
transformA
(
tb_frag_A
));
this
->
smem_iterator_B_
.
store
(
ldg_converter
(
tb_frag_B
));
++
this
->
smem_iterator_A_
;
++
this
->
smem_iterator_B_
;
copy_scales_and_advance
(
iterator_scale
);
__syncthreads
();
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA
warp_frag_A
[
2
];
WarpFragmentB
warp_frag_B
[
2
];
WarpFragmentScale
warp_frag_scales
;
WarpFragmentZero
warp_frag_zero
;
this
->
warp_tile_iterator_A_
.
set_kgroup_index
(
0
);
this
->
warp_tile_iterator_B_
.
set_kgroup_index
(
0
);
this
->
warp_tile_iterator_A_
.
load
(
warp_frag_A
[
0
]);
this
->
warp_tile_iterator_B_
.
load
(
warp_frag_B
[
0
]);
warp_dequantizer_
.
load
(
warp_frag_scales
,
warp_frag_zero
);
++
this
->
warp_tile_iterator_A_
;
++
this
->
warp_tile_iterator_B_
;
warp_dequantizer_
.
add_pointer_offset
(
Shape
::
kN
);
Operator
warp_mma
;
int
smem_write_stage_idx
=
1
;
// Avoid reading out of bounds
iterator_A
.
clear_mask
(
gemm_k_iterations
<=
1
);
iterator_B
.
clear_mask
(
gemm_k_iterations
<=
1
);
iterator_scale
.
clear_mask
(
gemm_k_iterations
<=
1
);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for
(;
gemm_k_iterations
>
0
;
--
gemm_k_iterations
)
{
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for
(
int
warp_mma_k
=
0
;
warp_mma_k
<
Base
::
kWarpGemmIterations
;
++
warp_mma_k
)
{
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if
(
warp_mma_k
==
Base
::
kWarpGemmIterations
-
1
)
{
// Write fragments to shared memory
this
->
smem_iterator_A_
.
store
(
transformA
(
tb_frag_A
));
this
->
smem_iterator_B_
.
store
(
ldg_converter
(
tb_frag_B
));
__syncthreads
();
++
this
->
smem_iterator_A_
;
++
this
->
smem_iterator_B_
;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if
(
smem_write_stage_idx
==
1
)
{
this
->
smem_iterator_A_
.
add_tile_offset
({
0
,
-
Base
::
kStages
});
this
->
smem_iterator_B_
.
add_tile_offset
({
-
Base
::
kStages
,
0
});
this
->
smem_iterator_scale_
.
add_tile_offset
({
-
Base
::
kStages
,
0
});
}
else
{
this
->
warp_tile_iterator_A_
.
add_tile_offset
(
{
0
,
-
Base
::
kStages
*
Policy
::
kPartitionsK
*
Base
::
kWarpGemmIterations
});
this
->
warp_tile_iterator_B_
.
add_tile_offset
(
{
-
Base
::
kStages
*
Policy
::
kPartitionsK
*
Base
::
kWarpGemmIterationsForB
,
0
});
warp_dequantizer_
.
add_pointer_offset
(
-
Base
::
kStages
*
Shape
::
kN
);
}
smem_write_stage_idx
^=
1
;
}
this
->
warp_tile_iterator_A_
.
set_kgroup_index
((
warp_mma_k
+
1
)
%
Base
::
kWarpGemmIterations
);
this
->
warp_tile_iterator_A_
.
load
(
warp_frag_A
[(
warp_mma_k
+
1
)
%
2
]);
++
this
->
warp_tile_iterator_A_
;
int
const
warp_tileB_k_compute_offset
=
warp_mma_k
%
Base
::
kNumKIterationsPerWarpBLoad
;
int
const
warp_tileB_k_load_offset
=
warp_mma_k
/
Base
::
kNumKIterationsPerWarpBLoad
;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if
(
warp_tileB_k_compute_offset
==
Base
::
kNumKIterationsPerWarpBLoad
-
1
)
{
this
->
warp_tile_iterator_B_
.
set_kgroup_index
(
(
warp_tileB_k_load_offset
+
1
)
%
Base
::
kWarpGemmIterationsForB
);
this
->
warp_tile_iterator_B_
.
load
(
warp_frag_B
[(
warp_tileB_k_load_offset
+
1
)
%
2
]);
++
this
->
warp_tile_iterator_B_
;
}
if
(
warp_mma_k
==
0
)
{
iterator_A
.
load
(
tb_frag_A
);
iterator_B
.
load
(
tb_frag_B
);
++
iterator_A
;
++
iterator_B
;
copy_scales_and_advance
(
iterator_scale
);
// Avoid reading out of bounds if this was the last loop iteration
iterator_A
.
clear_mask
(
gemm_k_iterations
<=
2
);
iterator_B
.
clear_mask
(
gemm_k_iterations
<=
2
);
iterator_scale
.
clear_mask
(
gemm_k_iterations
<=
2
);
}
typename
TransformBAfterLDS
::
result_type
converted_frag_B
=
lds_converter
(
warp_frag_B
[
warp_tileB_k_load_offset
%
2
]);
warp_dequantizer_
.
dequantize
(
converted_frag_B
,
warp_frag_scales
,
warp_frag_zero
);
run_warp_mma
(
warp_mma
,
accum
,
warp_frag_A
[
warp_mma_k
%
2
],
converted_frag_B
,
accum
,
warp_tileB_k_compute_offset
);
}
// Load the scales needed for the next tile iteration
warp_dequantizer_
.
load
(
warp_frag_scales
,
warp_frag_zero
);
// Update internal pointer to the set of scales in shared memory
warp_dequantizer_
.
add_pointer_offset
(
Shape
::
kN
);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
#include "cutlass_extensions/gemm_configs.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template
<
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename
Shape_
,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename
IteratorA_
,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorA_
,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename
IteratorB_
,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename
SmemIteratorB_
,
/// Iterators over scales in global memory
typename
IteratorScale_
,
/// Iterators over scales in shared memory
typename
SmemIteratorScale_
,
/// Data type of accumulator matrix
typename
ElementC_
,
/// Layout of accumulator matrix
typename
LayoutC_
,
/// Policy describing tuning details (concept: MmaPolicy)
typename
Policy_
,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename
TransformBAfterLDG_
,
/// Converter for B matrix applited immediately after the LDS
typename
TransformBAfterLDS_
,
/// The quantization operator being used
WeightOnlyQuantOp
QuantOp_
>
class
DqMmaPipelined
<
Shape_
,
IteratorA_
,
SmemIteratorA_
,
IteratorB_
,
SmemIteratorB_
,
IteratorScale_
,
SmemIteratorScale_
,
ElementC_
,
LayoutC_
,
Policy_
,
TransformBAfterLDG_
,
TransformBAfterLDS_
,
QuantOp_
,
std
::
enable_if_t
<!
isFinegrained
(
QuantOp_
)
>>
:
public
DqMmaBase
<
Shape_
,
Policy_
,
typename
SmemIteratorScale_
::
Element
,
2
,
QuantOp_
>
{
public:
///< Base class
using
Base
=
DqMmaBase
<
Shape_
,
Policy_
,
typename
SmemIteratorScale_
::
Element
,
2
,
QuantOp_
>
;
using
Shape
=
Shape_
;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using
IteratorA
=
IteratorA_
;
///< Iterates over tiles of A operand in global memory
using
IteratorB
=
IteratorB_
;
///< Iterates over tiles of B operand in global memory
using
ElementC
=
ElementC_
;
///< Data type of accumulator matrix
using
LayoutC
=
LayoutC_
;
///< Layout of accumulator matrix
using
Policy
=
Policy_
;
///< Policy describing tuning details
using
IteratorScale
=
IteratorScale_
;
using
ElementScale
=
typename
IteratorScale
::
Element
;
using
LayoutScale
=
typename
IteratorScale
::
Layout
;
using
SmemIteratorA
=
SmemIteratorA_
;
using
SmemIteratorB
=
SmemIteratorB_
;
using
SmemIteratorScale
=
SmemIteratorScale_
;
using
TransformBAfterLDG
=
TransformBAfterLDG_
;
using
TransformBAfterLDS
=
TransformBAfterLDS_
;
static
constexpr
WeightOnlyQuantOp
QuantOp
=
QuantOp_
;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using
FragmentA
=
typename
IteratorA
::
Fragment
;
/// Fragment of operand B loaded from global memory
using
FragmentB
=
typename
IteratorB
::
Fragment
;
/// Fragment of operand Scale loaded from global memory;
using
FragmentScale
=
typename
IteratorScale
::
Fragment
;
/// Fragment of accumulator tile
using
FragmentC
=
typename
Policy
::
Operator
::
FragmentC
;
/// Warp-level Mma
using
Operator
=
typename
Policy
::
Operator
;
/// Obtain the arch tag from the warp-level operator
using
ArchTag
=
typename
Policy
::
Operator
::
ArchTag
;
using
Dequantizer
=
warp
::
MmaTensorOpDequantizer
<
Operator
,
typename
Base
::
WarpGemm
,
Operand
::
kB
,
typename
SmemIteratorScale
::
Fragment
::
Element
,
LayoutScale
,
32
,
QuantOp
>
;
/// Complex transform on A operand
static
ComplexTransform
const
kTransformA
=
Operator
::
kTransformA
;
/// Complex transform on B operand
static
ComplexTransform
const
kTransformB
=
Operator
::
kTransformB
;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert
((
Base
::
kStages
==
2
),
"DqMmaPipelined requires kStages set to value 2"
);
private:
using
WarpFragmentA
=
typename
Operator
::
FragmentA
;
using
WarpFragmentB
=
typename
Operator
::
FragmentB
;
Dequantizer
warp_dequantizer_
;
using
ElementA
=
typename
IteratorA
::
Element
;
using
ElementB
=
typename
IteratorB
::
Element
;
using
LayoutDetailsForB
=
kernel
::
LayoutDetailsB
<
ElementA
,
ElementB
,
ArchTag
>
;
static
constexpr
bool
RequiresTileInterleave
=
layout
::
IsColumnMajorTileInterleave
<
typename
LayoutDetailsForB
::
Layout
>::
value
;
static_assert
(
!
RequiresTileInterleave
||
(
RequiresTileInterleave
&&
(
Shape
::
kK
==
LayoutDetailsForB
::
ThreadblockK
)),
"Layout K must match threadblockK"
);
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA
smem_iterator_A_
;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB
smem_iterator_B_
;
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
SmemIteratorScale
smem_iterator_scale_
;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined
(
typename
Base
::
SharedStorage
&
shared_storage
,
///< Shared storage needed for internal use by threadblock-scoped GEMM
int
const
group_size
,
///< Will not be used, just to adapt to finegrained modifications and make the compilation
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
///< argument is not added, it does not affect compilation for sm>=80.
int
thread_idx
,
///< ID within the threadblock
int
warp_idx
,
///< ID of warp
int
lane_idx
///< ID of each thread within a warp
)
:
Base
(
shared_storage
,
thread_idx
,
warp_idx
,
lane_idx
)
,
warp_dequantizer_
({
shared_storage
.
operand_scale
.
data
(),
LayoutScale
(
Shape
::
kN
)},
(
warp_idx
%
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
))
/
Base
::
WarpCount
::
kM
,
lane_idx
)
,
smem_iterator_A_
(
shared_storage
.
operand_A_ref
(),
thread_idx
)
,
smem_iterator_B_
(
shared_storage
.
operand_B_ref
(),
thread_idx
)
,
smem_iterator_scale_
(
LayoutScale
(
Shape
::
kN
),
shared_storage
.
operand_scale
.
data
(),
{
1
,
Shape
::
kN
},
thread_idx
)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int
warp_idx_mn
=
warp_idx
%
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
);
int
warp_idx_k
=
warp_idx
/
(
Base
::
WarpCount
::
kM
*
Base
::
WarpCount
::
kN
);
int
warp_idx_m
=
warp_idx_mn
%
Base
::
WarpCount
::
kM
;
int
warp_idx_n
=
warp_idx_mn
/
Base
::
WarpCount
::
kM
;
// Add per-warp offsets in units of warp-level tiles
this
->
warp_tile_iterator_A_
.
add_tile_offset
({
warp_idx_m
,
Base
::
kWarpGemmIterations
*
warp_idx_k
});
this
->
warp_tile_iterator_B_
.
add_tile_offset
({
Base
::
kWarpGemmIterationsForB
*
warp_idx_k
,
warp_idx_n
});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void
operator
()(
int
gemm_k_iterations
,
///< number of iterations of the mainloop
FragmentC
&
accum
,
///< destination accumulator tile
IteratorA
iterator_A
,
///< iterator over A operand in global memory
IteratorB
iterator_B
,
///< iterator over B operand in global memory
IteratorScale
iterator_scale
,
///< iterator over scale operand in global memory
FragmentC
const
&
src_accum
)
{
///< source accumulator tile
//
// Prologue
//
TransformBAfterLDG
ldg_converter
;
TransformBAfterLDS
lds_converter
;
using
TransformA
=
NumericArrayConverter
<
typename
WarpFragmentA
::
Element
,
typename
FragmentA
::
Element
,
FragmentA
::
kElements
>
;
using
TransformScale
=
NumericArrayConverter
<
typename
SmemIteratorScale
::
Fragment
::
Element
,
typename
FragmentScale
::
Element
,
FragmentScale
::
kElements
>
;
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
TransformA
transformA
;
TransformScale
transformScale
;
// Perform accumulation in the 'd' output operand
accum
=
src_accum
;
FragmentA
tb_frag_A
;
FragmentB
tb_frag_B
;
FragmentScale
tb_frag_scales
;
using
WarpFragmentScale
=
typename
Dequantizer
::
FragmentScale
;
WarpFragmentScale
warp_frag_scales
;
tb_frag_A
.
clear
();
tb_frag_B
.
clear
();
tb_frag_scales
.
clear
();
// The last kblock is loaded in the prolog
iterator_A
.
load
(
tb_frag_A
);
iterator_B
.
load
(
tb_frag_B
);
iterator_scale
.
load
(
tb_frag_scales
);
++
iterator_A
;
++
iterator_B
;
this
->
smem_iterator_A_
.
store
(
transformA
(
tb_frag_A
));
this
->
smem_iterator_B_
.
store
(
ldg_converter
(
tb_frag_B
));
this
->
smem_iterator_scale_
.
store
(
transformScale
(
tb_frag_scales
));
++
this
->
smem_iterator_A_
;
++
this
->
smem_iterator_B_
;
__syncthreads
();
warp_dequantizer_
.
load
(
warp_frag_scales
);
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA
warp_frag_A
[
2
];
WarpFragmentB
warp_frag_B
[
2
];
this
->
warp_tile_iterator_A_
.
set_kgroup_index
(
0
);
this
->
warp_tile_iterator_B_
.
set_kgroup_index
(
0
);
this
->
warp_tile_iterator_A_
.
load
(
warp_frag_A
[
0
]);
this
->
warp_tile_iterator_B_
.
load
(
warp_frag_B
[
0
]);
++
this
->
warp_tile_iterator_A_
;
++
this
->
warp_tile_iterator_B_
;
Operator
warp_mma
;
int
smem_write_stage_idx
=
1
;
// Avoid reading out of bounds
iterator_A
.
clear_mask
(
gemm_k_iterations
<=
1
);
iterator_B
.
clear_mask
(
gemm_k_iterations
<=
1
);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for
(;
gemm_k_iterations
>
0
;
--
gemm_k_iterations
)
{
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for
(
int
warp_mma_k
=
0
;
warp_mma_k
<
Base
::
kWarpGemmIterations
;
++
warp_mma_k
)
{
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if
(
warp_mma_k
==
Base
::
kWarpGemmIterations
-
1
)
{
// Write fragments to shared memory
this
->
smem_iterator_A_
.
store
(
transformA
(
tb_frag_A
));
this
->
smem_iterator_B_
.
store
(
ldg_converter
(
tb_frag_B
));
__syncthreads
();
++
this
->
smem_iterator_A_
;
++
this
->
smem_iterator_B_
;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if
(
smem_write_stage_idx
==
1
)
{
this
->
smem_iterator_A_
.
add_tile_offset
({
0
,
-
Base
::
kStages
});
this
->
smem_iterator_B_
.
add_tile_offset
({
-
Base
::
kStages
,
0
});
}
else
{
this
->
warp_tile_iterator_A_
.
add_tile_offset
(
{
0
,
-
Base
::
kStages
*
Policy
::
kPartitionsK
*
Base
::
kWarpGemmIterations
});
this
->
warp_tile_iterator_B_
.
add_tile_offset
(
{
-
Base
::
kStages
*
Policy
::
kPartitionsK
*
Base
::
kWarpGemmIterationsForB
,
0
});
}
smem_write_stage_idx
^=
1
;
}
this
->
warp_tile_iterator_A_
.
set_kgroup_index
((
warp_mma_k
+
1
)
%
Base
::
kWarpGemmIterations
);
this
->
warp_tile_iterator_A_
.
load
(
warp_frag_A
[(
warp_mma_k
+
1
)
%
2
]);
++
this
->
warp_tile_iterator_A_
;
int
const
warp_tileB_k_compute_offset
=
warp_mma_k
%
Base
::
kNumKIterationsPerWarpBLoad
;
int
const
warp_tileB_k_load_offset
=
warp_mma_k
/
Base
::
kNumKIterationsPerWarpBLoad
;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if
(
warp_tileB_k_compute_offset
==
Base
::
kNumKIterationsPerWarpBLoad
-
1
)
{
this
->
warp_tile_iterator_B_
.
set_kgroup_index
(
(
warp_tileB_k_load_offset
+
1
)
%
Base
::
kWarpGemmIterationsForB
);
this
->
warp_tile_iterator_B_
.
load
(
warp_frag_B
[(
warp_tileB_k_load_offset
+
1
)
%
2
]);
++
this
->
warp_tile_iterator_B_
;
}
if
(
warp_mma_k
==
0
)
{
iterator_A
.
load
(
tb_frag_A
);
iterator_B
.
load
(
tb_frag_B
);
++
iterator_A
;
++
iterator_B
;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A
.
clear_mask
(
gemm_k_iterations
<=
2
);
iterator_B
.
clear_mask
(
gemm_k_iterations
<=
2
);
}
typename
TransformBAfterLDS
::
result_type
converted_frag_B
=
lds_converter
(
warp_frag_B
[
warp_tileB_k_load_offset
%
2
]);
warp_dequantizer_
.
dequantize
(
converted_frag_B
,
warp_frag_scales
);
run_warp_mma
(
warp_mma
,
accum
,
warp_frag_A
[
warp_mma_k
%
2
],
converted_frag_B
,
accum
,
warp_tileB_k_compute_offset
);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace threadblock
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
namespace
cutlass
{
namespace
gemm
{
namespace
warp
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for m-by-n-by-kgroup
template
<
/// Shape of one matrix production operation (concept: GemmShape)
typename
WarpShape_
,
/// Shape of one matrix production operation (concept: GemmShape)
typename
InstructionShape_
,
/// Data type of A elements,
typename
ElementA
,
/// Layout of A matrix (concept: MatrixLayout)
typename
LayoutA
,
/// Data type of B elements
typename
ElementB
,
/// Layout of B matrix (concept: MatrixLayout)
typename
LayoutB
,
/// Element type of C matrix
typename
ElementC
,
/// Layout of C matrix (concept: MatrixLayout)
typename
LayoutC
,
/// Number of partitions along K dimension
int
PartitionsK
,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool
AccumulatorsInRowMajor
>
struct
DefaultMmaTensorOp
<
WarpShape_
,
InstructionShape_
,
ElementA
,
LayoutA
,
ElementB
,
LayoutB
,
ElementC
,
LayoutC
,
arch
::
OpMultiplyAddDequantizeInterleavedBToA
,
PartitionsK
,
AccumulatorsInRowMajor
>
{
private:
// Shape for computing the FP16s
using
ComputeInstructionShape
=
InstructionShape_
;
// Chosen so we get K=16 for int8 and K=32 for int4.
static
constexpr
int
LoadInstructionK
=
128
/
sizeof_bits
<
ElementB
>::
value
;
// Shape for loading the narrow data type from shared memory
using
LoadInstructionShape
=
GemmShape
<
InstructionShape_
::
kM
,
InstructionShape_
::
kN
,
LoadInstructionK
>
;
public:
using
Policy
=
cutlass
::
gemm
::
warp
::
MmaTensorOpPolicy
<
cutlass
::
arch
::
Mma
<
InstructionShape_
,
32
,
ElementA
,
cutlass
::
layout
::
RowMajor
,
ElementA
,
cutlass
::
layout
::
ColumnMajor
,
ElementC
,
cutlass
::
layout
::
RowMajor
,
arch
::
OpMultiplyAdd
>
,
cutlass
::
MatrixShape
<
1
,
1
>>
;
// Define the warp-level tensor op
using
Type
=
cutlass
::
gemm
::
warp
::
MmaTensorOpComputeBWithF16
<
WarpShape_
,
ElementA
,
LayoutA
,
ElementB
,
LayoutB
,
ElementC
,
LayoutC
,
Policy
,
LoadInstructionShape
,
PartitionsK
,
AccumulatorsInRowMajor
>
;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace warp
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
Tensor Cores.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/platform/platform.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/arch/mma_sm89.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
warp
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template
<
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename
Shape_
,
/// Data type of A elements
typename
ElementA_
,
/// Layout of A matrix (concept: MatrixLayout)
typename
LayoutA_
,
/// Data type of B elements
typename
ElementB_
,
/// Layout of B matrix (concept: MatrixLayout)
typename
LayoutB_
,
/// Element type of C matrix
typename
ElementC_
,
/// Layout of C matrix (concept: MatrixLayout)
typename
LayoutC_
,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename
Policy_
,
/// Instruction shape to override shared memory iterators with
typename
SharedMemoryInstructionShape_
,
/// Number of partitions along K dimension
int
PartitionsK_
=
1
,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool
AccumulatorsInRowMajor
=
false
,
/// Used for partial specialization
typename
Enable
=
bool
>
class
MmaTensorOpComputeBWithF16
{
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using
Shape
=
Shape_
;
/// Data type of multiplicand A
using
ElementA
=
ElementA_
;
/// Layout of multiplicand A
using
LayoutA
=
LayoutA_
;
/// Data type of multiplicand B
using
ElementB
=
ElementB_
;
/// Layout of multiplicand B
using
LayoutB
=
LayoutB_
;
/// Data type of accumulator matrix C
using
ElementC
=
ElementC_
;
/// Layout of accumulator matrix C
using
LayoutC
=
LayoutC_
;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using
Policy
=
Policy_
;
/// Underlying matrix multiply operator (concept: arch::Mma)
using
ArchMmaOperator
=
typename
Policy
::
Operator
;
/// Indicates math operator
using
MathOperator
=
typename
ArchMmaOperator
::
Operator
;
/// Architecture tag from underlying instruction
using
ArchTag
=
typename
ArchMmaOperator
::
ArchTag
;
static_assert
((
platform
::
is_same
<
typename
ArchMmaOperator
::
ElementA
,
half_t
>::
value
&&
platform
::
is_same
<
typename
ArchMmaOperator
::
ElementB
,
half_t
>::
value
)
||
(
platform
::
is_same
<
typename
ArchMmaOperator
::
ElementA
,
bfloat16_t
>::
value
&&
platform
::
is_same
<
typename
ArchMmaOperator
::
ElementB
,
bfloat16_t
>::
value
&&
ArchTag
::
kMinComputeCapability
>=
80
)
||
(
platform
::
is_same
<
typename
ArchMmaOperator
::
ElementA
,
float_e4m3_t
>::
value
&&
platform
::
is_same
<
typename
ArchMmaOperator
::
ElementB
,
float_e4m3_t
>::
value
&&
ArchTag
::
kMinComputeCapability
>=
89
),
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"
);
static_assert
(
platform
::
is_same
<
ElementA
,
half_t
>::
value
||
(
platform
::
is_same
<
ElementA
,
bfloat16_t
>::
value
&&
ArchTag
::
kMinComputeCapability
>=
80
)
||
(
platform
::
is_same
<
ElementA
,
float_e4m3_t
>::
value
&&
ArchTag
::
kMinComputeCapability
>=
89
),
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada"
);
/// Indicates class of matrix operator
using
OperatorClass
=
arch
::
OpClassTensorOp
;
/// Shape of underlying instruction
using
InstructionShape
=
typename
ArchMmaOperator
::
Shape
;
/// Instruction shape to override shared memory iterators with
using
SharedMemoryInstructionShape
=
SharedMemoryInstructionShape_
;
static_assert
(
SharedMemoryInstructionShape
::
kM
==
InstructionShape
::
kM
,
"M dimension of compute instruction must match load"
);
static_assert
(
SharedMemoryInstructionShape
::
kN
==
InstructionShape
::
kN
,
"N dimension of compute instruction must match load"
);
static
constexpr
int
kExpansionFactor
=
SharedMemoryInstructionShape
::
kK
/
InstructionShape
::
kK
;
static_assert
(
!
(
Shape
::
kK
%
SharedMemoryInstructionShape
::
kK
),
""
);
/// Complex transform on A operand
static
ComplexTransform
const
kTransformA
=
ComplexTransform
::
kNone
;
/// Complex transform on B operand
static
ComplexTransform
const
kTransformB
=
ComplexTransform
::
kNone
;
/// Number of threads participating in warp-level matrix product
static
int
const
kThreadCount
=
32
;
/// Number of partitions along K dimension
static
int
const
kPartitionsK
=
PartitionsK_
;
public:
/// Iterates over the A operand in memory
using
IteratorA
=
MmaTensorOpMultiplicandTileIterator
<
MatrixShape
<
Shape
::
kM
,
Shape
::
kK
>
,
Operand
::
kA
,
ElementA
,
LayoutA
,
MatrixShape
<
InstructionShape
::
kM
,
InstructionShape
::
kK
>
,
Policy
::
OpDelta
::
kRow
,
kThreadCount
,
kPartitionsK
>
;
/// Storage for A tile
using
FragmentA
=
typename
IteratorA
::
Fragment
;
/// Storage for transformed A tile
using
TransformedFragmentA
=
Array
<
typename
ArchMmaOperator
::
ElementA
,
FragmentA
::
kElements
>
;
/// Iterates over the B operand in memory
using
IteratorB
=
MmaTensorOpMultiplicandTileIterator
<
MatrixShape
<
Shape
::
kK
,
Shape
::
kN
>
,
Operand
::
kB
,
ElementB
,
LayoutB
,
MatrixShape
<
SharedMemoryInstructionShape
::
kK
,
InstructionShape
::
kN
>
,
Policy
::
OpDelta
::
kRow
,
kThreadCount
,
kPartitionsK
>
;
/// Storage for B tile
using
FragmentB
=
typename
IteratorB
::
Fragment
;
/// Storage for transformed B tile
using
TransformedFragmentB
=
Array
<
typename
ArchMmaOperator
::
ElementB
,
FragmentB
::
kElements
>
;
/// Iterates over the C operand in memory
using
IteratorC
=
MmaTensorOpAccumulatorTileIterator
<
MatrixShape
<
Shape
::
kM
,
Shape
::
kN
>
,
ElementC
,
LayoutC
,
typename
ArchMmaOperator
::
Shape
,
typename
Policy
::
OpDelta
>
;
/// Storage for C tile
using
FragmentC
=
typename
IteratorC
::
Fragment
;
/// Number of mma operations performed
using
MmaIterations
=
MatrixShape
<
(
Shape
::
kM
+
ArchMmaOperator
::
Shape
::
kM
-
1
)
/
ArchMmaOperator
::
Shape
::
kM
,
(
Shape
::
kN
+
ArchMmaOperator
::
Shape
::
kN
-
1
)
/
ArchMmaOperator
::
Shape
::
kN
>
;
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
ArchMmaOperator
mma
;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaTensorOpComputeBWithF16
()
{}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void
operator
()(
FragmentC
&
D
,
TransformedFragmentA
const
&
A
,
TransformedFragmentB
const
&
B
,
FragmentC
const
&
C
,
int
const
warp_tileB_k_offset
)
const
{
using
MmaOperandA
=
typename
ArchMmaOperator
::
FragmentA
;
using
MmaOperandB
=
typename
ArchMmaOperator
::
FragmentB
;
using
MmaOperandC
=
typename
ArchMmaOperator
::
FragmentC
;
static_assert
(
TransformedFragmentB
::
kElements
==
MmaOperandB
::
kElements
*
kExpansionFactor
*
MmaIterations
::
kColumn
,
"Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of "
"B"
);
D
=
C
;
MmaOperandA
const
*
ptr_A
=
reinterpret_cast
<
MmaOperandA
const
*>
(
&
A
);
MmaOperandB
const
*
ptr_B
=
reinterpret_cast
<
MmaOperandB
const
*>
(
&
B
);
MmaOperandC
*
ptr_D
=
reinterpret_cast
<
MmaOperandC
*>
(
&
D
);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
// Serpentine visitation order maximizing reuse of Rb
CUTLASS_PRAGMA_UNROLL
for
(
int
n
=
0
;
n
<
MmaIterations
::
kColumn
;
++
n
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
m
=
0
;
m
<
MmaIterations
::
kRow
;
++
m
)
{
int
m_serpentine
=
((
n
%
2
)
?
(
MmaIterations
::
kRow
-
1
-
m
)
:
m
);
int
n_offsetB
=
warp_tileB_k_offset
+
kExpansionFactor
*
n
;
if
(
AccumulatorsInRowMajor
)
{
// matrix B is reordered
mma
(
ptr_D
[
n
+
m_serpentine
*
MmaIterations
::
kColumn
],
ptr_A
[
m_serpentine
],
ptr_B
[
n_offsetB
],
ptr_D
[
n
+
m_serpentine
*
MmaIterations
::
kColumn
]);
}
else
{
mma
(
ptr_D
[
m_serpentine
+
n
*
MmaIterations
::
kRow
],
ptr_A
[
m_serpentine
],
ptr_B
[
n_offsetB
],
ptr_D
[
m_serpentine
+
n
*
MmaIterations
::
kRow
]);
}
}
}
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// Serpentine visitation order maximizing reuse of Ra
CUTLASS_PRAGMA_UNROLL
for
(
int
m
=
0
;
m
<
MmaIterations
::
kRow
;
++
m
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
n
=
0
;
n
<
MmaIterations
::
kColumn
;
++
n
)
{
int
n_serpentine
=
((
m
%
2
)
?
(
MmaIterations
::
kColumn
-
1
-
n
)
:
n
);
int
n_serpentine_offsetB
=
warp_tileB_k_offset
+
kExpansionFactor
*
n_serpentine
;
if
(
AccumulatorsInRowMajor
)
{
// matrix B is reordered
mma
(
ptr_D
[
n_serpentine
+
m
*
MmaIterations
::
kColumn
],
ptr_A
[
m
],
ptr_B
[
n_serpentine_offsetB
],
ptr_D
[
n_serpentine
+
m
*
MmaIterations
::
kColumn
]);
}
else
{
mma
(
ptr_D
[
m
+
n_serpentine
*
MmaIterations
::
kRow
],
ptr_A
[
m
],
ptr_B
[
n_serpentine_offsetB
],
ptr_D
[
m
+
n_serpentine
*
MmaIterations
::
kRow
]);
}
}
}
#else
assert
(
0
);
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace warp
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/functional.h"
#include "cutlass/platform/platform.h"
#include "cutlass_extensions/weight_only_quant_op.h"
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
warp
{
////////////////////////////////////////////////////////////////////////////////
template
<
/// Matrix multiply operator
typename
MmaOperator_
,
/// Size of the matrix to load (concept: MatrixShape)
typename
Shape_
,
/// Operand identity
Operand
Operand
,
/// Data type of Scale elements
typename
Element_
,
/// Layout of operand
typename
Layout_
,
/// Number of threads participating in one matrix operation
int
Threads
,
///
WeightOnlyQuantOp
QuantOp_
,
///
typename
Enable
=
void
>
class
MmaTensorOpDequantizer
;
////////////////////////////////////////////////////////////////////////////////
// Bfloat specialization for Ampere
template
<
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename
MmaOperator_
,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename
Shape_
,
///
WeightOnlyQuantOp
QuantOp_
>
class
MmaTensorOpDequantizer
<
MmaOperator_
,
Shape_
,
Operand
::
kB
,
bfloat16_t
,
layout
::
RowMajor
,
32
,
QuantOp_
,
typename
platform
::
enable_if
<
MmaOperator_
::
ArchTag
::
kMinComputeCapability
>=
80
&&
platform
::
is_same
<
typename
MmaOperator_
::
ArchMmaOperator
::
LayoutB
,
layout
::
ColumnMajor
>::
value
>::
type
>
{
public:
/// Mma Operator
using
MmaOperator
=
MmaOperator_
;
// The architecture specific mma ooperator being used
using
ArchMmaOperator
=
typename
MmaOperator
::
ArchMmaOperator
;
// Mma Instruction Shape
using
InstructionShape
=
typename
ArchMmaOperator
::
Shape
;
// This is the ratio of the load instruction vs the compute instruction.
static
constexpr
int
kExpansionFactor
=
MmaOperator
::
IteratorB
::
InstructionShape
::
kRow
/
InstructionShape
::
kK
;
/// Type of the scales
using
ElementScale
=
bfloat16_t
;
/// Fragment to hold B data before Mma
using
FragmentDequantizedOperand
=
Array
<
ElementScale
,
MmaOperator
::
FragmentB
::
kElements
>
;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static
constexpr
int
kColsPerMmaPerThread
=
1
;
using
FragmentScale
=
Array
<
ElementScale
,
kColsPerMmaPerThread
*
MmaOperator
::
MmaIterations
::
kColumn
>
;
using
FragmentZero
=
Array
<
ElementScale
,
kColsPerMmaPerThread
*
MmaOperator
::
MmaIterations
::
kColumn
>
;
/// Warp mma shape
using
Shape
=
Shape_
;
/// Layout of the scales in shared memory
using
Layout
=
layout
::
RowMajor
;
/// TensorRef type for loading element from a tensor
using
TensorRef
=
TensorRef
<
ElementScale
,
Layout
>
;
static
constexpr
WeightOnlyQuantOp
QuantOp
=
QuantOp_
;
CUTLASS_DEVICE
MmaTensorOpDequantizer
(
TensorRef
smem_scales
,
TensorRef
smem_zeros
,
int
const
warp_idx_n
,
int
const
lane_idx
)
{
int
const
warp_offset
=
warp_idx_n
*
Shape
::
kN
;
int
const
quad
=
lane_idx
/
4
;
int
const
thread_offset
=
warp_offset
+
quad
;
pointer_scale_
=
smem_scales
.
data
()
+
thread_offset
;
if
constexpr
(
hasZero
(
QuantOp
))
{
pointer_zero_
=
smem_zeros
.
data
()
+
thread_offset
;
}
}
CUTLASS_DEVICE
MmaTensorOpDequantizer
(
TensorRef
smem_scales
,
int
const
warp_idx_n
,
int
const
lane_idx
)
:
MmaTensorOpDequantizer
(
smem_scales
,
TensorRef
(),
warp_idx_n
,
lane_idx
)
{
}
CUTLASS_DEVICE
void
load
(
FragmentScale
&
scale_frag
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
mma_n_iter
=
0
;
mma_n_iter
<
MmaOperator
::
MmaIterations
::
kColumn
;
++
mma_n_iter
)
{
scale_frag
[
mma_n_iter
]
=
pointer_scale_
[
mma_n_iter
*
InstructionShape
::
kN
];
}
}
CUTLASS_DEVICE
void
dequantize
(
FragmentDequantizedOperand
&
operand_frag
,
FragmentScale
const
&
scale_frag
)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
using
_MmaOperandB
=
typename
ArchMmaOperator
::
FragmentB
;
using
ExpandedMmaOperandB
=
Array
<
typename
_MmaOperandB
::
Element
,
kExpansionFactor
*
_MmaOperandB
::
kElements
>
;
static_assert
(
ExpandedMmaOperandB
::
kElements
*
MmaOperator
::
MmaIterations
::
kColumn
==
FragmentDequantizedOperand
::
kElements
,
""
);
__nv_bfloat16
const
*
scale_ptr
=
reinterpret_cast
<
__nv_bfloat16
const
*>
(
&
scale_frag
);
ExpandedMmaOperandB
*
operand_frag_ptr
=
reinterpret_cast
<
ExpandedMmaOperandB
*>
(
&
operand_frag
);
CUTLASS_PRAGMA_UNROLL
for
(
int
mma_n_iter
=
0
;
mma_n_iter
<
MmaOperator
::
MmaIterations
::
kColumn
;
++
mma_n_iter
)
{
static_assert
(
ExpandedMmaOperandB
::
kElements
%
2
==
0
,
""
);
__nv_bfloat162
scalex2
=
__bfloat162bfloat162
(
scale_ptr
[
mma_n_iter
]);
__nv_bfloat162
*
operand_bf16x2_ptr
=
reinterpret_cast
<
__nv_bfloat162
*>
(
&
operand_frag_ptr
[
mma_n_iter
]);
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
ExpandedMmaOperandB
::
kElements
/
2
;
++
ii
)
{
operand_bf16x2_ptr
[
ii
]
=
__hmul2
(
operand_bf16x2_ptr
[
ii
],
scalex2
);
}
}
#else
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
// numerous conversion instructions in GEMM main loop.
arch
::
device_breakpoint
();
#endif
}
CUTLASS_DEVICE
void
load
(
FragmentScale
&
scale_frag
,
FragmentScale
&
zero_frag
)
{
if
constexpr
(
hasZero
(
QuantOp
))
{
CUTLASS_PRAGMA_UNROLL
for
(
int
mma_n_iter
=
0
;
mma_n_iter
<
MmaOperator
::
MmaIterations
::
kColumn
;
++
mma_n_iter
)
{
scale_frag
[
mma_n_iter
]
=
pointer_scale_
[
mma_n_iter
*
InstructionShape
::
kN
];
zero_frag
[
mma_n_iter
]
=
pointer_zero_
[
mma_n_iter
*
InstructionShape
::
kN
];
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for
(
int
mma_n_iter
=
0
;
mma_n_iter
<
MmaOperator
::
MmaIterations
::
kColumn
;
++
mma_n_iter
)
{
scale_frag
[
mma_n_iter
]
=
pointer_scale_
[
mma_n_iter
*
InstructionShape
::
kN
];
}
}
}
CUTLASS_DEVICE
void
dequantize
(
FragmentDequantizedOperand
&
operand_frag
,
FragmentScale
const
&
scale_frag
,
FragmentScale
const
&
zero_frag
)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
using
_MmaOperandB
=
typename
ArchMmaOperator
::
FragmentB
;
using
ExpandedMmaOperandB
=
Array
<
typename
_MmaOperandB
::
Element
,
kExpansionFactor
*
_MmaOperandB
::
kElements
>
;
static_assert
(
ExpandedMmaOperandB
::
kElements
*
MmaOperator
::
MmaIterations
::
kColumn
==
FragmentDequantizedOperand
::
kElements
,
""
);
__nv_bfloat16
const
*
scale_ptr
=
reinterpret_cast
<
__nv_bfloat16
const
*>
(
&
scale_frag
);
__nv_bfloat16
const
*
zero_ptr
=
reinterpret_cast
<
__nv_bfloat16
const
*>
(
&
zero_frag
);
ExpandedMmaOperandB
*
operand_frag_ptr
=
reinterpret_cast
<
ExpandedMmaOperandB
*>
(
&
operand_frag
);
CUTLASS_PRAGMA_UNROLL
for
(
int
mma_n_iter
=
0
;
mma_n_iter
<
MmaOperator
::
MmaIterations
::
kColumn
;
++
mma_n_iter
)
{
static_assert
(
ExpandedMmaOperandB
::
kElements
%
2
==
0
,
""
);
__nv_bfloat162
scalex2
=
__bfloat162bfloat162
(
scale_ptr
[
mma_n_iter
]);
__nv_bfloat162
zerox2
=
__bfloat162bfloat162
(
zero_ptr
[
mma_n_iter
]);
__nv_bfloat162
*
operand_bf16x2_ptr
=
reinterpret_cast
<
__nv_bfloat162
*>
(
&
operand_frag_ptr
[
mma_n_iter
]);
if
constexpr
(
hasZero
(
QuantOp
))
{
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
ExpandedMmaOperandB
::
kElements
/
2
;
++
ii
)
{
operand_bf16x2_ptr
[
ii
]
=
__hfma2
(
operand_bf16x2_ptr
[
ii
],
scalex2
,
zerox2
);
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
ExpandedMmaOperandB
::
kElements
/
2
;
++
ii
)
{
operand_bf16x2_ptr
[
ii
]
=
__hmul2
(
operand_bf16x2_ptr
[
ii
],
scalex2
);
}
}
}
#else
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
// numerous conversion instructions in GEMM main loop.
arch
::
device_breakpoint
();
#endif
}
// Adds a pointer offset in units of elements.
CUTLASS_DEVICE
void
add_pointer_offset
(
int64_t
const
&
offset
)
{
static_assert
(
sizeof
(
ElementScale
)
>
1
,
""
);
pointer_scale_
+=
offset
;
pointer_zero_
+=
offset
;
}
private:
ElementScale
const
*
pointer_scale_
;
ElementScale
const
*
pointer_zero_
;
};
////////////////////////////////////////////////////////////////////////////////
// Specialization for Turing & Ampere
template
<
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename
MmaOperator_
,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename
Shape_
,
///
WeightOnlyQuantOp
QuantOp_
>
class
MmaTensorOpDequantizer
<
MmaOperator_
,
Shape_
,
Operand
::
kB
,
half_t
,
layout
::
RowMajor
,
32
,
QuantOp_
,
typename
platform
::
enable_if
<
MmaOperator_
::
ArchTag
::
kMinComputeCapability
>=
75
&&
platform
::
is_same
<
typename
MmaOperator_
::
ArchMmaOperator
::
LayoutB
,
layout
::
ColumnMajor
>::
value
>::
type
>
{
public:
/// Mma Operator
using
MmaOperator
=
MmaOperator_
;
// The architecture specific mma ooperator being used
using
ArchMmaOperator
=
typename
MmaOperator
::
ArchMmaOperator
;
// Mma Instruction Shape
using
InstructionShape
=
typename
ArchMmaOperator
::
Shape
;
// This is the ratio of the load instruction vs the compute instruction.
static
constexpr
int
kExpansionFactor
=
MmaOperator
::
IteratorB
::
InstructionShape
::
kRow
/
InstructionShape
::
kK
;
/// Type of the scales
using
ElementScale
=
half_t
;
/// Fragment to hold B data before Mma
using
FragmentDequantizedOperand
=
Array
<
ElementScale
,
MmaOperator
::
FragmentB
::
kElements
>
;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static
constexpr
int
kColsPerMmaPerThread
=
1
;
using
FragmentScale
=
Array
<
ElementScale
,
kColsPerMmaPerThread
*
MmaOperator
::
MmaIterations
::
kColumn
>
;
using
FragmentZero
=
Array
<
ElementScale
,
kColsPerMmaPerThread
*
MmaOperator
::
MmaIterations
::
kColumn
>
;
/// Warp mma shape
using
Shape
=
Shape_
;
/// Layout of the scales in shared memory
using
Layout
=
layout
::
RowMajor
;
/// TensorRef type for loading element from a tensor
using
TensorRef
=
TensorRef
<
ElementScale
,
Layout
>
;
static
constexpr
WeightOnlyQuantOp
QuantOp
=
QuantOp_
;
CUTLASS_DEVICE
MmaTensorOpDequantizer
(
TensorRef
smem_scales
,
TensorRef
smem_zeros
,
int
const
warp_idx_n
,
int
const
lane_idx
)
{
int
const
warp_offset
=
warp_idx_n
*
Shape
::
kN
;
int
const
quad
=
lane_idx
/
4
;
int
const
thread_offset
=
warp_offset
+
quad
;
pointer_scale_
=
smem_scales
.
data
()
+
thread_offset
;
if
constexpr
(
hasZero
(
QuantOp
))
{
pointer_zero_
=
smem_zeros
.
data
()
+
thread_offset
;
}
}
CUTLASS_DEVICE
MmaTensorOpDequantizer
(
TensorRef
smem_scales
,
int
const
warp_idx_n
,
int
const
lane_idx
)
:
MmaTensorOpDequantizer
(
smem_scales
,
TensorRef
(),
warp_idx_n
,
lane_idx
)
{
}
CUTLASS_DEVICE
void
load
(
FragmentScale
&
scale_frag
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
mma_n_iter
=
0
;
mma_n_iter
<
MmaOperator
::
MmaIterations
::
kColumn
;
++
mma_n_iter
)
{
scale_frag
[
mma_n_iter
]
=
pointer_scale_
[
mma_n_iter
*
InstructionShape
::
kN
];
}
}
CUTLASS_DEVICE
void
dequantize
(
FragmentDequantizedOperand
&
operand_frag
,
FragmentScale
const
&
scale_frag
)
{
using
_MmaOperandB
=
typename
ArchMmaOperator
::
FragmentB
;
using
ExpandedMmaOperandB
=
Array
<
typename
FragmentDequantizedOperand
::
Element
,
kExpansionFactor
*
_MmaOperandB
::
kElements
>
;
static_assert
(
ExpandedMmaOperandB
::
kElements
*
MmaOperator
::
MmaIterations
::
kColumn
==
FragmentDequantizedOperand
::
kElements
,
""
);
multiplies
<
ExpandedMmaOperandB
>
mul_op
;
ExpandedMmaOperandB
*
operand_frag_ptr
=
reinterpret_cast
<
ExpandedMmaOperandB
*>
(
&
operand_frag
);
CUTLASS_PRAGMA_UNROLL
for
(
int
mma_n_iter
=
0
;
mma_n_iter
<
MmaOperator
::
MmaIterations
::
kColumn
;
++
mma_n_iter
)
{
operand_frag_ptr
[
mma_n_iter
]
=
mul_op
(
operand_frag_ptr
[
mma_n_iter
],
scale_frag
[
mma_n_iter
]);
}
}
CUTLASS_DEVICE
void
load
(
FragmentScale
&
scale_frag
,
FragmentScale
&
zero_frag
)
{
if
constexpr
(
hasZero
(
QuantOp
))
{
CUTLASS_PRAGMA_UNROLL
for
(
int
mma_n_iter
=
0
;
mma_n_iter
<
MmaOperator
::
MmaIterations
::
kColumn
;
++
mma_n_iter
)
{
scale_frag
[
mma_n_iter
]
=
pointer_scale_
[
mma_n_iter
*
InstructionShape
::
kN
];
zero_frag
[
mma_n_iter
]
=
pointer_zero_
[
mma_n_iter
*
InstructionShape
::
kN
];
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for
(
int
mma_n_iter
=
0
;
mma_n_iter
<
MmaOperator
::
MmaIterations
::
kColumn
;
++
mma_n_iter
)
{
scale_frag
[
mma_n_iter
]
=
pointer_scale_
[
mma_n_iter
*
InstructionShape
::
kN
];
}
}
}
CUTLASS_DEVICE
void
dequantize
(
FragmentDequantizedOperand
&
operand_frag
,
FragmentScale
const
&
scale_frag
,
FragmentScale
const
&
zero_frag
)
{
using
_MmaOperandB
=
typename
ArchMmaOperator
::
FragmentB
;
using
ExpandedMmaOperandB
=
Array
<
typename
FragmentDequantizedOperand
::
Element
,
kExpansionFactor
*
_MmaOperandB
::
kElements
>
;
static_assert
(
ExpandedMmaOperandB
::
kElements
*
MmaOperator
::
MmaIterations
::
kColumn
==
FragmentDequantizedOperand
::
kElements
,
""
);
multiplies
<
ExpandedMmaOperandB
>
mul_op
;
ExpandedMmaOperandB
*
operand_frag_ptr
=
reinterpret_cast
<
ExpandedMmaOperandB
*>
(
&
operand_frag
);
if
constexpr
(
hasZero
(
QuantOp
))
{
plus
<
ExpandedMmaOperandB
>
plus_op
;
CUTLASS_PRAGMA_UNROLL
for
(
int
mma_n_iter
=
0
;
mma_n_iter
<
MmaOperator
::
MmaIterations
::
kColumn
;
++
mma_n_iter
)
{
operand_frag_ptr
[
mma_n_iter
]
=
plus_op
(
mul_op
(
operand_frag_ptr
[
mma_n_iter
],
scale_frag
[
mma_n_iter
]),
zero_frag
[
mma_n_iter
]);
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for
(
int
mma_n_iter
=
0
;
mma_n_iter
<
MmaOperator
::
MmaIterations
::
kColumn
;
++
mma_n_iter
)
{
operand_frag_ptr
[
mma_n_iter
]
=
mul_op
(
operand_frag_ptr
[
mma_n_iter
],
scale_frag
[
mma_n_iter
]);
}
}
}
// Adds a pointer offset in units of elements.
CUTLASS_DEVICE
void
add_pointer_offset
(
int64_t
const
&
offset
)
{
static_assert
(
sizeof
(
ElementScale
)
>
1
,
""
);
pointer_scale_
+=
offset
;
pointer_zero_
+=
offset
;
}
private:
ElementScale
const
*
pointer_scale_
;
ElementScale
const
*
pointer_zero_
;
};
////////////////////////////////////////////////////////////////////////////////
}
// namespace warp
}
// namespace gemm
}
// namespace cutlass
////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cassert>
#include <iostream>
#include <sstream>
#include <string>
namespace
tensorrt_llm
{
namespace
cutlass_extensions
{
// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
// in the kernel layout details when doing weight only quantization.
enum
class
CutlassTileConfig
{
// Signals that we should run heuristics do choose a config
Undefined
,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic
,
// SiMT config
CtaShape128x128x8_WarpShape64x64x8
,
// TensorCore configs CTA_N = 128, CTA_K = 64
// Warp configs for M=16
CtaShape16x128x64_WarpShape16x32x64
,
// Warp configs for M=32
CtaShape32x128x64_WarpShape32x32x64
,
// Warp configs for M=64
CtaShape64x128x64_WarpShape32x64x64
,
CtaShape64x64x128_WarpShape32x64x64
,
CtaShape64x128x64_WarpShape64x32x64
,
// Warp configs for M=128
CtaShape128x64x64_WarpShape64x32x64
,
CtaShape128x128x64_WarpShape64x32x64
,
CtaShape128x128x64_WarpShape64x64x64
,
CtaShape128x128x64_WarpShape128x32x64
,
CtaShape128x256x64_WarpShape64x64x64
,
// Warp configs for M=256
CtaShape256x128x64_WarpShape64x64x64
,
// TensorCore config CTA_N = 64, CTA_K = 128
CtaShape128x64x128_WarpShape64x32x128
,
// TensorCore config CTA_N = 256, CTA_K = 64
CtaShape16x256x64_WarpShape16x64x64
,
// TensorCore config CTA_N = 256, CTA_K = 128
CtaShape16x256x128_WarpShape16x64x128
};
enum
class
SplitKStyle
{
NO_SPLIT_K
,
SPLIT_K_SERIAL
,
STREAM_K
,
// Sm80+
// SPLIT_K_PARALLEL // Not supported yet
};
enum
class
CutlassTileConfigSM90
{
// Signals that we should run heuristics do choose a config
Undefined
,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic
,
// CTA configs for M=64
CtaShape64x16x128B
,
CtaShape64x32x128B
,
CtaShape64x64x128B
,
CtaShape64x128x128B
,
CtaShape64x256x128B
,
// CTA configs for M=128
CtaShape128x16x128B
,
CtaShape128x32x128B
,
CtaShape128x64x128B
,
CtaShape128x128x128B
,
CtaShape128x256x128B
,
// CTA configs for M=128
CtaShape256x128x128B
,
};
enum
class
MainloopScheduleType
{
AUTO
// Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this
// defaults to the "legacy" main loop schedule.
};
enum
class
EpilogueScheduleType
{
AUTO
// Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For
// architectures older than hopper, the epilogue is always performed by the same thread block as the main loop.
};
enum
class
ClusterShape
{
ClusterShape_1x1x1
,
ClusterShape_2x1x1
,
ClusterShape_1x2x1
,
ClusterShape_2x2x1
,
ClusterShape_1x8x1
,
ClusterShape_8x1x1
};
struct
CutlassGemmConfig
{
enum
CandidateConfigTypeParam
:
int
{
NONE
=
0
,
WEIGHT_ONLY
=
1u
<<
0
,
SIMT_ONLY
=
1u
<<
1
,
INT8_ONLY
=
1u
<<
2
,
HOPPER
=
1u
<<
3
,
GROUPED_GEMM
=
1u
<<
4
,
FP8_ONLY
=
1u
<<
5
,
};
CutlassTileConfig
tile_config
=
CutlassTileConfig
::
ChooseWithHeuristic
;
SplitKStyle
split_k_style
=
SplitKStyle
::
NO_SPLIT_K
;
int
split_k_factor
=
-
1
;
int
stages
=
-
1
;
// config options for sm90
CutlassTileConfigSM90
tile_config_sm90
=
CutlassTileConfigSM90
::
ChooseWithHeuristic
;
MainloopScheduleType
mainloop_schedule
=
MainloopScheduleType
::
AUTO
;
EpilogueScheduleType
epilogue_schedule
=
EpilogueScheduleType
::
AUTO
;
ClusterShape
cluster_shape
=
ClusterShape
::
ClusterShape_1x1x1
;
bool
is_sm90
=
false
;
CutlassGemmConfig
()
{}
CutlassGemmConfig
(
CutlassTileConfig
tile_config
,
SplitKStyle
split_k_style
,
int
split_k_factor
,
int
stages
)
:
tile_config
(
tile_config
)
,
split_k_style
(
split_k_style
)
,
split_k_factor
(
split_k_factor
)
,
stages
(
stages
)
,
is_sm90
(
false
)
{
}
CutlassGemmConfig
(
CutlassTileConfigSM90
tile_config_sm90
,
MainloopScheduleType
mainloop_schedule
,
EpilogueScheduleType
epilogue_schedule
,
ClusterShape
cluster_shape
)
:
tile_config_sm90
(
tile_config_sm90
)
,
mainloop_schedule
(
mainloop_schedule
)
,
epilogue_schedule
(
epilogue_schedule
)
,
cluster_shape
(
cluster_shape
)
,
is_sm90
(
true
)
{
}
std
::
string
toString
()
const
{
std
::
stringstream
tactic
;
tactic
<<
"Cutlass GEMM Tactic"
;
if
(
tile_config_sm90
!=
tensorrt_llm
::
cutlass_extensions
::
CutlassTileConfigSM90
::
ChooseWithHeuristic
)
{
assert
(
is_sm90
&&
"Invalid cutlass GEMM config"
);
tactic
<<
"
\n\t
style=TMA"
<<
"
\n\t
tile shape ID: "
<<
(
int
)
tile_config_sm90
<<
"
\n\t
cluster shape ID: "
<<
(
int
)
cluster_shape
<<
"
\n\t
mainloop sched: "
<<
(
int
)
mainloop_schedule
<<
"
\n\t
epi sched: "
<<
(
int
)
epilogue_schedule
;
}
else
if
(
tile_config
!=
tensorrt_llm
::
cutlass_extensions
::
CutlassTileConfig
::
ChooseWithHeuristic
)
{
assert
(
!
is_sm90
&&
"Invalid cutlass GEMM config"
);
tactic
<<
"
\n\t
style=compatible"
<<
"
\n\t
tile shape ID: "
<<
(
int
)
tile_config
<<
"
\n\t
stages: "
<<
(
int
)
stages
<<
"
\n\t
split k: "
<<
(
int
)
split_k_factor
;
}
else
{
tactic
<<
"
\n\t
undefined"
;
}
tactic
<<
"
\n
"
;
return
tactic
.
str
();
}
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
CutlassGemmConfig
const
&
config
)
{
// clang-format off
if
(
config
.
is_sm90
)
{
out
<<
"tile_config_sm90_enum: "
<<
int
(
config
.
tile_config_sm90
)
<<
", mainloop_schedule_enum: "
<<
int
(
config
.
mainloop_schedule
)
<<
", epilogue_schedule_enum: "
<<
int
(
config
.
epilogue_schedule
)
<<
", cluster_shape_enum: "
<<
int
(
config
.
cluster_shape
);
}
else
{
out
<<
"tile_config_enum: "
<<
int
(
config
.
tile_config
)
<<
", split_k_style_enum: "
<<
int
(
config
.
split_k_style
)
<<
", split_k_factor: "
<<
config
.
split_k_factor
<<
", stages: "
<<
config
.
stages
;
}
// clang-format on
return
out
;
}
}
// namespace cutlass_extensions
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/array.h"
#include "cutlass/half.h"
#include "cutlass/numeric_types.h"
namespace
cutlass
{
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
// This converter will uninterleave the data and subtract the bias while converting to the result type.
template
<
typename
T
,
typename
S
,
int
N
>
struct
FastInterleavedAndBiasedNumericArrayConverter
{
};
template
<
>
struct
FastInterleavedAndBiasedNumericArrayConverter
<
half_t
,
uint8_t
,
4
>
{
using
result_type
=
Array
<
half_t
,
4
>
;
using
source_type
=
Array
<
uint8_t
,
4
>
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
result_type
result
;
uint32_t
*
h
=
reinterpret_cast
<
uint32_t
*>
(
&
result
);
uint32_t
const
i8s
=
reinterpret_cast
<
uint32_t
const
&>
(
source
);
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
asm
volatile
(
"prmt.b32 %0,%1,%2,%3;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
i8s
),
"n"
(
start_byte_for_fp16
),
"n"
(
mask_for_elt_01
));
asm
volatile
(
"prmt.b32 %0,%1,%2,%3;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
i8s
),
"n"
(
start_byte_for_fp16
),
"n"
(
mask_for_elt_23
));
// Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
h
[
0
]),
"r"
(
I8s_TO_F16s_MAGIC_NUM
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
h
[
1
]),
"r"
(
I8s_TO_F16s_MAGIC_NUM
));
return
result
;
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
{
return
convert
(
s
);
}
};
template
<
int
N
>
struct
FastInterleavedAndBiasedNumericArrayConverter
<
half_t
,
uint8_t
,
N
>
{
static
constexpr
int
VEC_WIDTH
=
4
;
static_assert
(
!
(
N
%
VEC_WIDTH
),
"N must be multiple of 4."
);
using
result_type
=
Array
<
half_t
,
N
>
;
using
source_type
=
Array
<
uint8_t
,
N
>
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
using
scalar_result_type
=
typename
result_type
::
Element
;
using
scalar_source_type
=
typename
source_type
::
Element
;
FastInterleavedAndBiasedNumericArrayConverter
<
scalar_result_type
,
scalar_source_type
,
VEC_WIDTH
>
convert_vector_
;
result_type
result
;
using
vec_result
=
Array
<
scalar_result_type
,
VEC_WIDTH
>
;
using
vec_source
=
Array
<
scalar_source_type
,
VEC_WIDTH
>
;
vec_result
*
result_ptr
=
reinterpret_cast
<
vec_result
*>
(
&
result
);
vec_source
const
*
source_ptr
=
reinterpret_cast
<
vec_source
const
*>
(
&
source
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
N
/
VEC_WIDTH
;
++
i
)
{
result_ptr
[
i
]
=
convert_vector_
(
source_ptr
[
i
]);
}
return
result
;
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
{
return
convert
(
s
);
}
};
template
<
>
struct
FastInterleavedAndBiasedNumericArrayConverter
<
bfloat16_t
,
uint8_t
,
4
>
{
using
result_type
=
Array
<
bfloat16_t
,
4
>
;
using
source_type
=
Array
<
uint8_t
,
4
>
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
result_type
result
;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
result
);
uint32_t
const
i8s
=
reinterpret_cast
<
uint32_t
const
&>
(
source
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
float
fp32_intermediates
[
4
];
// Construct FP32s, bfloat does not have enough mantissa for IADD trick
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7653
);
// Subtract out fp32_base + 128 to make the unsigned integer signed.
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
4
;
++
ii
)
{
fp32_intermediates
[
ii
]
-=
8388736.
f
;
}
// Truncate the fp32 representation and pack up as bfloat16s.
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
2
;
++
ii
)
{
bf16_result_ptr
[
ii
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
*
ii
+
0
],
fp32_intermediates_casted
[
2
*
ii
+
1
],
0x7632
);
}
#else
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
result
.
clear
();
// Suppress compiler warning
arch
::
device_breakpoint
();
#endif
return
result
;
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
{
return
convert
(
s
);
}
};
template
<
int
N
>
struct
FastInterleavedAndBiasedNumericArrayConverter
<
bfloat16_t
,
uint8_t
,
N
>
{
static
constexpr
int
VEC_WIDTH
=
4
;
static_assert
(
!
(
N
%
VEC_WIDTH
),
"N must be multiple of 4."
);
using
result_type
=
Array
<
bfloat16_t
,
N
>
;
using
source_type
=
Array
<
uint8_t
,
N
>
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
using
scalar_result_type
=
typename
result_type
::
Element
;
using
scalar_source_type
=
typename
source_type
::
Element
;
FastInterleavedAndBiasedNumericArrayConverter
<
scalar_result_type
,
scalar_source_type
,
VEC_WIDTH
>
convert_vector_
;
result_type
result
;
using
vec_result
=
Array
<
scalar_result_type
,
VEC_WIDTH
>
;
using
vec_source
=
Array
<
scalar_source_type
,
VEC_WIDTH
>
;
vec_result
*
result_ptr
=
reinterpret_cast
<
vec_result
*>
(
&
result
);
vec_source
const
*
source_ptr
=
reinterpret_cast
<
vec_source
const
*>
(
&
source
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
N
/
VEC_WIDTH
;
++
i
)
{
result_ptr
[
i
]
=
convert_vector_
(
source_ptr
[
i
]);
}
return
result
;
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
{
return
convert
(
s
);
}
};
template
<
>
struct
FastInterleavedAndBiasedNumericArrayConverter
<
half_t
,
uint4b_t
,
8
>
{
using
result_type
=
Array
<
half_t
,
8
>
;
using
source_type
=
Array
<
uint4b_t
,
8
>
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
result_type
result
;
uint32_t
*
h
=
reinterpret_cast
<
uint32_t
*>
(
&
result
);
uint32_t
const
i4s
=
reinterpret_cast
<
uint32_t
const
&>
(
source
);
// First, we extract the i4s and construct an intermediate fp16 number.
static
constexpr
uint32_t
immLut
=
(
0xf0
&
0xcc
)
|
0xaa
;
static
constexpr
uint32_t
BOTTOM_MASK
=
0x000f000f
;
static
constexpr
uint32_t
TOP_MASK
=
0x00f000f0
;
static
constexpr
uint32_t
I4s_TO_F16s_MAGIC_NUM
=
0x64006400
;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const
uint32_t
top_i4s
=
i4s
>>
8
;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
top_i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
top_i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
static
constexpr
uint32_t
FP16_TOP_MAGIC_NUM
=
0x64086408
;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static
constexpr
uint32_t
ONE_SIXTEENTH
=
0x2c002c00
;
// This is the half2 {-72, -72} represented as an integer.
static
constexpr
uint32_t
NEG_72
=
0xd480d480
;
// Finally, we construct the output numbers.
// Convert elt_01
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
h
[
0
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
// Convert elt_23
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
h
[
1
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_72
));
// Convert elt_45
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
h
[
2
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
// Convert elt_67
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_72
));
return
result
;
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
{
return
convert
(
s
);
}
};
template
<
int
N
>
struct
FastInterleavedAndBiasedNumericArrayConverter
<
half_t
,
uint4b_t
,
N
>
{
static
constexpr
int
VEC_WIDTH
=
8
;
static_assert
(
!
(
N
%
VEC_WIDTH
),
"N must be multiple of 8."
);
using
result_type
=
Array
<
half_t
,
N
>
;
using
source_type
=
Array
<
uint4b_t
,
N
>
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
using
scalar_result_type
=
typename
result_type
::
Element
;
using
scalar_source_type
=
typename
source_type
::
Element
;
FastInterleavedAndBiasedNumericArrayConverter
<
scalar_result_type
,
scalar_source_type
,
VEC_WIDTH
>
convert_vector_
;
result_type
result
;
using
vec_result
=
Array
<
scalar_result_type
,
VEC_WIDTH
>
;
using
vec_source
=
Array
<
scalar_source_type
,
VEC_WIDTH
>
;
vec_result
*
result_ptr
=
reinterpret_cast
<
vec_result
*>
(
&
result
);
vec_source
const
*
source_ptr
=
reinterpret_cast
<
vec_source
const
*>
(
&
source
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
N
/
VEC_WIDTH
;
++
i
)
{
result_ptr
[
i
]
=
convert_vector_
(
source_ptr
[
i
]);
}
return
result
;
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
{
return
convert
(
s
);
}
};
template
<
>
struct
FastInterleavedAndBiasedNumericArrayConverter
<
bfloat16_t
,
uint4b_t
,
8
>
{
using
result_type
=
Array
<
bfloat16_t
,
8
>
;
using
source_type
=
Array
<
uint4b_t
,
8
>
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
result_type
result
;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
uint32_t
*
h
=
reinterpret_cast
<
uint32_t
*>
(
&
result
);
uint32_t
const
source_i4s
=
reinterpret_cast
<
uint32_t
const
&>
(
source
);
// First, we extract the i4s and construct an intermediate fp16 number.
static
constexpr
uint32_t
immLut
=
(
0xf0
&
0xcc
)
|
0xaa
;
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
I4s_TO_BF16s_MAGIC_NUM
=
0x43004300
;
// We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
// No shift needed for first item.
uint32_t
i4s
=
source_i4s
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
i4s
),
"n"
(
MASK
),
"n"
(
I4s_TO_BF16s_MAGIC_NUM
),
"n"
(
immLut
));
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
1
;
ii
<
result_type
::
kElements
/
2
;
++
ii
)
{
i4s
>>=
sizeof_bits
<
typename
source_type
::
Element
>::
value
;
// (i4s & 0x000f000f) | 0x43004300
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
ii
])
:
"r"
(
i4s
),
"n"
(
MASK
),
"n"
(
I4s_TO_BF16s_MAGIC_NUM
),
"n"
(
immLut
));
}
// This is the BF16 {-136, -136} represented as an integer.
static
constexpr
uint32_t
BF16_BIAS
=
0xC308C308
;
static
constexpr
uint32_t
BF16_ONE
=
0x3F803F80
;
// Finally, we construct the output numbers.
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
result_type
::
kElements
/
2
;
++
ii
)
{
// Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
asm
(
"fma.rn.bf16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
ii
])
:
"r"
(
h
[
ii
]),
"r"
(
BF16_ONE
),
"r"
(
BF16_BIAS
));
}
#else
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
arch
::
device_breakpoint
();
result
.
clear
();
// Suppress compiler warning.
#endif
return
result
;
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
{
return
convert
(
s
);
}
};
template
<
int
N
>
struct
FastInterleavedAndBiasedNumericArrayConverter
<
bfloat16_t
,
uint4b_t
,
N
>
{
static
constexpr
int
VEC_WIDTH
=
8
;
static_assert
(
!
(
N
%
VEC_WIDTH
),
"N must be multiple of 8."
);
using
result_type
=
Array
<
bfloat16_t
,
N
>
;
using
source_type
=
Array
<
uint4b_t
,
N
>
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
using
scalar_result_type
=
typename
result_type
::
Element
;
using
scalar_source_type
=
typename
source_type
::
Element
;
FastInterleavedAndBiasedNumericArrayConverter
<
scalar_result_type
,
scalar_source_type
,
VEC_WIDTH
>
convert_vector_
;
result_type
result
;
using
vec_result
=
Array
<
scalar_result_type
,
VEC_WIDTH
>
;
using
vec_source
=
Array
<
scalar_source_type
,
VEC_WIDTH
>
;
vec_result
*
result_ptr
=
reinterpret_cast
<
vec_result
*>
(
&
result
);
vec_source
const
*
source_ptr
=
reinterpret_cast
<
vec_source
const
*>
(
&
source
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
N
/
VEC_WIDTH
;
++
i
)
{
result_ptr
[
i
]
=
convert_vector_
(
source_ptr
[
i
]);
}
return
result
;
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
{
return
convert
(
s
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines new layouts needed for MoE
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/pitch_linear_coord.h"
namespace
cutlass
{
namespace
layout
{
template
<
int
RowsPerTile
,
int
ColumnsInterleaved
>
struct
ColumnMajorTileInterleave
{
static
constexpr
int
kRowsPerTile
=
RowsPerTile
;
static
constexpr
int
kColumnsInterleaved
=
ColumnsInterleaved
;
};
template
<
class
T
>
struct
IsColumnMajorTileInterleave
{
static
constexpr
bool
value
=
false
;
};
template
<
int
U
,
int
V
>
struct
IsColumnMajorTileInterleave
<
ColumnMajorTileInterleave
<
U
,
V
>>
{
static
constexpr
bool
value
=
true
;
};
}
// namespace layout
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM
quantization.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
transform
{
namespace
threadblock
{
////////////////////////////////////////////////////////////////////////////////
template
<
typename
Shape
,
typename
Element
,
typename
Layout
,
int
AdvanceRank
,
int
Alignment
>
class
FineGrainedScaleZeroIterator
;
template
<
typename
Shape_
,
typename
Element_
,
int
Alignment_
>
class
FineGrainedScaleZeroIterator
<
Shape_
,
Element_
,
layout
::
RowMajor
,
0
,
Alignment_
>
{
public:
using
Shape
=
Shape_
;
using
Element
=
Element_
;
using
Layout
=
layout
::
RowMajor
;
static
int
const
kAdvanceRank
=
0
;
static
int
const
kAlignment
=
Alignment_
;
static
int
const
kAccessesPerVector
=
1
;
/// Row index of scales corresponding to the groupsize of 64
int
row_groupsize64_
;
int
group_size_
;
using
Index
=
typename
Layout
::
Index
;
using
LongIndex
=
typename
Layout
::
LongIndex
;
using
TensorRef
=
TensorRef
<
Element
,
Layout
>
;
using
TensorView
=
TensorView
<
Element
,
Layout
>
;
using
TensorCoord
=
typename
Layout
::
TensorCoord
;
using
Pointer
=
Element
*
;
using
NonConstPointer
=
typename
platform
::
remove_const
<
Element
>::
type
*
;
using
AccessType
=
AlignedArray
<
Element
,
kAlignment
>
;
using
Fragment
=
cutlass
::
Array
<
Element
,
kAlignment
>
;
// For compatibility with existing iterator interface
struct
Params
{
LongIndex
stride_
=
0
;
/// amount (in byte) to increment pointer from first access of current tile
/// to first access of next tile
LongIndex
inc_advance_
=
0
;
// Default ctor
CUTLASS_HOST_DEVICE
Params
()
{}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params
(
Layout
const
&
layout
)
:
stride_
(
layout
.
stride
(
0
))
{
inc_advance_
=
Shape
::
kRow
*
stride_
*
sizeof_bits
<
Element
>::
value
/
8
;
}
};
private:
/// Internal pointer type permits fast address arithmetic
using
BytePointer
=
char
*
;
private:
//
// Data members
//
/// Parameters object with precomputed internal state
Params
const
params_
;
/// Internal pointer to first access of tile
BytePointer
pointer_scale_
;
BytePointer
pointer_zero_
;
bool
is_valid_
=
false
;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_DEVICE
FineGrainedScaleZeroIterator
(
///< Precomputed parameters object
Params
const
&
params
,
///< Pointer to start of scale tensor
Pointer
pointer_scale
,
///< Pointer to start of zero tensor
Pointer
pointer_zero
,
///< Extent of the scale and bias
TensorCoord
extent
,
///< ID of each participating thread
int
thread_id
,
///< Initial offset of threadblock
TensorCoord
const
&
threadblock_offset
,
///< Group size
int
group_size
)
:
params_
(
params
)
,
pointer_scale_
(
reinterpret_cast
<
BytePointer
>
(
const_cast
<
NonConstPointer
>
(
pointer_scale
)))
,
pointer_zero_
(
reinterpret_cast
<
BytePointer
>
(
const_cast
<
NonConstPointer
>
(
pointer_zero
)))
{
row_groupsize64_
=
threadblock_offset
.
row
();
group_size_
=
group_size
;
const
LongIndex
tb_row_byte_offset
=
threadblock_offset
.
row
()
/
(
group_size
/
64
)
*
params_
.
stride_
*
sizeof_bits
<
Element
>::
value
/
8
;
const
LongIndex
tb_col_byte_offset
=
threadblock_offset
.
column
()
*
sizeof_bits
<
Element
>::
value
/
8
;
pointer_scale_
+=
(
tb_row_byte_offset
+
tb_col_byte_offset
);
if
(
pointer_zero_
!=
nullptr
)
{
pointer_zero_
+=
(
tb_row_byte_offset
+
tb_col_byte_offset
);
}
static
constexpr
int
THREADS_PER_ROW
=
Shape
::
kColumn
/
kAlignment
;
int
const
thread_row
=
thread_id
/
THREADS_PER_ROW
;
int
const
thread_col
=
thread_id
%
THREADS_PER_ROW
;
const
LongIndex
thread_row_byte_offset
=
thread_row
*
params_
.
stride_
*
sizeof_bits
<
Element
>::
value
/
8
;
const
LongIndex
thread_col_byte_offset
=
thread_col
*
kAlignment
*
sizeof_bits
<
Element
>::
value
/
8
;
pointer_scale_
+=
(
thread_row_byte_offset
+
thread_col_byte_offset
);
if
(
pointer_zero_
!=
nullptr
)
{
pointer_zero_
+=
(
thread_row_byte_offset
+
thread_col_byte_offset
);
}
// For the rows, we must check that we are within the extent AND the tile to avoid extra reads on
// a given iteration. The same threads will be responsible for issues reads since the number of scales
// read in a given iteration is a constant. Therefore, we should never have to update is_valid_
// outside of the constructor.
int
const
global_row
=
threadblock_offset
.
row
()
+
thread_row
;
int
const
global_col
=
threadblock_offset
.
column
()
+
thread_col
*
kAlignment
;
bool
const
row_in_bounds
=
global_row
<
extent
.
row
()
&&
thread_row
<
Shape
::
kRow
;
bool
const
col_in_bounds
=
global_col
<
extent
.
column
();
is_valid_
=
row_in_bounds
&&
col_in_bounds
;
}
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
CUTLASS_HOST_DEVICE
FineGrainedScaleZeroIterator
(
Params
const
&
params
,
///< Precomputed parameters object
Pointer
pointer_scale
,
///< Pointer to start of scale tensor
Pointer
pointer_zero
,
///< Pointer to start of zero tensor
TensorCoord
extent
,
///< Extent of tensor
int
thread_id
,
///< ID of each participating thread
int
group_size
)
:
FineGrainedScaleZeroIterator
(
params
,
pointer_scale
,
pointer_zero
,
extent
,
thread_id
,
make_Coord
(
0
,
0
),
group_size
)
{
}
CUTLASS_DEVICE
void
add_tile_offset
(
TensorCoord
const
&
tile_offset
)
{
const
LongIndex
row_byte_offset
=
tile_offset
.
row
()
*
params_
.
inc_advance_
;
const
LongIndex
col_byte_offset
=
tile_offset
.
column
()
*
Shape
::
kColumn
*
sizeof_bits
<
Element
>::
value
/
8
;
pointer_scale_
+=
row_byte_offset
+
col_byte_offset
;
if
(
pointer_zero_
!=
nullptr
)
{
pointer_zero_
+=
row_byte_offset
+
col_byte_offset
;
}
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void
clear_mask
(
bool
enable
=
true
)
{
is_valid_
&=
(
!
enable
);
}
/// Returns whether access is valid or not
CUTLASS_HOST_DEVICE
bool
valid
()
const
{
return
is_valid_
;
}
/// Returns a scale pointer
CUTLASS_HOST_DEVICE
AccessType
*
get_scale
()
const
{
return
reinterpret_cast
<
AccessType
*>
(
pointer_scale_
);
}
/// Returns a zero pointer
CUTLASS_HOST_DEVICE
AccessType
*
get_zero
()
const
{
return
reinterpret_cast
<
AccessType
*>
(
pointer_zero_
);
}
};
}
// namespace threadblock
}
// namespace transform
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/util/print.hpp"
using
namespace
cute
;
/// Function object that applies an index to its argument
template
<
class
Iter
>
struct
IndexedGather
{
CUTE_HOST_DEVICE
constexpr
IndexedGather
(
Iter
indices
=
{})
:
indices_
(
indices
)
{
}
template
<
typename
I
>
CUTE_HOST_DEVICE
constexpr
auto
operator
()(
I
i
)
const
{
return
indices_
[
i
];
}
CUTE_HOST_DEVICE
friend
void
print
(
IndexedGather
const
&
s
)
{
cute
::
print
(
"Indexed{"
);
print
(
s
.
indices_
);
print
(
"}"
);
}
Iter
indices_
;
};
/// Custom stride object that applies a function followed by a stride
template
<
class
Func
,
class
Stride
>
struct
CustomStride
{
CUTE_HOST_DEVICE
constexpr
CustomStride
(
Func
const
&
func
,
Stride
const
&
stride
)
:
func_
(
func
)
,
stride_
(
stride
)
{
}
template
<
class
I
>
CUTE_HOST_DEVICE
constexpr
friend
auto
operator
*
(
I
i
,
CustomStride
const
&
s
)
{
return
s
.
func_
(
i
)
*
s
.
stride_
;
}
template
<
class
I
>
CUTE_HOST_DEVICE
constexpr
friend
auto
operator
*
(
CustomStride
const
&
s
,
I
i
)
{
return
s
.
func_
(
i
)
*
s
.
stride_
;
}
CUTE_HOST_DEVICE
friend
void
print
(
CustomStride
const
&
s
)
{
cute
::
print
(
"Custom{"
);
print
(
s
.
func_
);
cute
::
print
(
","
);
print
(
s
.
stride_
);
cute
::
print
(
"}"
);
}
template
<
class
Div
>
CUTE_HOST_DEVICE
constexpr
friend
auto
safe_div
(
CustomStride
const
&
s
,
Div
const
&
div
)
{
return
CustomStride
<
Func
,
decltype
(
safe_div
(
s
.
stride_
,
div
))
>
(
s
.
func_
,
safe_div
(
s
.
stride_
,
div
));
}
// Circumvent the requirement on make_layout that shape and stride are integral
template
<
class
Shape
>
CUTE_HOST_DEVICE
constexpr
friend
auto
make_layout
(
Shape
const
&
shape
,
CustomStride
const
&
stride
)
{
return
Layout
<
Shape
,
CustomStride
>
(
shape
,
stride
);
}
Func
func_
;
Stride
stride_
;
};
template
<
class
Stride
,
class
Func
>
CUTLASS_HOST_DEVICE
auto
make_custom_stride_layout
(
Stride
const
&
stride
,
Func
&&
func
)
{
// Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride
auto
idx
=
find_if
(
stride
,
[](
auto
x
)
{
return
!
is_constant
<
1
,
decltype
(
x
)
>
{}
&&
!
is_constant
<
0
,
decltype
(
x
)
>
{};
});
constexpr
int
I
=
decltype
(
idx
)
::
value
;
return
make_layout
(
repeat_like
(
stride
,
_1
{}),
replace
<
I
>
(
stride
,
CustomStride
{
static_cast
<
Func
&&>
(
func
),
get
<
I
>
(
stride
)}));
}
/// Helper function to optionally create a gather tensor
template
<
class
Iterator
,
class
Shape
,
class
Stride
,
class
Func
>
CUTLASS_HOST_DEVICE
auto
make_gather_tensor
(
Iterator
iter
,
Shape
const
&
shape
,
Stride
const
&
stride
,
Func
&&
func
)
{
Layout
matrix_layout
=
make_identity_layout
(
shape
);
auto
offset
=
as_arithmetic_tuple
(
repeat_like
(
shape
,
_0
{}));
Layout
gather_layout
=
make_custom_stride_layout
(
stride
,
static_cast
<
Func
&&>
(
func
));
return
make_tensor
(
iter
,
ComposedLayout
{
gather_layout
,
offset
,
matrix_layout
});
}
namespace
cute
{
template
<
int
N
,
int
I
,
class
Shape
,
class
Stride
>
CUTE_HOST_DEVICE
constexpr
auto
upcast
(
Shape
const
&
shape
,
Stride
const
&
stride
)
{
if
constexpr
(
is_tuple
<
Shape
>::
value
)
{
return
transform_layout
(
shape
,
stride
,
[](
auto
const
&
s
,
auto
const
&
d
)
{
return
upcast
<
N
,
I
>
(
s
,
d
);
});
}
else
if
constexpr
(
is_scaled_basis
<
Stride
>::
value
)
{
if
constexpr
(
Stride
::
mode
()
==
I
)
{
return
make_layout
(
shape_div
(
shape
,
Int
<
N
>
{}),
shape_div
(
stride
,
Int
<
N
>
{}));
}
else
{
return
make_layout
(
shape
,
stride
);
}
}
else
{
return
upcast
<
N
>
(
shape
,
stride
);
}
CUTE_GCC_UNREACHABLE
;
}
template
<
int
N
,
class
OuterShape
,
class
OuterStride
,
class
Offset
,
class
Shape
,
class
Stride
>
CUTE_HOST_DEVICE
constexpr
auto
upcast
(
ComposedLayout
<
Layout
<
OuterShape
,
OuterStride
>
,
Offset
,
Layout
<
Shape
,
Stride
>>
const
&
layout
)
{
// Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset
auto
idx
=
find_if
(
layout
.
layout_a
().
stride
(),
[](
auto
x
)
{
return
is_constant
<
1
,
decltype
(
x
)
>
{};
});
constexpr
int
I
=
decltype
(
idx
)
::
value
;
// Upcast the outer layout (works as expected)
auto
outer
=
upcast
<
N
>
(
layout
.
layout_a
());
// Upcast the accumulated offset along stride-1 mode
auto
offset
=
as_arithmetic_tuple
(
replace
<
I
>
(
layout
.
offset
(),
upcast
<
N
>
(
get
<
I
>
(
layout
.
offset
()))));
// Upcast the inner layout's shape along stride-1 mode
auto
inner
=
upcast
<
N
,
I
>
(
layout
.
layout_b
().
shape
(),
layout
.
layout_b
().
stride
());
return
composition
(
outer
,
offset
,
inner
);
}
}
// namespace cute
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
*/
#pragma once
namespace
cutlass
{
enum
class
WeightOnlyQuantOp
{
UNDEFINED
,
PER_COLUMN_SCALE_ONLY
,
FINEGRAINED_SCALE_ONLY
,
FINEGRAINED_SCALE_AND_ZEROS
};
constexpr
bool
isFinegrained
(
WeightOnlyQuantOp
op
)
{
return
op
==
WeightOnlyQuantOp
::
FINEGRAINED_SCALE_AND_ZEROS
||
op
==
WeightOnlyQuantOp
::
FINEGRAINED_SCALE_ONLY
;
}
constexpr
bool
hasZero
(
WeightOnlyQuantOp
op
)
{
return
op
==
WeightOnlyQuantOp
::
FINEGRAINED_SCALE_AND_ZEROS
;
}
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
namespace
tensorrt_llm
::
kernels
::
cutlass_kernels
{
template
<
typename
ElementType_
,
typename
CutlassWeightType_
,
int
MaxTileM_
,
int
TileN_
,
int
TileK_
,
int
Stages_
,
typename
EpilogueTag
>
void
sm80_generic_fused_moe_gemm_kernelLauncher
(
ElementType_
const
*
A
,
CutlassWeightType_
const
*
B
,
ElementType_
const
*
biases
,
bool
bias_is_broadcast
,
ElementType_
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
num_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
int
multi_processor_count
,
cudaStream_t
stream
,
int
*
kernel_occupancy
);
}
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include <cutlass_extensions/epilogue_helpers.h>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh>
#include <tensorrt_llm/common/cudaUtils.h>
namespace tensorrt_llm::kernels::cutlass_kernels
{
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
typename EpilogueTag>
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
int* kernel_occupancy)
{
constexpr auto activation_type = fused_moe::EpilogueRouting<EpilogueTag>(true);
using GemmType = fused_moe::Fused_Moe_Kernel_sm80<ElementType_, CutlassWeightType_, ElementType_, MaxTileM_, TileN_,
TileK_, Stages_, activation_type>;
// make sure GPU has enough resources..
if (kernel_occupancy != nullptr)
{
constexpr int smem_size = GemmType::kSmemSize;
if (smem_size > (48 << 10))
{
cudaFuncAttributes attr{};
int device = 0;
int max_smem_per_block = 0;
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
tensorrt_llm::common::check_cuda_error(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global<GemmType>));
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
{
// This should mean that
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize,
// smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the
// heuristic to ignore this configuration.
*kernel_occupancy = 0;
return;
}
}
int max_active_blocks = -1;
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, fused_moe::run_global<GemmType>, GemmType::kThreadCount, smem_size));
*kernel_occupancy = max_active_blocks;
return;
}
int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks<GemmType>());
int const threadblock_count = multi_processor_count * occupancy;
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel");
using Arguments = typename GemmType::Arguments;
Arguments args{{const_cast<ElementType_*>(A), const_cast<CutlassWeightType_*>(B), const_cast<ElementType_*>(biases),
reinterpret_cast<ElementType_*>(C), total_tokens_including_expert, static_cast<int>(gemm_n),
static_cast<int>(gemm_k), num_experts, bias_is_broadcast},
num_experts, threadblock_count};
auto params = GemmType::to_underlying_arguments(args);
if (GemmType::kSmemSize >= (48 << 10))
{
cudaError_t result = cudaFuncSetAttribute(
fused_moe::run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize);
TLLM_CHECK_WITH_INFO(result == cudaSuccess,
"Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel");
}
dim3 grid(params.threadblock_count, 1, 1);
dim3 block(GemmType::kThreadCount);
fused_moe::run_global<GemmType><<<grid, block, GemmType::kSmemSize, stream>>>(params);
auto result = cudaGetLastError();
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result));
}
} // namespace tensorrt_llm::kernels::cutlass_kernels
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda_runtime_api.h>
namespace
tensorrt_llm
{
namespace
kernels
{
namespace
cutlass_kernels
{
// Keep in sync with the signature generated by generate_kernels.py
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
EpilogueTag
,
HopperGroupedGemmInput
::
EpilogueFusion
FUSION
,
typename
TileShape
,
typename
ClusterShape
,
bool
BIAS
>
void
sm90_generic_moe_gemm_kernelLauncher
(
HopperGroupedGemmInput
hopper_input
,
int
num_experts
,
int
multi_processor_count
,
cudaStream_t
stream
,
int
*
kernel_occupancy
,
size_t
*
workspace_size
);
}
// namespace cutlass_kernels
}
// namespace kernels
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion;
// Hopper helper class for defining all the cutlass helper types
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, typename TileShape,
typename ClusterShape, bool BIAS, EpilogueFusion FUSION>
struct HopperGroupedGemmInfo
{
using Arch = cutlass::arch::Sm90;
// TODO Update once mixed input support is added
static_assert(cutlass::platform::is_same<T, WeightType>::value,
"CUTLASS does not currently have specialised SM90 support for quantized operations");
#ifdef ENABLE_FP8
constexpr static bool IsFP8
= cutlass::platform::is_same<T, __nv_fp8_e4m3>::value || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value;
#else
constexpr static bool IsFP8 = false;
#endif
#ifdef ENABLE_BF16
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for bfloat16, half, float, fp8");
#else
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for half, float, fp8");
#endif
static_assert(cutlass::platform::is_same<T, WeightType>::value
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e4m3_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e5m2_t>::value,
"Unexpected quantization type");
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter<WeightType>::type;
// For legacy reasons we convert unsigned 8-bit to signed
using CutlassWeightTypeMaybeUint8
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint4, cutlass::uint4b_t>, cutlass::int4b_t,
CutlassWeightTypeMaybeUint4>;
using CutlassWeightType
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint8, uint8_t>, int8_t, CutlassWeightTypeMaybeUint8>;
using ElementA = ElementType;
using ElementB = CutlassWeightType;
using ElementD = typename TllmToCutlassTypeAdapter<HopperGroupedGemmInput::OutputTypeAdaptor_t<OutputType>>::type;
using ElementFinalOutput = typename TllmToCutlassTypeAdapter<OutputType>::type;
// using ElementC = std::conditional_t<BIAS, ElementType, void>;
// using ElementCNoVoid = std::conditional_t<BIAS, ElementType, ElementD>;
using ElementC = void;
using ElementCNoVoid = ElementD;
using ElementAccumulator = float;
using ElementBias = ElementFinalOutput;
using ElementRouterScales = float;
// A matrix configuration - this is transposed and swapped with B
using LayoutA = HopperGroupedGemmInput::LayoutA;
constexpr static int AlignmentA
= 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units
// of elements (up to 16 bytes)
// B matrix configuration - this is transposed and swapped with A
using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand
constexpr static int AlignmentB
= 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units
// of elements (up to 16 bytes)
// C matrix configuration
using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand
using StrideC = HopperGroupedGemmInput::StrideC;
// Note we use ElementType here deliberately, so we don't break when BIAS is disabled
constexpr static int AlignmentC
= 128 / cutlass::sizeof_bits<ElementType>::value; // Memory access granularity/alignment of C matrix in units
// of elements (up to 16 bytes)
// D matrix configuration
using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD;
using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD;
constexpr static int AlignmentD
= 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix
// in units of elements (up to 16 bytes)
static_assert(cutlass::platform::is_same<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpDefault>::value,
"Hopper Grouped GEMM specialisation doesn't support fused activation");
using EpilogueOp
= cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
// TODO Add mode for fused activation once CUTLASS adds support
// using EpilogueSchedule = cutlass::platform::conditional_t<
// cutlass::platform::is_same<EpilogueOp, EpilogueOpDefault>::value,
// cutlass::epilogue::PtrArrayNoSmemWarpSpecialized,
// cutlass::epilogue::?????????????????? /// <<<<<< what supports activations
// >;
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;
// Epilogue For Default Finalize
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
TileShape, ClusterShape, //
cutlass::epilogue::collective::EpilogueTileAuto, //
ElementAccumulator, ElementAccumulator, //
ElementC, LayoutC*, AlignmentC, //
ElementD, LayoutD*, AlignmentD, //
EpilogueSchedule>::CollectiveOp;
// Epilogue For Fused Finalize
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< //
TileShape, //
ElementCNoVoid, StrideC*, //
ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, //
ElementAccumulator, //
ElementAccumulator, //
ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, //
ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales //
>::CollectiveOp;
using CollectiveEpilogue
= std::conditional_t<FUSION == EpilogueFusion::FINALIZE, CollectiveEpilogueFinalize, CollectiveEpilogueDefault>;
using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>;
using KernelSchedule
= std::conditional_t<IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here
ElementType, LayoutA*, AlignmentA, //
ElementAccumulator, //
TileShape, ClusterShape, //
StageCountAutoCarveout, KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<HopperGroupedGemmInput::ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
// Hopper specialised version
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename TileShape, typename ClusterShape, bool BIAS>
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size)
{
#ifdef COMPILE_HOPPER_TMA_GEMMS
using namespace cute;
if constexpr (!should_filter_sm90_gemm_problem_shape_v<TileShape, ClusterShape, T>)
{
using GemmInfo
= HopperGroupedGemmInfo<T, WeightType, OutputType, EpilogueTag, TileShape, ClusterShape, BIAS, FUSION>;
using ElementAccumulator = typename GemmInfo::ElementAccumulator;
using ElementA = typename GemmInfo::ElementA;
using ElementB = typename GemmInfo::ElementB;
using ElementC = typename GemmInfo::ElementC;
using ElementCNoVoid = typename GemmInfo::ElementCNoVoid;
using ElementD = typename GemmInfo::ElementD;
using CollectiveMainloop = typename GemmInfo::CollectiveMainloop;
using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue;
using GemmKernel = typename GemmInfo::GemmKernel;
using GemmGrouped = typename GemmInfo::GemmGrouped;
if (kernel_occupancy != nullptr)
{
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel, true>();
return;
}
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = multi_processor_count;
GemmGrouped gemm;
if (workspace_size != nullptr)
{
// Make a mock problem shape with just the minimal information actually required to get the workspace size
// This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to
// catch future cutlass updates causing silent breakages, but that is not fool proof.
// The alternative is to wait until we have data and then dynamically allocate the workspace
typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr};
typename GemmGrouped::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info};
*workspace_size = gemm.get_workspace_size(args);
return;
}
using MainloopArguments = typename CollectiveMainloop::Arguments;
TLLM_CHECK(hopper_input.stride_a);
TLLM_CHECK(hopper_input.stride_b);
TLLM_CHECK(hopper_input.ptr_a);
TLLM_CHECK(hopper_input.ptr_b);
MainloopArguments const mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
hopper_input.stride_b, reinterpret_cast<ElementA const**>(hopper_input.ptr_a), hopper_input.stride_a};
typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{
ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)};
epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
// TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS
auto make_epi_args = [&]()
{
if constexpr (FUSION == EpilogueFusion::NONE)
{
auto epi_params = hopper_input.default_epilogue;
return EpilogueArguments{epilogue_scalars, reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c),
hopper_input.stride_c, reinterpret_cast<ElementD**>(epi_params.ptr_d), epi_params.stride_d};
}
else if constexpr (FUSION == EpilogueFusion::FINALIZE)
{
// Parameters for fused finalize
auto epi_params = hopper_input.fused_finalize_epilogue;
return EpilogueArguments{
epilogue_scalars, // Parameters to underlying epilogue
reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c), hopper_input.stride_c, // C params
reinterpret_cast<typename GemmInfo::ElementFinalOutput*>(epi_params.ptr_final_output),
epi_params.stride_final_output, // D (output) params
reinterpret_cast<typename GemmInfo::ElementBias const*>(epi_params.ptr_bias),
epi_params.stride_bias, // Bias params
epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales
epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales
epi_params.ptr_source_token_index, // Index of the source token to sum into
epi_params.num_rows_in_final_output // Number of tokens in the output buffer
};
}
else
{
static_assert(
sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher");
}
};
EpilogueArguments const epilogue_params = make_epi_args();
typename GemmKernel::TileScheduler::Arguments scheduler_args{
1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN};
typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info,
mainloop_params, epilogue_params, hw_info, scheduler_args};
size_t calculated_ws_size = gemm.get_workspace_size(args);
TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size,
"Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size);
auto can_implement = gemm.can_implement(args);
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
"Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
auto init_status = gemm.initialize(args, hopper_input.gemm_workspace);
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
"Failed to initialize cutlass SM90 grouped gemm. Error: "
+ std::string(cutlassGetStatusString(init_status)));
auto run_status = gemm.run(stream);
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
"Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
sync_check_cuda_error();
}
else
{
TLLM_THROW("Configuration was disabled by FAST_BUILD");
}
#else // COMPILE_HOPPER_TMA_GEMMS
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
#endif // COMPILE_HOPPER_TMA_GEMMS
}
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment