Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3ee62235
"configs/_base_/datasets/sunrgbd-3d.py" did not exist on "fc9e0d9dab0c0cfc14241b990c1cce6ea7a750c0"
Unverified
Commit
3ee62235
authored
Jan 31, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 31, 2025
Browse files
revert the MoE dependence (#3230)
parent
9829e77e
Changes
94
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
0 additions
and
1642 deletions
+0
-1642
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu
...kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu
+0
-131
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h
...t_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h
+0
-230
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu
...ls/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu
+0
-24
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu
...s/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu
+0
-24
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu
...s/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu
+0
-24
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu
...ls/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu
+0
-22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu
...s/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu
+0
-22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu
...s/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu
+0
-22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu
...ls/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu
+0
-22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu
...nels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu
+0
-28
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h
...nels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h
+0
-823
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h
...cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h
+0
-222
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h
...rt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h
+0
-44
sgl-kernel/setup.py
sgl-kernel/setup.py
+0
-4
No files found.
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/conv/convolution.h"
// Order matters here, packed_stride.hpp is missing cute and convolution includes
#include "cutlass/util/packed_stride.hpp"
#include "tensorrt_llm/common/logger.h"
namespace
tensorrt_llm
{
std
::
array
<
size_t
,
10
>
HopperGroupedGemmInput
::
workspaceBuffers
(
int
num_experts
)
{
size_t
problem_shape_size
=
sizeof
(
ProblemShape
::
UnderlyingProblemShape
)
*
num_experts
;
size_t
stride_a_size
=
sizeof
(
StrideA
)
*
num_experts
;
size_t
stride_b_size
=
sizeof
(
StrideB
)
*
num_experts
;
size_t
stride_c_size
=
sizeof
(
StrideC
)
*
num_experts
;
size_t
stride_d_size
=
sizeof
(
DefaultEpilogue
::
StrideD
)
*
num_experts
;
size_t
ptr_buf_size
=
sizeof
(
void
*
)
*
num_experts
;
size_t
scale_buf_size
=
sizeof
(
float
*
)
*
num_experts
;
return
std
::
array
{
problem_shape_size
,
stride_a_size
,
stride_b_size
,
stride_c_size
,
stride_d_size
,
ptr_buf_size
,
ptr_buf_size
,
ptr_buf_size
,
ptr_buf_size
,
scale_buf_size
};
}
size_t
HopperGroupedGemmInput
::
workspaceSize
(
int
num_experts
)
{
auto
buffers
=
workspaceBuffers
(
num_experts
);
return
tensorrt_llm
::
common
::
calculateTotalWorkspaceSize
(
buffers
.
data
(),
buffers
.
size
());
}
void
HopperGroupedGemmInput
::
configureWorkspace
(
int8_t
*
start_ptr
,
int
num_experts
,
void
*
gemm_workspace
,
size_t
gemm_workspace_size
)
{
auto
buffers
=
workspaceBuffers
(
num_experts
);
std
::
array
<
int8_t
*
,
10
>
pointers
{};
TLLM_CHECK_WITH_INFO
(
pointers
.
size
()
==
buffers
.
size
(),
"Mismatching workspace size and number of buffers"
);
for
(
int
i
=
0
;
i
<
buffers
.
size
();
i
++
)
{
pointers
[
i
]
=
start_ptr
;
start_ptr
=
tensorrt_llm
::
common
::
nextWorkspacePtr
(
start_ptr
,
buffers
[
i
]);
}
shape_info
.
num_groups
=
num_experts
;
shape_info
.
problem_shapes
=
reinterpret_cast
<
ProblemShape
::
UnderlyingProblemShape
*>
(
pointers
[
0
]);
shape_info
.
host_problem_shapes
=
nullptr
;
stride_a
=
reinterpret_cast
<
StrideA
*>
(
pointers
[
1
]);
stride_b
=
reinterpret_cast
<
StrideB
*>
(
pointers
[
2
]);
stride_c
=
reinterpret_cast
<
StrideC
*>
(
pointers
[
3
]);
default_epilogue
.
stride_d
=
reinterpret_cast
<
DefaultEpilogue
::
StrideD
*>
(
pointers
[
4
]);
ptr_a
=
reinterpret_cast
<
void
const
**>
(
pointers
[
5
]);
ptr_b
=
reinterpret_cast
<
void
const
**>
(
pointers
[
6
]);
ptr_c
=
reinterpret_cast
<
void
const
**>
(
pointers
[
7
]);
default_epilogue
.
ptr_d
=
reinterpret_cast
<
void
**>
(
pointers
[
8
]);
alpha_scale_ptr_array
=
reinterpret_cast
<
float
const
**>
(
pointers
[
9
]);
this
->
gemm_workspace
=
reinterpret_cast
<
uint8_t
*>
(
gemm_workspace
);
this
->
gemm_workspace_size
=
gemm_workspace_size
;
}
void
HopperGroupedGemmInput
::
setFinalizeFusionParams
(
void
*
final_output
,
float
const
*
router_scales
,
int64_t
const
*
expert_first_token_offset
,
int
const
*
source_token_index
,
void
const
*
bias
,
int
hidden_size
,
int
num_output_tokens
)
{
fused_finalize_epilogue
.
ptr_final_output
=
final_output
;
fused_finalize_epilogue
.
ptr_router_scales
=
router_scales
;
fused_finalize_epilogue
.
ptr_bias
=
bias
;
fused_finalize_epilogue
.
ptr_expert_first_token_offset
=
expert_first_token_offset
;
fused_finalize_epilogue
.
ptr_source_token_index
=
source_token_index
;
fused_finalize_epilogue
.
stride_final_output
=
cutlass
::
make_cute_packed_stride
(
FusedFinalizeEpilogue
::
StrideFinalOutput
{},
transpose_stride
(
cute
::
make_shape
(
num_output_tokens
,
hidden_size
,
1
)));
fused_finalize_epilogue
.
stride_bias
=
transpose_stride
(
cute
::
make_stride
(
cute
::
Int
<
0
>
{},
cute
::
Int
<
1
>
{},
hidden_size
));
fused_finalize_epilogue
.
stride_router_scales
=
{};
fused_finalize_epilogue
.
num_rows_in_final_output
=
num_output_tokens
;
}
std
::
string
HopperGroupedGemmInput
::
toString
()
const
{
std
::
stringstream
ss
;
ss
<<
"Hopper Input Information: "
<<
(
isValid
()
?
"valid"
:
"null"
)
<<
"
\n
"
;
if
(
isValid
())
{
ss
<<
"Ptr A: "
<<
ptr_a
<<
", Ptr B: "
<<
ptr_b
<<
", Ptr C: "
<<
ptr_c
<<
"
\n
"
;
ss
<<
"Epilogue Fusion: "
<<
(
int
)
fusion
;
if
(
fusion
==
HopperGroupedGemmInput
::
EpilogueFusion
::
FINALIZE
)
{
ss
<<
",
\n
Final Output: "
<<
fused_finalize_epilogue
.
ptr_final_output
;
ss
<<
" with Stride: "
<<
fused_finalize_epilogue
.
stride_router_scales
;
ss
<<
",
\n
Bias: "
<<
fused_finalize_epilogue
.
ptr_bias
;
ss
<<
" with Stride: "
<<
fused_finalize_epilogue
.
stride_bias
;
ss
<<
",
\n
Router Scales: "
<<
fused_finalize_epilogue
.
ptr_router_scales
;
ss
<<
" with Stride: "
<<
fused_finalize_epilogue
.
stride_router_scales
;
ss
<<
",
\n
Expert Offset: "
<<
fused_finalize_epilogue
.
ptr_expert_first_token_offset
;
ss
<<
", Source Map: "
<<
fused_finalize_epilogue
.
ptr_source_token_index
;
}
else
{
ss
<<
", Ptr D: "
<<
default_epilogue
.
ptr_d
;
}
ss
<<
'\n'
;
ss
<<
"Alpha scale ptr: "
<<
alpha_scale_ptr_array
<<
"
\n
"
;
}
return
ss
.
str
();
}
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/workspace.h"
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h"
#include <array>
#include <cuda_runtime_api.h>
#include <optional>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/layout/layout.h"
namespace
tensorrt_llm
{
template
<
class
T
>
constexpr
auto
transpose_stride
(
T
const
&
t
)
{
return
cute
::
prepend
(
cute
::
prepend
(
cute
::
take
<
2
,
cute
::
rank_v
<
T
>>
(
t
),
cute
::
get
<
0
>
(
t
)),
cute
::
get
<
1
>
(
t
));
}
struct
HopperGroupedGemmInput
{
template
<
class
T
>
using
TransposeStride
=
decltype
(
transpose_stride
<
T
>
(
T
{}));
template
<
class
Tag
>
using
TransposeLayoutTag
=
std
::
conditional_t
<
std
::
is_same_v
<
Tag
,
cutlass
::
layout
::
RowMajor
>
,
cutlass
::
layout
::
ColumnMajor
,
cutlass
::
layout
::
RowMajor
>
;
static_assert
(
std
::
is_same_v
<
cutlass
::
layout
::
RowMajor
,
TransposeLayoutTag
<
cutlass
::
layout
::
ColumnMajor
>>
);
static_assert
(
std
::
is_same_v
<
cutlass
::
layout
::
ColumnMajor
,
TransposeLayoutTag
<
cutlass
::
layout
::
RowMajor
>>
);
// Layout for A and B is transposed and then swapped in the implementation
// This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM
using
LayoutA
=
TransposeLayoutTag
<
cutlass
::
layout
::
RowMajor
>
;
// Layout type for A matrix operand
using
LayoutB
=
TransposeLayoutTag
<
cutlass
::
layout
::
ColumnMajor
>
;
// Layout type for B matrix operand
using
LayoutC
=
TransposeLayoutTag
<
cutlass
::
layout
::
RowMajor
>
;
// Layout type for C matrix operand
using
StrideA
=
std
::
remove_pointer_t
<
cutlass
::
detail
::
TagToStrideB_t
<
LayoutA
*>>
;
// Use B because they will be swapped
using
StrideB
=
std
::
remove_pointer_t
<
cutlass
::
detail
::
TagToStrideA_t
<
LayoutB
*>>
;
// Use A because they will be swapped
using
StrideC
=
std
::
remove_pointer_t
<
cutlass
::
detail
::
TagToStrideC_t
<
LayoutC
*>>
;
template
<
class
T
>
constexpr
static
bool
IsFP8_v
=
std
::
is_same_v
<
T
,
__nv_fp8_e4m3
>
||
std
::
is_same_v
<
T
,
__nv_fp8_e5m2
>
;
// Currently this should always just be T
template
<
class
T
>
using
OutputTypeAdaptor_t
=
std
::
conditional_t
<
IsFP8_v
<
T
>
,
nv_bfloat16
,
T
>
;
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
cute
::
Shape
<
int64_t
,
int64_t
,
int64_t
>>
;
ProblemShape
shape_info
{};
StrideA
*
stride_a
=
nullptr
;
StrideB
*
stride_b
=
nullptr
;
void
const
**
ptr_a
=
nullptr
;
void
const
**
ptr_b
=
nullptr
;
// C is currently the same in both epilogues
StrideC
*
stride_c
=
nullptr
;
void
const
**
ptr_c
=
nullptr
;
struct
DefaultEpilogue
{
using
LayoutD
=
TransposeLayoutTag
<
cutlass
::
layout
::
RowMajor
>
;
// Layout type for D matrix operand
using
StrideD
=
std
::
remove_pointer_t
<
cutlass
::
detail
::
TagToStrideC_t
<
LayoutD
*>>
;
StrideD
*
stride_d
=
nullptr
;
void
**
ptr_d
=
nullptr
;
};
struct
FusedFinalizeEpilogue
{
using
StrideFinalOutput
=
DefaultEpilogue
::
StrideD
;
using
StrideBias
=
TransposeStride
<
cute
::
Stride
<
cute
::
_0
,
cute
::
_1
,
int
>>
;
using
StrideRouterScales
=
TransposeStride
<
cute
::
Stride
<
cute
::
_1
,
cute
::
_0
>>
;
void
*
ptr_final_output
=
nullptr
;
StrideFinalOutput
stride_final_output
{};
void
const
*
ptr_bias
=
nullptr
;
StrideBias
stride_bias
{};
float
const
*
ptr_router_scales
=
nullptr
;
StrideRouterScales
stride_router_scales
{};
int64_t
const
*
ptr_expert_first_token_offset
=
nullptr
;
int
const
*
ptr_source_token_index
=
nullptr
;
size_t
num_rows_in_final_output
=
0
;
};
DefaultEpilogue
default_epilogue
;
FusedFinalizeEpilogue
fused_finalize_epilogue
;
enum
class
EpilogueFusion
{
NONE
,
ACTIVATION
,
GATED_ACTIVATION
,
FINALIZE
};
EpilogueFusion
fusion
=
EpilogueFusion
::
NONE
;
float
const
**
alpha_scale_ptr_array
=
nullptr
;
uint8_t
*
gemm_workspace
=
nullptr
;
size_t
gemm_workspace_size
=
0
;
static
std
::
array
<
size_t
,
10
>
workspaceBuffers
(
int
num_experts
);
static
size_t
workspaceSize
(
int
num_experts
);
void
configureWorkspace
(
int8_t
*
start_ptr
,
int
num_experts
,
void
*
gemm_workspace
,
size_t
gemm_workspace_size
);
bool
isValid
()
const
{
return
stride_a
!=
nullptr
&&
ptr_a
!=
nullptr
;
}
void
setFinalizeFusionParams
(
void
*
final_output
,
float
const
*
router_scales
,
int64_t
const
*
expert_first_token_offset
,
int
const
*
source_token_index
,
void
const
*
bias
,
int
hidden_size
,
int
num_output_tokens
);
std
::
string
toString
()
const
;
};
// Note update moe.py to match
enum
class
ActivationType
{
Gelu
=
0
,
Relu
,
Silu
,
Swiglu
,
Geglu
,
Identity
,
InvalidType
};
constexpr
bool
isGatedActivation
(
ActivationType
activation_type
)
{
return
activation_type
==
ActivationType
::
Swiglu
||
activation_type
==
ActivationType
::
Geglu
;
}
template
<
typename
T
,
/*The type used for activations/scales/compute*/
typename
WeightType
,
/* The type for the MoE weights */
typename
OutputType
,
/* The output type for the GEMM */
typename
ScaleBiasType
=
OutputType
/* The type for the scales/bias */
>
class
MoeGemmRunner
{
public:
MoeGemmRunner
();
#if defined(ENABLE_FP8)
static
constexpr
bool
use_fp8
=
std
::
is_same_v
<
T
,
__nv_fp8_e4m3
>
||
std
::
is_same_v
<
T
,
__nv_fp8_e5m2
>
;
#else
static
constexpr
bool
use_fp8
=
false
;
#endif
void
moeGemmBiasAct
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
layout_info
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
ActivationType
activation_type
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
);
void
moeGemm
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
layout_info
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
);
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
getConfigs
()
const
;
static
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
getConfigs
(
int
sm
);
static
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
getHopperConfigs
(
int
sm
);
static
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
getAmpereConfigs
(
int
sm
);
[[
nodiscard
]]
bool
isHopperSpecialised
(
cutlass_extensions
::
CutlassGemmConfig
gemm_config
)
const
;
[[
nodiscard
]]
bool
supportsHopperSpecialisation
()
const
;
[[
nodiscard
]]
bool
isFusedGatedActivation
(
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
bool
is_gated_activation
,
int
gemm_n
,
int
gemm_k
)
const
;
[[
nodiscard
]]
bool
supportsFusedGatedActivation
(
bool
is_gated_activation
,
int
gemm_n
,
int
gemm_k
)
const
;
size_t
getMaxWorkspaceSize
(
int
num_experts
)
const
;
[[
nodiscard
]]
int
getSM
()
const
;
private:
template
<
typename
EpilogueTag
>
void
dispatchToArch
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
layout_info
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
);
template
<
typename
EpilogueTag
>
void
runGemm
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
layout_info
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
);
private:
int
sm_
{};
int
multi_processor_count_
{};
mutable
int
num_experts_
=
0
;
mutable
size_t
gemm_workspace_size_
=
0
;
size_t
calcMaxWorkspaceSize
(
int
num_experts
)
const
;
};
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
#ifdef ENABLE_BF16
template
class
MoeGemmRunner
<
__nv_bfloat16
,
__nv_bfloat16
,
__nv_bfloat16
>;
#endif
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
#ifdef ENABLE_BF16
template
class
MoeGemmRunner
<
__nv_bfloat16
,
cutlass
::
uint4b_t
,
__nv_bfloat16
>;
#endif
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
#ifdef ENABLE_BF16
template
class
MoeGemmRunner
<
__nv_bfloat16
,
uint8_t
,
__nv_bfloat16
>;
#endif
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
template
class
MoeGemmRunner
<
half
,
half
,
half
>;
}
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
template
class
MoeGemmRunner
<
half
,
cutlass
::
uint4b_t
,
half
>;
}
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
template
class
MoeGemmRunner
<
half
,
uint8_t
,
half
>;
}
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
template
class
MoeGemmRunner
<
float
,
float
,
float
>;
}
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
#ifdef ENABLE_FP8
template
class
MoeGemmRunner
<
__nv_fp8_e4m3
,
__nv_fp8_e4m3
,
half
>;
#ifdef ENABLE_BF16
template
class
MoeGemmRunner
<
__nv_fp8_e4m3
,
__nv_fp8_e4m3
,
__nv_bfloat16
>;
#endif
// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>;
#endif
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#ifdef __GNUC__ // Restore GCC-specific diagnostics
#pragma GCC diagnostic pop
#endif
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "moe_gemm_kernels_template_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include <tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace
tensorrt_llm
{
namespace
kernels
::
cutlass_kernels
{
// ============================= Variable batched Gemm things ===========================
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
ThreadblockShape
,
typename
WarpShape
,
int
Stages
>
void
genericMoeGemmKernelLauncher
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
num_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
const
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
kernel_occupancy
=
nullptr
)
{
#if defined(ENABLE_FP8)
static_assert
(
cutlass
::
platform
::
is_same
<
T
,
__nv_bfloat16
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
half
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
__nv_fp8_e4m3
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
__nv_fp8_e5m2
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
float
>::
value
,
"Specialized for fp8, bfloat16, half, float"
);
#elif defined(ENABLE_BF16)
static_assert
(
cutlass
::
platform
::
is_same
<
T
,
__nv_bfloat16
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
half
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
float
>::
value
,
"Specialized for bfloat16, half, float"
);
#else
static_assert
(
cutlass
::
platform
::
is_same
<
T
,
half
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
float
>::
value
,
"Specialized for half, float"
);
#endif
static_assert
(
cutlass
::
platform
::
is_same
<
T
,
WeightType
>::
value
||
cutlass
::
platform
::
is_same
<
WeightType
,
uint8_t
>::
value
||
cutlass
::
platform
::
is_same
<
WeightType
,
cutlass
::
uint4b_t
>::
value
,
""
);
static_assert
(
!
cutlass
::
platform
::
is_same
<
arch
,
cutlass
::
arch
::
Sm90
>::
value
,
"Sm90 architecture should use specialised kernels"
);
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using
ElementType
=
typename
TllmToCutlassTypeAdapter
<
T
>::
type
;
using
CutlassGemmOutputType
=
typename
TllmToCutlassTypeAdapter
<
GemmOutputType
>::
type
;
using
CutlassWeightType
=
typename
TllmToCutlassTypeAdapter
<
WeightType
>::
type
;
if
(
!
use_fused_moe
)
{
// We need separate config for each architecture since we will target different tensorcore instructions. For
// float, we do not target TCs.
using
MixedGemmArchTraits
=
cutlass
::
gemm
::
kernel
::
MixedGemmArchTraits
<
ElementType
,
CutlassWeightType
,
arch
>
;
using
ElementAccumulator
=
typename
MixedGemmArchTraits
::
AccType
;
using
EpilogueOp
=
typename
tensorrt_llm
::
cutlass_extensions
::
Epilogue
<
CutlassGemmOutputType
,
MixedGemmArchTraits
::
ElementsPerAccessC
,
ElementAccumulator
,
EpilogueTag
>::
Op
;
typename
EpilogueOp
::
Params
epilogue_op
(
ElementAccumulator
(
1.
f
),
biases
?
ElementAccumulator
(
1.
f
)
:
ElementAccumulator
(
0.
f
));
#if defined(ENABLE_FP8)
if
constexpr
((
std
::
is_same_v
<
T
,
__nv_fp8_e4m3
>
||
std
::
is_same_v
<
T
,
__nv_fp8_e5m2
>
)
&&
std
::
is_same_v
<
EpilogueTag
,
cutlass_extensions
::
EpilogueOpDefault
>
)
{
TLLM_CHECK_WITH_INFO
(
weight_scales
==
nullptr
&&
biases
==
nullptr
&&
alpha_scale_ptr_array
,
"weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 "
"Ada"
);
epilogue_op
.
alpha_ptr_array
=
alpha_scale_ptr_array
;
}
#endif
// Finally, set up the kernel.
using
GemmKernel_
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmGrouped
<
ElementType
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
MixedGemmArchTraits
::
ElementsPerAccessA
,
CutlassWeightType
,
typename
MixedGemmArchTraits
::
LayoutB
,
cutlass
::
ComplexTransform
::
kNone
,
MixedGemmArchTraits
::
ElementsPerAccessB
,
CutlassGemmOutputType
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
typename
MixedGemmArchTraits
::
OperatorClass
,
arch
,
ThreadblockShape
,
WarpShape
,
typename
MixedGemmArchTraits
::
InstructionShape
,
EpilogueOp
,
cutlass
::
gemm
::
threadblock
::
GemmBatchedIdentityThreadblockSwizzle
,
Stages
,
cutlass
::
gemm
::
kernel
::
GroupScheduleMode
::
kDeviceOnly
,
typename
MixedGemmArchTraits
::
Operator
>::
GemmKernel
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
MoeFCGemm
<
typename
GemmKernel_
::
Mma
,
typename
GemmKernel_
::
Epilogue
,
typename
GemmKernel_
::
ThreadblockSwizzle
,
arch
,
// Ensure top level arch is used for dispatch
GemmKernel_
::
kGroupScheduleMode
>
;
using
GemmGrouped
=
cutlass
::
gemm
::
device
::
GemmGrouped
<
GemmKernel
>
;
if
(
kernel_occupancy
!=
nullptr
)
{
*
kernel_occupancy
=
tensorrt_llm
::
cutlass_extensions
::
compute_occupancy_for_kernel
<
GemmKernel
>
();
return
;
}
int
occupancy
=
std
::
min
(
2
,
GemmGrouped
::
maximum_active_blocks
());
TLLM_CHECK_WITH_INFO
(
occupancy
>
0
,
"GPU lacks the shared memory resources to run GroupedGEMM kernel"
);
int
const
threadblock_count
=
multi_processor_count
*
occupancy
;
int
const
group_size
=
gemm_k
;
typename
GemmGrouped
::
Arguments
args
(
num_experts
,
threadblock_count
,
group_size
,
epilogue_op
,
reinterpret_cast
<
ElementType
const
*>
(
A
),
reinterpret_cast
<
CutlassWeightType
const
*>
(
B
),
reinterpret_cast
<
CutlassGemmOutputType
const
*>
(
weight_scales
),
reinterpret_cast
<
CutlassGemmOutputType
const
*>
(
biases
),
bias_is_broadcast
,
reinterpret_cast
<
CutlassGemmOutputType
*>
(
C
),
total_tokens_including_expert
,
gemm_n
,
gemm_k
);
GemmGrouped
gemm
;
auto
can_implement
=
gemm
.
can_implement
(
args
);
TLLM_CHECK_WITH_INFO
(
can_implement
==
cutlass
::
Status
::
kSuccess
,
"MoE FC kernel will fail for params. Error: "
+
std
::
string
(
cutlassGetStatusString
(
can_implement
)));
auto
init_status
=
gemm
.
initialize
(
args
);
TLLM_CHECK_WITH_INFO
(
init_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize cutlass grouped gemm. Error: "
+
std
::
string
(
cutlassGetStatusString
(
init_status
)));
auto
run_status
=
gemm
.
run
(
stream
);
TLLM_CHECK_WITH_INFO
(
run_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run cutlass grouped gemm. Error: "
+
std
::
string
(
cutlassGetStatusString
(
run_status
)));
}
else
if
constexpr
(
sizeof
(
ElementType
)
==
2
&&
sizeof
(
CutlassWeightType
)
==
2
&&
(
std
::
is_same_v
<
EpilogueTag
,
cutlass_extensions
::
EpilogueOpDefaultSilu
>
||
std
::
is_same_v
<
EpilogueTag
,
cutlass_extensions
::
EpilogueOpDefaultFtGelu
>
)
)
// use fused moe gemm
// kernel.. (only support
// fp16 or bf16)
{
sm80_generic_fused_moe_gemm_kernelLauncher
<
ElementType
,
CutlassWeightType
,
ThreadblockShape
::
kM
,
ThreadblockShape
::
kN
,
ThreadblockShape
::
kK
,
Stages
,
EpilogueTag
>
(
reinterpret_cast
<
ElementType
const
*>
(
A
),
reinterpret_cast
<
CutlassWeightType
const
*>
(
B
),
reinterpret_cast
<
ElementType
const
*>
(
biases
),
bias_is_broadcast
,
reinterpret_cast
<
ElementType
*>
(
C
),
total_tokens_including_expert
,
num_rows
,
gemm_n
,
gemm_k
,
num_experts
,
multi_processor_count
,
stream
,
kernel_occupancy
);
}
}
}
// namespace kernels::cutlass_kernels
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
Arch
,
typename
EpilogueTag
,
typename
ThreadblockShape
,
typename
WarpShape
,
int
Stages
>
static
void
dispatch
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
num_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
static_assert
(
!
std
::
is_same_v
<
Arch
,
cutlass
::
arch
::
Sm90
>
,
"Use TMA specialised functions for arch SM90"
);
#if defined(ENABLE_FP8)
constexpr
bool
isFp8
=
std
::
is_same_v
<
T
,
__nv_fp8_e4m3
>
||
std
::
is_same_v
<
T
,
__nv_fp8_e5m2
>
;
#else
constexpr
bool
isFp8
=
false
;
#endif
if
constexpr
((
Stages
==
2
||
Arch
::
kMinComputeCapability
>=
80
)
&&
(
!
isFp8
||
std
::
is_same_v
<
Arch
,
cutlass
::
arch
::
Sm89
>
)
)
{
kernels
::
cutlass_kernels
::
genericMoeGemmKernelLauncher
<
T
,
WeightType
,
GemmOutputType
,
Arch
,
EpilogueTag
,
ThreadblockShape
,
WarpShape
,
Stages
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
num_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
else
{
TLLM_THROW
(
"Cutlass gemm. Not instantiated for arch %d with stages set to %d"
,
Arch
::
kMinComputeCapability
,
Stages
);
}
}
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
ThreadblockShape
,
typename
WarpShape
>
void
dispatchGemmConfig
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
num_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
switch
(
gemm_config
.
stages
)
{
case
2
:
dispatch
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
ThreadblockShape
,
WarpShape
,
2
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
num_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
3
:
dispatch
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
ThreadblockShape
,
WarpShape
,
3
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
num_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
4
:
dispatch
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
ThreadblockShape
,
WarpShape
,
4
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
num_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
default:
TLLM_THROW
(
"dispatchGemmConfig does not support stages %d"
,
gemm_config
.
stages
);
break
;
}
}
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
// This overload is only enabled when T == WeightType.
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
float
>
::
value
#if defined(ENABLE_FP8)
&&
!
std
::
is_same
<
T
,
__nv_fp8_e4m3
>::
value
&&
!
std
::
is_same
<
T
,
__nv_fp8_e5m2
>::
value
#endif
&&
std
::
is_same
<
T
,
WeightType
>::
value
>::
type
*
=
nullptr
>
void
dispatchMoeGemmToCutlass
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
switch
(
gemm_config
.
tile_config
)
{
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape16x128x64_WarpShape16x32x64
:
TLLM_CHECK_WITH_INFO
(
arch
::
kMinComputeCapability
>=
75
,
"Invalid config on Volta"
);
if
constexpr
(
arch
::
kMinComputeCapability
>=
75
)
{
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape16x256x64_WarpShape16x64x64
:
TLLM_CHECK_WITH_INFO
(
arch
::
kMinComputeCapability
>=
75
,
"Invalid config on Volta"
);
if
constexpr
(
arch
::
kMinComputeCapability
>=
75
)
{
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
256
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape32x128x64_WarpShape32x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape64x128x64_WarpShape32x64x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape128x128x64_WarpShape64x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
Undefined
:
TLLM_THROW
(
"GEMM config undefined."
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
ChooseWithHeuristic
:
TLLM_THROW
(
"GEMM config should have already been set by heuristic."
);
break
;
default:
TLLM_THROW
(
"Config is invalid for same type tensorop GEMM."
);
break
;
}
}
// Tensorop GEMM overload
// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve
// compile time
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
float
>
::
value
&&
!
std
::
is_same
<
T
,
WeightType
>::
value
>::
type
*
=
nullptr
>
void
dispatchMoeGemmToCutlass
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
switch
(
gemm_config
.
tile_config
)
{
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape16x128x64_WarpShape16x32x64
:
TLLM_CHECK_WITH_INFO
(
arch
::
kMinComputeCapability
>=
75
,
"Invalid config on Volta"
);
if
constexpr
(
arch
::
kMinComputeCapability
>=
75
)
{
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape16x256x64_WarpShape16x64x64
:
TLLM_CHECK_WITH_INFO
(
arch
::
kMinComputeCapability
>=
75
,
"Invalid config on Volta"
);
if
constexpr
(
arch
::
kMinComputeCapability
>=
75
)
{
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
256
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape32x128x64_WarpShape32x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape64x128x64_WarpShape64x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape128x128x64_WarpShape128x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
128
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
Undefined
:
TLLM_THROW
(
"GEMM config undefined."
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
ChooseWithHeuristic
:
TLLM_THROW
(
"GEMM config should have already been set by heuristic."
);
break
;
default:
TLLM_THROW
(
"Config is invalid for mixed type tensorop GEMM."
);
break
;
}
}
// This overload will handle tensorop gemms.
// This overload is only enabled when T == WeightType and T == __nv_fp8_e4m3 or __nv_fp8_e5m2
#if defined(ENABLE_FP8)
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
std
::
enable_if
<
(
std
::
is_same
<
T
,
__nv_fp8_e4m3
>
::
value
||
std
::
is_same
<
T
,
__nv_fp8_e5m2
>::
value
)
&&
std
::
is_same
<
T
,
WeightType
>::
value
>::
type
*
=
nullptr
>
void
dispatchMoeGemmToCutlass
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
switch
(
gemm_config
.
tile_config
)
{
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape16x256x128_WarpShape16x64x128
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
256
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape32x128x64_WarpShape32x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape64x128x64_WarpShape64x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape64x64x128_WarpShape32x64x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape128x64x64_WarpShape64x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape128x256x64_WarpShape64x64x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
256
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape256x128x64_WarpShape64x64x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
Undefined
:
TLLM_THROW
(
"GEMM config undefined."
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
ChooseWithHeuristic
:
TLLM_THROW
(
"GEMM config should have already been set by heuristic."
);
break
;
default:
TLLM_THROW
(
"Config is invalid for same type tensorop GEMM."
);
break
;
}
}
#endif
// This overload will handle simt gemms. It is disabled via SFINAE for tensorop.
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
std
::
enable_if
<
std
::
is_same
<
T
,
float
>
::
value
>::
type
*
=
nullptr
>
void
dispatchMoeGemmToCutlass
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
switch
(
gemm_config
.
tile_config
)
{
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape128x128x8_WarpShape64x64x8
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
8
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
8
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
Undefined
:
TLLM_THROW
(
"GEMM config undefined."
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
ChooseWithHeuristic
:
TLLM_THROW
(
"GEMM config should have already been set by heuristic."
);
break
;
default:
TLLM_THROW
(
"Unsupported config for float MoE gemm."
);
break
;
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getConfigs
()
const
{
return
getConfigs
(
sm_
);
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getConfigs
(
int
sm
)
{
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
candidate_configs
=
getHopperConfigs
(
sm
);
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
ampere_configs
=
getAmpereConfigs
(
sm
);
std
::
copy
(
ampere_configs
.
begin
(),
ampere_configs
.
end
(),
std
::
back_inserter
(
candidate_configs
));
return
candidate_configs
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getAmpereConfigs
(
int
sm
)
{
using
tensorrt_llm
::
cutlass_extensions
::
CutlassGemmConfig
;
static
constexpr
auto
weight_only_flag
=
std
::
is_same
<
T
,
WeightType
>::
value
?
CutlassGemmConfig
::
NONE
:
CutlassGemmConfig
::
WEIGHT_ONLY
;
static
constexpr
auto
simt_only_flag
=
std
::
is_same
<
T
,
float
>::
value
?
CutlassGemmConfig
::
SIMT_ONLY
:
CutlassGemmConfig
::
NONE
;
static
constexpr
auto
fp8_only_flag
=
use_fp8
?
CutlassGemmConfig
::
FP8_ONLY
:
CutlassGemmConfig
::
NONE
;
int
const
max_split_k
=
1
;
int
const
grouped_gemm_flag
=
CutlassGemmConfig
::
GROUPED_GEMM
;
int
const
enable_hopper
=
CutlassGemmConfig
::
NONE
;
auto
config_type_param
=
static_cast
<
CutlassGemmConfig
::
CandidateConfigTypeParam
>
(
weight_only_flag
|
simt_only_flag
|
grouped_gemm_flag
|
enable_hopper
|
fp8_only_flag
);
if
(
!
kernels
::
cutlass_kernels
::
isValidAmpereMOESpecialisation
<
T
,
WeightType
>
())
{
return
{};
}
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
ampere_configs
=
kernels
::
cutlass_kernels
::
get_candidate_configs
(
sm
,
max_split_k
,
config_type_param
);
return
ampere_configs
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getHopperConfigs
(
int
sm
)
{
using
tensorrt_llm
::
cutlass_extensions
::
CutlassGemmConfig
;
static
constexpr
auto
weight_only_flag
=
std
::
is_same
<
T
,
WeightType
>::
value
?
CutlassGemmConfig
::
NONE
:
CutlassGemmConfig
::
WEIGHT_ONLY
;
static
constexpr
auto
simt_only_flag
=
std
::
is_same
<
T
,
float
>::
value
?
CutlassGemmConfig
::
SIMT_ONLY
:
CutlassGemmConfig
::
NONE
;
int
const
max_split_k
=
1
;
int
const
grouped_gemm_flag
=
CutlassGemmConfig
::
GROUPED_GEMM
;
int
const
enable_hopper
=
CutlassGemmConfig
::
HOPPER
;
static
constexpr
auto
fp8_only_flag
=
use_fp8
?
CutlassGemmConfig
::
FP8_ONLY
:
CutlassGemmConfig
::
NONE
;
auto
config_type_param
=
static_cast
<
CutlassGemmConfig
::
CandidateConfigTypeParam
>
(
weight_only_flag
|
simt_only_flag
|
grouped_gemm_flag
|
enable_hopper
|
fp8_only_flag
);
if
(
!
kernels
::
cutlass_kernels
::
isValidHopperMOESpecialisation
<
T
,
WeightType
>
())
{
return
{};
}
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
hopper_configs
=
kernels
::
cutlass_kernels
::
get_candidate_configs
(
sm
,
max_split_k
,
config_type_param
);
return
hopper_configs
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
bool
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
isHopperSpecialised
(
cutlass_extensions
::
CutlassGemmConfig
gemm_config
)
const
{
bool
config_is_sm90
=
gemm_config
.
is_sm90
;
return
supportsHopperSpecialisation
()
&&
config_is_sm90
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
bool
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
supportsHopperSpecialisation
()
const
{
return
sm_
==
90
&&
kernels
::
cutlass_kernels
::
isValidHopperMOESpecialisation
<
T
,
WeightType
>
();
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
int
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getSM
()
const
{
return
this
->
sm_
;
}
// currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
bool
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
supportsFusedGatedActivation
(
bool
is_gated_activation
,
int
gemm_n
,
int
gemm_k
)
const
{
constexpr
bool
ENABLE_FUSED_GATED_ACTIVATION
=
true
;
return
is_gated_activation
&&
std
::
is_same_v
<
T
,
WeightType
>
&&
!
std
::
is_same_v
<
T
,
float
>
&&
!
use_fp8
&&
(
this
->
getSM
()
>=
80
)
&&
(
gemm_k
%
64
==
0
)
&&
(
gemm_n
%
64
==
0
)
&&
ENABLE_FUSED_GATED_ACTIVATION
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
bool
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
isFusedGatedActivation
(
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
bool
is_gated_activation
,
int
gemm_n
,
int
gemm_k
)
const
{
return
supportsFusedGatedActivation
(
is_gated_activation
,
gemm_n
,
gemm_k
)
&&
!
gemm_config
.
is_sm90
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
MoeGemmRunner
()
{
int
device
{
-
1
};
tensorrt_llm
::
common
::
check_cuda_error
(
cudaGetDevice
(
&
device
));
sm_
=
tensorrt_llm
::
common
::
getSMVersion
();
tensorrt_llm
::
common
::
check_cuda_error
(
cudaDeviceGetAttribute
(
&
multi_processor_count_
,
cudaDevAttrMultiProcessorCount
,
device
));
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
template
<
typename
EpilogueTag
>
void
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
dispatchToArch
<
EpilogueTag
>
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C_void
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
hopper_input
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
)
{
static_assert
(
std
::
is_same_v
<
ScaleBiasType
,
OutputType
>
,
"Separate Scale/Bias type is not supported. This is assumed to be the gemm output type"
);
// For now we always cast this to output type.
// In the future this will vary based on what fusions are applied for FP8
auto
*
C
=
reinterpret_cast
<
OutputType
*>
(
C_void
);
TLLM_CHECK_WITH_INFO
(
sm_
>=
89
||
!
hopper_input
.
isValid
(),
"Hopper input information is set for non specialised implementation"
);
TLLM_CHECK_WITH_INFO
(
sm_
==
90
||
!
gemm_config
.
is_sm90
,
"Hopper configuration provided for non-Hopper architecture"
);
if
(
sm_
>=
75
&&
sm_
<
80
)
{
dispatchMoeGemmToCutlass
<
T
,
WeightType
,
ScaleBiasType
,
cutlass
::
arch
::
Sm75
,
EpilogueTag
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count_
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
else
if
(
sm_
>=
80
&&
sm_
<
90
)
{
if
constexpr
(
use_fp8
)
{
#if defined(ENABLE_FP8)
static_assert
(
!
std
::
is_same_v
<
OutputType
,
__nv_fp8_e4m3
>
&&
!
std
::
is_same_v
<
OutputType
,
__nv_fp8_e5m2
>
,
"FP8 GEMM Output not supported"
);
#endif
TLLM_CHECK_WITH_INFO
(
sm_
==
89
,
"For sm >= 80 and < 90, fp8 is only supported with sm == 89"
);
dispatchMoeGemmToCutlass
<
T
,
WeightType
,
ScaleBiasType
,
cutlass
::
arch
::
Sm89
,
EpilogueTag
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count_
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
else
{
dispatchMoeGemmToCutlass
<
T
,
WeightType
,
ScaleBiasType
,
cutlass
::
arch
::
Sm80
,
EpilogueTag
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count_
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
}
else
if
(
sm_
>=
90
)
{
if
constexpr
(
kernels
::
cutlass_kernels
::
isValidHopperMOESpecialisation
<
T
,
WeightType
,
EpilogueTag
>
())
{
// We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens
// SM80 is faster. We check here to see which is selected
if
(
gemm_config
.
is_sm90
)
{
TLLM_CHECK_WITH_INFO
(
biases
!=
nullptr
||
hopper_input
.
ptr_c
==
nullptr
,
"Input biases and hopper input disagree if bias is enabled"
);
TLLM_CHECK_WITH_INFO
(
hopper_input
.
isValid
(),
"Calling SM90 configuration with invalid hopper config"
);
// Select the appropriate fusion function
auto
select_function
=
[
&
]()
{
switch
(
hopper_input
.
fusion
)
{
case
HopperGroupedGemmInput
::
EpilogueFusion
::
FINALIZE
:
return
&
dispatchMoeGemmSelectTileShapeSM90
<
T
,
WeightType
,
OutputType
,
EpilogueTag
,
HopperGroupedGemmInput
::
EpilogueFusion
::
FINALIZE
>
;
case
HopperGroupedGemmInput
::
EpilogueFusion
::
NONE
:
return
&
dispatchMoeGemmSelectTileShapeSM90
<
T
,
WeightType
,
OutputType
,
EpilogueTag
,
HopperGroupedGemmInput
::
EpilogueFusion
::
NONE
>
;
case
HopperGroupedGemmInput
::
EpilogueFusion
::
ACTIVATION
:
case
HopperGroupedGemmInput
::
EpilogueFusion
::
GATED_ACTIVATION
:
default:
TLLM_THROW
(
"Unimplemented fusion %d requested"
,
(
int
)
hopper_input
.
fusion
);
};
};
auto
selected_func
=
select_function
();
selected_func
(
hopper_input
,
num_experts
,
gemm_config
,
multi_processor_count_
,
stream
,
occupancy
,
nullptr
);
return
;
}
// Fallthrough to SM80 impl below
}
// Do Ampere case instead
if
constexpr
(
kernels
::
cutlass_kernels
::
isValidAmpereMOESpecialisation
<
T
,
WeightType
,
EpilogueTag
>
())
{
TLLM_CHECK_WITH_INFO
(
!
hopper_input
.
isValid
(),
"Non-specialised Hopper implementation is being rerouted to fallback implementation so input "
"information is not required"
);
TLLM_CHECK_WITH_INFO
(
!
gemm_config
.
is_sm90
,
"GEMM config is for SM90 configuration, but this configuration is not valid for Hppper"
);
dispatchMoeGemmToCutlass
<
T
,
WeightType
,
ScaleBiasType
,
cutlass
::
arch
::
Sm80
,
EpilogueTag
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count_
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
else
{
TLLM_THROW
(
"Configuration expects SM80 but configuration is not supported by SM80 kernels"
);
}
}
else
{
TLLM_THROW
(
"Arch unsupported for MoE GEMM"
);
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
size_t
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getMaxWorkspaceSize
(
int
num_experts
)
const
{
if
(
num_experts
!=
num_experts_
)
{
TLLM_LOG_TRACE
(
"Calling getMaxWorkspaceSize() with a new expert count %d vs %d"
,
num_experts
,
num_experts_
);
num_experts_
=
num_experts
;
gemm_workspace_size_
=
calcMaxWorkspaceSize
(
num_experts
);
}
return
gemm_workspace_size_
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
size_t
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
calcMaxWorkspaceSize
(
int
num_experts
)
const
{
if
(
!
supportsHopperSpecialisation
())
{
return
0
;
}
if
constexpr
(
kernels
::
cutlass_kernels
::
isValidHopperMOESpecialisation
<
T
,
WeightType
>
())
{
auto
configs
=
getHopperConfigs
(
sm_
);
size_t
max_size
=
0
;
bool
has_config
=
false
;
for
(
auto
conf
:
configs
)
{
#define CALC_SIZE_FUSION(FUSION) \
do \
{ \
try \
{ \
size_t size = calcMaxWorkspaceSizeSM90<T, WeightType, OutputType, FUSION>( \
num_experts, conf, multi_processor_count_); \
max_size = std::max(max_size, size); \
has_config = true; \
} \
catch (tensorrt_llm::common::TllmException const& e) \
{ \
TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size"); \
} \
} while (0)
CALC_SIZE_FUSION
(
HopperGroupedGemmInput
::
EpilogueFusion
::
NONE
);
CALC_SIZE_FUSION
(
HopperGroupedGemmInput
::
EpilogueFusion
::
FINALIZE
);
}
TLLM_CHECK_WITH_INFO
(
has_config
,
"Could not find valid config when calculating workspace size"
);
return
max_size
;
}
else
{
TLLM_THROW
(
"Attempting to calculate Hopper GEMM workspace size with unsupported weight combination"
);
return
0
;
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
template
<
typename
EpilogueTag
>
void
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
runGemm
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
hopper_input
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
)
{
dispatchToArch
<
EpilogueTag
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
chosen_conf
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
nullptr
);
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
void
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
moeGemmBiasAct
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
hopper_input
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
ActivationType
activation_type
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
)
{
switch
(
activation_type
)
{
case
ActivationType
::
Relu
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefaultReLU
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
Gelu
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefaultFtGelu
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
Silu
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefaultSilu
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
Identity
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefault
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
Swiglu
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefaultSilu
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
Geglu
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefaultFtGelu
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
InvalidType
:
TLLM_THROW
(
"Activation type for fpA_intB must be valid."
);
break
;
default:
TLLM_THROW
(
"Invalid activation type."
);
break
;
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
void
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
moeGemm
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
hopper_input
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
)
{
runGemm
<
cutlass_extensions
::
EpilogueOpDefault
>
(
A
,
B
,
weight_scales
,
nullptr
,
true
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
}
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // __GNUC__
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // __GNUC__
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace
tensorrt_llm
{
using
EpilogueFusion
=
HopperGroupedGemmInput
::
EpilogueFusion
;
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
EpilogueTag
,
EpilogueFusion
FUSION
,
typename
TileShape
,
typename
ClusterShape
>
void
dispatchMoeGemmSelectBiasSM90
(
HopperGroupedGemmInput
hopper_input
,
int
num_experts
,
int
multi_processor_count
,
cudaStream_t
stream
,
int
*
occupancy
,
size_t
*
workspace_size
)
{
static_assert
(
kernels
::
cutlass_kernels
::
isValidHopperMOESpecialisation
<
T
,
WeightType
,
EpilogueTag
>
(),
"Invalid hopper configuration invoked, fallback to Sm80"
);
TLLM_CHECK_WITH_INFO
(
workspace_size
||
hopper_input
.
isValid
(),
"Hopper specialisation is missing additional input information"
);
// auto func = hopper_input.ptr_c ?
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T, WeightType,
// cutlass::arch::Sm90, EpilogueTag, true>
// :
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T,
// WeightType,
// cutlass::arch::Sm90, EpilogueTag, false>;
// TODO(dastokes) Re-enable bias when CUTLASS supports it
auto
func
=
kernels
::
cutlass_kernels
::
sm90_generic_moe_gemm_kernelLauncher
<
T
,
WeightType
,
OutputType
,
EpilogueTag
,
FUSION
,
TileShape
,
ClusterShape
,
false
>
;
func
(
hopper_input
,
num_experts
,
multi_processor_count
,
stream
,
occupancy
,
workspace_size
);
}
/*
1x1x1 cluster shape is are supported for any tile shape.
2x1x1 cluster shape is only supported for when the M tile is at least 128.
1x2x1 cluster shape is only supported when the N tile is at least 128.
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
that may not be very useful in practice.
*/
template
<
typename
CTAShape
,
typename
ClusterShape
>
constexpr
bool
are_tile_shapes_supported
()
{
using
namespace
cute
;
[[
maybe_unused
]]
constexpr
int
cta_m
=
get
<
0
>
(
CTAShape
{});
[[
maybe_unused
]]
constexpr
int
cta_n
=
get
<
1
>
(
CTAShape
{});
constexpr
int
cga_m
=
get
<
0
>
(
ClusterShape
{});
constexpr
int
cga_n
=
get
<
1
>
(
ClusterShape
{});
if
constexpr
(
cga_m
==
_1
{}
&&
cga_n
==
_1
{})
{
return
true
;
}
else
if
constexpr
(
cga_m
==
_2
{}
&&
cga_n
==
_1
{}
&&
cta_m
>=
_128
{})
{
return
true
;
}
else
if
constexpr
(
cga_m
==
_1
{}
&&
cga_n
==
_2
{}
&&
cta_n
>=
_128
{})
{
return
true
;
}
else
if
constexpr
(
cga_m
==
_2
{}
&&
cga_n
==
_2
{}
&&
cta_m
>=
_128
{}
&&
cta_n
>=
_128
{})
{
return
true
;
}
else
{
return
false
;
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
EpilogueTag
,
EpilogueFusion
FUSION
,
typename
TileShape
>
void
dispatchMoeGemmSelectClusterShapeSM90
(
HopperGroupedGemmInput
hopper_input
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
cudaStream_t
stream
,
int
*
occupancy
,
size_t
*
workspace_size
)
{
using
namespace
cute
;
switch
(
gemm_config
.
cluster_shape
)
{
#define SHAPE_CASE(M, N, K) \
case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \
{ \
using ClusterShape = Shape<_##M, _##N, _##K>; \
if constexpr (are_tile_shapes_supported<TileShape, ClusterShape>()) \
{ \
dispatchMoeGemmSelectBiasSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape>( \
hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \
break; \
} \
else \
{ \
TLLM_THROW("Unsupported tile and cluster shape combination"); \
} \
}
SHAPE_CASE
(
1
,
1
,
1
)
SHAPE_CASE
(
1
,
2
,
1
)
SHAPE_CASE
(
2
,
1
,
1
)
SHAPE_CASE
(
2
,
2
,
1
)
#undef SHAPE_CASE
default:
TLLM_THROW
(
"Unsupported config for MoE gemm."
);
}
}
// namespace tensorrt_llm
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
EpilogueTag
,
EpilogueFusion
FUSION
>
void
dispatchMoeGemmSelectTileShapeSM90
(
HopperGroupedGemmInput
hopper_input
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
cudaStream_t
stream
,
int
*
occupancy
,
size_t
*
workspace_size
)
{
using
namespace
cute
;
switch
(
gemm_config
.
tile_config_sm90
)
{
#define SHAPE_CASE(M, N, K) \
case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \
{ \
constexpr int KtileBytes = K / sizeof(T); \
using KTileDim = Int<KtileBytes>; \
using TileShape = Shape<_##M, _##N, KTileDim>; \
dispatchMoeGemmSelectClusterShapeSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape>( \
hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \
break; \
}
SHAPE_CASE
(
128
,
16
,
128
)
SHAPE_CASE
(
128
,
32
,
128
)
SHAPE_CASE
(
128
,
64
,
128
)
SHAPE_CASE
(
128
,
128
,
128
)
SHAPE_CASE
(
128
,
256
,
128
)
SHAPE_CASE
(
256
,
128
,
128
)
#undef SHAPE_CASE
case
cutlass_extensions
::
CutlassTileConfigSM90
::
Undefined
:
TLLM_THROW
(
"GEMM config undefined."
);
break
;
case
cutlass_extensions
::
CutlassTileConfigSM90
::
ChooseWithHeuristic
:
TLLM_THROW
(
"GEMM config should have already been set by heuristic."
);
break
;
default:
TLLM_THROW
(
"Unsupported config for MoE gemm."
);
break
;
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
EpilogueFusion
FUSION
>
size_t
calcMaxWorkspaceSizeSM90
(
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
)
{
size_t
count
;
// Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat
dispatchMoeGemmSelectTileShapeSM90
<
T
,
WeightType
,
OutputType
,
cutlass_extensions
::
EpilogueOpDefault
,
FUSION
>
(
HopperGroupedGemmInput
{},
num_experts
,
gemm_config
,
multi_processor_count
,
cudaStream_t
{
0
},
nullptr
,
&
count
);
return
count
;
}
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/mma_sm90.h"
#include "cutlass_extensions/epilogue_helpers.h"
namespace
tensorrt_llm
::
kernels
::
cutlass_kernels
{
// Hopper arch
template
<
typename
T
,
typename
WeightType
,
typename
EpilogueTag
=
cutlass_extensions
::
EpilogueOpDefault
>
constexpr
bool
isValidHopperMOESpecialisation
()
{
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
return
cutlass
::
platform
::
is_same
<
T
,
WeightType
>::
value
&&
cutlass
::
platform
::
is_same
<
EpilogueTag
,
cutlass_extensions
::
EpilogueOpDefault
>::
value
;
#else
return
false
;
// CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled
#endif
}
// Hopper arch
template
<
typename
T
,
typename
WeightType
,
typename
EpilogueTag
=
cutlass_extensions
::
EpilogueOpDefault
>
constexpr
bool
isValidAmpereMOESpecialisation
()
{
return
true
;
// Default to true
}
}
// namespace tensorrt_llm::kernels::cutlass_kernels
sgl-kernel/setup.py
View file @
3ee62235
...
...
@@ -39,8 +39,6 @@ cutlass_default = root / "3rdparty" / "cutlass"
cutlass
=
Path
(
os
.
environ
.
get
(
"CUSTOM_CUTLASS_SRC_DIR"
,
default
=
cutlass_default
))
flashinfer
=
root
/
"3rdparty"
/
"flashinfer"
turbomind
=
root
/
"3rdparty"
/
"turbomind"
tensorrt_llm_parent
=
root
/
"3rdparty"
tensorrt_llm
=
root
/
"3rdparty"
/
"tensorrt_llm"
include_dirs
=
[
cutlass
.
resolve
()
/
"include"
,
cutlass
.
resolve
()
/
"tools"
/
"util"
/
"include"
,
...
...
@@ -53,8 +51,6 @@ include_dirs = [
"cublasLt"
,
turbomind
.
resolve
(),
turbomind
.
resolve
()
/
"src"
,
tensorrt_llm_parent
.
resolve
(),
tensorrt_llm
.
resolve
()
/
"cutlass_extensions"
/
"include"
,
]
nvcc_flags
=
[
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment