Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e81d7f11
Unverified
Commit
e81d7f11
authored
Jan 30, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 30, 2025
Browse files
add tensorrt_llm moe_gemm as 3rdparty (#3217)
parent
222ce6f1
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2165 additions
and
325 deletions
+2165
-325
sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h
sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h
+21
-0
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp
...kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp
+0
-187
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h
+0
-138
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h
...kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h
+25
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl
...rnels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl
+96
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h
...tlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h
+37
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl
...ass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl
+348
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu
...kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu
+131
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h
...t_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h
+230
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu
...ls/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu
+24
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu
...s/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu
+24
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu
...s/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu
+24
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu
...ls/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu
+22
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu
...s/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu
+22
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu
...s/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu
+22
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu
...ls/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu
+22
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu
...nels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu
+28
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h
...nels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h
+823
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h
...cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h
+222
-0
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h
...rt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h
+44
-0
No files found.
sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp
deleted
100644 → 0
View file @
222ce6f1
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define CUDA_LIB_NAME "cuda"
#if defined(_WIN32)
#include <windows.h>
#define dllOpen(name) LoadLibrary("nv" name ".dll")
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
#else // For non-Windows platforms
#include <dlfcn.h>
#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY)
#define dllClose(handle) dlclose(handle)
#define dllGetSym(handle, name) dlsym(handle, name)
#endif // defined(_WIN32)
#include "cudaDriverWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
namespace
tensorrt_llm
::
common
{
std
::
shared_ptr
<
CUDADriverWrapper
>
CUDADriverWrapper
::
getInstance
()
{
static
std
::
mutex
mutex
;
static
std
::
weak_ptr
<
CUDADriverWrapper
>
instance
;
std
::
shared_ptr
<
CUDADriverWrapper
>
result
=
instance
.
lock
();
if
(
result
)
{
return
result
;
}
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
result
=
instance
.
lock
();
if
(
!
result
)
{
result
=
std
::
shared_ptr
<
CUDADriverWrapper
>
(
new
CUDADriverWrapper
());
instance
=
result
;
}
return
result
;
}
CUDADriverWrapper
::
CUDADriverWrapper
()
:
handle
(
dllOpen
(
CUDA_LIB_NAME
))
{
TLLM_CHECK_WITH_INFO
(
handle
!=
nullptr
,
"CUDA driver library is not open correctly."
);
auto
load_sym
=
[](
void
*
handle
,
char
const
*
name
)
{
void
*
ret
=
dllGetSym
(
handle
,
name
);
return
ret
;
};
*
reinterpret_cast
<
void
**>
(
&
_cuGetErrorName
)
=
load_sym
(
handle
,
"cuGetErrorName"
);
*
reinterpret_cast
<
void
**>
(
&
_cuGetErrorMessage
)
=
load_sym
(
handle
,
"cuGetErrorMessage"
);
*
reinterpret_cast
<
void
**>
(
&
_cuFuncSetAttribute
)
=
load_sym
(
handle
,
"cuFuncSetAttribute"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLinkComplete
)
=
load_sym
(
handle
,
"cuLinkComplete"
);
*
reinterpret_cast
<
void
**>
(
&
_cuModuleUnload
)
=
load_sym
(
handle
,
"cuModuleUnload"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLinkDestroy
)
=
load_sym
(
handle
,
"cuLinkDestroy"
);
*
reinterpret_cast
<
void
**>
(
&
_cuModuleLoadData
)
=
load_sym
(
handle
,
"cuModuleLoadData"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLinkCreate
)
=
load_sym
(
handle
,
"cuLinkCreate_v2"
);
*
reinterpret_cast
<
void
**>
(
&
_cuModuleGetFunction
)
=
load_sym
(
handle
,
"cuModuleGetFunction"
);
*
reinterpret_cast
<
void
**>
(
&
_cuModuleGetGlobal
)
=
load_sym
(
handle
,
"cuModuleGetGlobal_v2"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLinkAddFile
)
=
load_sym
(
handle
,
"cuLinkAddFile_v2"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLinkAddData
)
=
load_sym
(
handle
,
"cuLinkAddData_v2"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLaunchCooperativeKernel
)
=
load_sym
(
handle
,
"cuLaunchCooperativeKernel"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLaunchKernel
)
=
load_sym
(
handle
,
"cuLaunchKernel"
);
*
reinterpret_cast
<
void
**>
(
&
_cuTensorMapEncodeTiled
)
=
load_sym
(
handle
,
"cuTensorMapEncodeTiled"
);
*
reinterpret_cast
<
void
**>
(
&
_cuMemcpyDtoH
)
=
load_sym
(
handle
,
"cuMemcpyDtoH_v2"
);
}
CUDADriverWrapper
::~
CUDADriverWrapper
()
{
dllClose
(
handle
);
}
CUresult
CUDADriverWrapper
::
cuGetErrorName
(
CUresult
error
,
char
const
**
pStr
)
const
{
return
(
*
_cuGetErrorName
)(
error
,
pStr
);
}
CUresult
CUDADriverWrapper
::
cuGetErrorMessage
(
CUresult
error
,
char
const
**
pStr
)
const
{
return
(
*
_cuGetErrorMessage
)(
error
,
pStr
);
}
CUresult
CUDADriverWrapper
::
cuFuncSetAttribute
(
CUfunction
hfunc
,
CUfunction_attribute
attrib
,
int
value
)
const
{
return
(
*
_cuFuncSetAttribute
)(
hfunc
,
attrib
,
value
);
}
CUresult
CUDADriverWrapper
::
cuLinkComplete
(
CUlinkState
state
,
void
**
cubinOut
,
size_t
*
sizeOut
)
const
{
return
(
*
_cuLinkComplete
)(
state
,
cubinOut
,
sizeOut
);
}
CUresult
CUDADriverWrapper
::
cuModuleUnload
(
CUmodule
hmod
)
const
{
return
(
*
_cuModuleUnload
)(
hmod
);
}
CUresult
CUDADriverWrapper
::
cuLinkDestroy
(
CUlinkState
state
)
const
{
return
(
*
_cuLinkDestroy
)(
state
);
}
CUresult
CUDADriverWrapper
::
cuModuleLoadData
(
CUmodule
*
module
,
void
const
*
image
)
const
{
return
(
*
_cuModuleLoadData
)(
module
,
image
);
}
CUresult
CUDADriverWrapper
::
cuLinkCreate
(
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
,
CUlinkState
*
stateOut
)
const
{
return
(
*
_cuLinkCreate
)(
numOptions
,
options
,
optionValues
,
stateOut
);
}
CUresult
CUDADriverWrapper
::
cuModuleGetFunction
(
CUfunction
*
hfunc
,
CUmodule
hmod
,
char
const
*
name
)
const
{
return
(
*
_cuModuleGetFunction
)(
hfunc
,
hmod
,
name
);
}
CUresult
CUDADriverWrapper
::
cuModuleGetGlobal
(
CUdeviceptr
*
dptr
,
size_t
*
bytes
,
CUmodule
hmod
,
char
const
*
name
)
const
{
return
(
*
_cuModuleGetGlobal
)(
dptr
,
bytes
,
hmod
,
name
);
}
CUresult
CUDADriverWrapper
::
cuLinkAddFile
(
CUlinkState
state
,
CUjitInputType
type
,
char
const
*
path
,
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
)
const
{
return
(
*
_cuLinkAddFile
)(
state
,
type
,
path
,
numOptions
,
options
,
optionValues
);
}
CUresult
CUDADriverWrapper
::
cuLinkAddData
(
CUlinkState
state
,
CUjitInputType
type
,
void
*
data
,
size_t
size
,
char
const
*
name
,
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
)
const
{
return
(
*
_cuLinkAddData
)(
state
,
type
,
data
,
size
,
name
,
numOptions
,
options
,
optionValues
);
}
CUresult
CUDADriverWrapper
::
cuLaunchCooperativeKernel
(
CUfunction
f
,
unsigned
int
gridDimX
,
unsigned
int
gridDimY
,
unsigned
int
gridDimZ
,
unsigned
int
blockDimX
,
unsigned
int
blockDimY
,
unsigned
int
blockDimZ
,
unsigned
int
sharedMemBytes
,
CUstream
hStream
,
void
**
kernelParams
)
const
{
return
(
*
_cuLaunchCooperativeKernel
)(
f
,
gridDimX
,
gridDimY
,
gridDimZ
,
blockDimX
,
blockDimY
,
blockDimZ
,
sharedMemBytes
,
hStream
,
kernelParams
);
}
CUresult
CUDADriverWrapper
::
cuLaunchKernel
(
CUfunction
f
,
unsigned
int
gridDimX
,
unsigned
int
gridDimY
,
unsigned
int
gridDimZ
,
unsigned
int
blockDimX
,
unsigned
int
blockDimY
,
unsigned
int
blockDimZ
,
unsigned
int
sharedMemBytes
,
CUstream
hStream
,
void
**
kernelParams
,
void
**
extra
)
const
{
return
(
*
_cuLaunchKernel
)(
f
,
gridDimX
,
gridDimY
,
gridDimZ
,
blockDimX
,
blockDimY
,
blockDimZ
,
sharedMemBytes
,
hStream
,
kernelParams
,
extra
);
}
CUresult
CUDADriverWrapper
::
cuTensorMapEncodeTiled
(
CUtensorMap
*
tensorMap
,
CUtensorMapDataType
tensorDataType
,
cuuint32_t
tensorRank
,
void
*
globalAddress
,
cuuint64_t
const
*
globalDim
,
cuuint64_t
const
*
globalStrides
,
cuuint32_t
const
*
boxDim
,
cuuint32_t
const
*
elementStrides
,
CUtensorMapInterleave
interleave
,
CUtensorMapSwizzle
swizzle
,
CUtensorMapL2promotion
l2Promotion
,
CUtensorMapFloatOOBfill
oobFill
)
const
{
return
(
*
_cuTensorMapEncodeTiled
)(
tensorMap
,
tensorDataType
,
tensorRank
,
globalAddress
,
globalDim
,
globalStrides
,
boxDim
,
elementStrides
,
interleave
,
swizzle
,
l2Promotion
,
oobFill
);
}
CUresult
CUDADriverWrapper
::
cuMemcpyDtoH
(
void
*
dstHost
,
CUdeviceptr
srcDevice
,
size_t
ByteCount
)
const
{
return
(
*
_cuMemcpyDtoH
)(
dstHost
,
srcDevice
,
ByteCount
);
}
}
// namespace tensorrt_llm::common
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h
deleted
100644 → 0
View file @
222ce6f1
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDA_DRIVER_WRAPPER_H
#define CUDA_DRIVER_WRAPPER_H
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
#include <memory>
#include <mutex>
namespace
tensorrt_llm
::
common
{
class
CUDADriverWrapper
{
public:
static
std
::
shared_ptr
<
CUDADriverWrapper
>
getInstance
();
~
CUDADriverWrapper
();
CUDADriverWrapper
(
CUDADriverWrapper
const
&
)
=
delete
;
CUDADriverWrapper
operator
=
(
CUDADriverWrapper
const
&
)
=
delete
;
CUDADriverWrapper
(
CUDADriverWrapper
&&
)
=
delete
;
CUDADriverWrapper
operator
=
(
CUDADriverWrapper
&&
)
=
delete
;
CUresult
cuGetErrorName
(
CUresult
error
,
char
const
**
pStr
)
const
;
CUresult
cuGetErrorMessage
(
CUresult
error
,
char
const
**
pStr
)
const
;
CUresult
cuFuncSetAttribute
(
CUfunction
hfunc
,
CUfunction_attribute
attrib
,
int
value
)
const
;
CUresult
cuLinkComplete
(
CUlinkState
state
,
void
**
cubinOut
,
size_t
*
sizeOut
)
const
;
CUresult
cuModuleUnload
(
CUmodule
hmod
)
const
;
CUresult
cuLinkDestroy
(
CUlinkState
state
)
const
;
CUresult
cuModuleLoadData
(
CUmodule
*
module
,
void
const
*
image
)
const
;
CUresult
cuLinkCreate
(
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
,
CUlinkState
*
stateOut
)
const
;
CUresult
cuModuleGetFunction
(
CUfunction
*
hfunc
,
CUmodule
hmod
,
char
const
*
name
)
const
;
CUresult
cuModuleGetGlobal
(
CUdeviceptr
*
dptr
,
size_t
*
bytes
,
CUmodule
hmod
,
char
const
*
name
)
const
;
CUresult
cuLinkAddFile
(
CUlinkState
state
,
CUjitInputType
type
,
char
const
*
path
,
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
)
const
;
CUresult
cuLinkAddData
(
CUlinkState
state
,
CUjitInputType
type
,
void
*
data
,
size_t
size
,
char
const
*
name
,
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
)
const
;
CUresult
cuLaunchCooperativeKernel
(
CUfunction
f
,
unsigned
int
gridDimX
,
unsigned
int
gridDimY
,
unsigned
int
gridDimZ
,
unsigned
int
blockDimX
,
unsigned
int
blockDimY
,
unsigned
int
blockDimZ
,
unsigned
int
sharedMemBytes
,
CUstream
hStream
,
void
**
kernelParams
)
const
;
CUresult
cuLaunchKernel
(
CUfunction
f
,
unsigned
int
gridDimX
,
unsigned
int
gridDimY
,
unsigned
int
gridDimZ
,
unsigned
int
blockDimX
,
unsigned
int
blockDimY
,
unsigned
int
blockDimZ
,
unsigned
int
sharedMemBytes
,
CUstream
hStream
,
void
**
kernelParams
,
void
**
extra
)
const
;
CUresult
cuTensorMapEncodeTiled
(
CUtensorMap
*
tensorMap
,
CUtensorMapDataType
tensorDataType
,
cuuint32_t
tensorRank
,
void
*
globalAddress
,
cuuint64_t
const
*
globalDim
,
cuuint64_t
const
*
globalStrides
,
cuuint32_t
const
*
boxDim
,
cuuint32_t
const
*
elementStrides
,
CUtensorMapInterleave
interleave
,
CUtensorMapSwizzle
swizzle
,
CUtensorMapL2promotion
l2Promotion
,
CUtensorMapFloatOOBfill
oobFill
)
const
;
CUresult
cuMemcpyDtoH
(
void
*
dstHost
,
CUdeviceptr
srcDevice
,
size_t
ByteCount
)
const
;
private:
void
*
handle
;
CUDADriverWrapper
();
CUresult
(
*
_cuGetErrorName
)(
CUresult
,
char
const
**
);
CUresult
(
*
_cuGetErrorMessage
)(
CUresult
,
char
const
**
);
CUresult
(
*
_cuFuncSetAttribute
)(
CUfunction
,
CUfunction_attribute
,
int
);
CUresult
(
*
_cuLinkComplete
)(
CUlinkState
,
void
**
,
size_t
*
);
CUresult
(
*
_cuModuleUnload
)(
CUmodule
);
CUresult
(
*
_cuLinkDestroy
)(
CUlinkState
);
CUresult
(
*
_cuLinkCreate
)(
unsigned
int
,
CUjit_option
*
,
void
**
,
CUlinkState
*
);
CUresult
(
*
_cuModuleLoadData
)(
CUmodule
*
,
void
const
*
);
CUresult
(
*
_cuModuleGetFunction
)(
CUfunction
*
,
CUmodule
,
char
const
*
);
CUresult
(
*
_cuModuleGetGlobal
)(
CUdeviceptr
*
,
size_t
*
,
CUmodule
,
char
const
*
);
CUresult
(
*
_cuLinkAddFile
)(
CUlinkState
,
CUjitInputType
,
char
const
*
,
unsigned
int
,
CUjit_option
*
,
void
**
);
CUresult
(
*
_cuLinkAddData
)(
CUlinkState
,
CUjitInputType
,
void
*
,
size_t
,
char
const
*
,
unsigned
int
,
CUjit_option
*
,
void
**
);
CUresult
(
*
_cuLaunchCooperativeKernel
)(
CUfunction
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
CUstream
,
void
**
);
CUresult
(
*
_cuLaunchKernel
)(
CUfunction
f
,
unsigned
int
gridDimX
,
unsigned
int
gridDimY
,
unsigned
int
gridDimZ
,
unsigned
int
blockDimX
,
unsigned
int
blockDimY
,
unsigned
int
blockDimZ
,
unsigned
int
sharedMemBytes
,
CUstream
hStream
,
void
**
kernelParams
,
void
**
extra
);
CUresult
(
*
_cuTensorMapEncodeTiled
)(
CUtensorMap
*
tensorMap
,
CUtensorMapDataType
tensorDataType
,
cuuint32_t
tensorRank
,
void
*
globalAddress
,
cuuint64_t
const
*
globalDim
,
cuuint64_t
const
*
globalStrides
,
cuuint32_t
const
*
boxDim
,
cuuint32_t
const
*
elementStrides
,
CUtensorMapInterleave
interleave
,
CUtensorMapSwizzle
swizzle
,
CUtensorMapL2promotion
l2Promotion
,
CUtensorMapFloatOOBfill
oobFill
);
CUresult
(
*
_cuMemcpyDtoH
)(
void
*
dstHost
,
CUdeviceptr
srcDevice
,
size_t
ByteCount
);
};
template
<
typename
T
>
void
checkDriver
(
T
result
,
CUDADriverWrapper
const
&
wrap
,
char
const
*
const
func
,
char
const
*
const
file
,
int
const
line
)
{
if
(
result
)
{
char
const
*
errorName
=
nullptr
;
char
const
*
errorMsg
=
nullptr
;
wrap
.
cuGetErrorName
(
result
,
&
errorName
);
wrap
.
cuGetErrorMessage
(
result
,
&
errorMsg
);
throw
TllmException
(
file
,
line
,
fmtstr
(
"[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s"
,
func
,
errorName
,
errorMsg
));
}
}
}
// namespace tensorrt_llm::common
/*
* Macros compliant with TensorRT coding conventions
*/
#define TLLM_CU_CHECK(stat) \
do \
{ \
tensorrt_llm::common::checkDriver( \
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
} while (0)
#endif // CUDA_DRIVER_WRAPPER_H
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
namespace
tensorrt_llm
::
kernels
::
cutlass_kernels
{
template
<
typename
ElementType_
,
typename
CutlassWeightType_
,
int
MaxTileM_
,
int
TileN_
,
int
TileK_
,
int
Stages_
,
typename
EpilogueTag
>
void
sm80_generic_fused_moe_gemm_kernelLauncher
(
ElementType_
const
*
A
,
CutlassWeightType_
const
*
B
,
ElementType_
const
*
biases
,
bool
bias_is_broadcast
,
ElementType_
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
num_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
int
multi_processor_count
,
cudaStream_t
stream
,
int
*
kernel_occupancy
);
}
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include <cutlass_extensions/epilogue_helpers.h>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh>
#include <tensorrt_llm/common/cudaUtils.h>
namespace tensorrt_llm::kernels::cutlass_kernels
{
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
typename EpilogueTag>
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
int* kernel_occupancy)
{
constexpr auto activation_type = fused_moe::EpilogueRouting<EpilogueTag>(true);
using GemmType = fused_moe::Fused_Moe_Kernel_sm80<ElementType_, CutlassWeightType_, ElementType_, MaxTileM_, TileN_,
TileK_, Stages_, activation_type>;
// make sure GPU has enough resources..
if (kernel_occupancy != nullptr)
{
constexpr int smem_size = GemmType::kSmemSize;
if (smem_size > (48 << 10))
{
cudaFuncAttributes attr{};
int device = 0;
int max_smem_per_block = 0;
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
tensorrt_llm::common::check_cuda_error(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global<GemmType>));
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
{
// This should mean that
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize,
// smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the
// heuristic to ignore this configuration.
*kernel_occupancy = 0;
return;
}
}
int max_active_blocks = -1;
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, fused_moe::run_global<GemmType>, GemmType::kThreadCount, smem_size));
*kernel_occupancy = max_active_blocks;
return;
}
int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks<GemmType>());
int const threadblock_count = multi_processor_count * occupancy;
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel");
using Arguments = typename GemmType::Arguments;
Arguments args{{const_cast<ElementType_*>(A), const_cast<CutlassWeightType_*>(B), const_cast<ElementType_*>(biases),
reinterpret_cast<ElementType_*>(C), total_tokens_including_expert, static_cast<int>(gemm_n),
static_cast<int>(gemm_k), num_experts, bias_is_broadcast},
num_experts, threadblock_count};
auto params = GemmType::to_underlying_arguments(args);
if (GemmType::kSmemSize >= (48 << 10))
{
cudaError_t result = cudaFuncSetAttribute(
fused_moe::run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize);
TLLM_CHECK_WITH_INFO(result == cudaSuccess,
"Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel");
}
dim3 grid(params.threadblock_count, 1, 1);
dim3 block(GemmType::kThreadCount);
fused_moe::run_global<GemmType><<<grid, block, GemmType::kSmemSize, stream>>>(params);
auto result = cudaGetLastError();
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result));
}
} // namespace tensorrt_llm::kernels::cutlass_kernels
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h
0 → 100644
View file @
e81d7f11
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda_runtime_api.h>
namespace
tensorrt_llm
{
namespace
kernels
{
namespace
cutlass_kernels
{
// Keep in sync with the signature generated by generate_kernels.py
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
EpilogueTag
,
HopperGroupedGemmInput
::
EpilogueFusion
FUSION
,
typename
TileShape
,
typename
ClusterShape
,
bool
BIAS
>
void
sm90_generic_moe_gemm_kernelLauncher
(
HopperGroupedGemmInput
hopper_input
,
int
num_experts
,
int
multi_processor_count
,
cudaStream_t
stream
,
int
*
kernel_occupancy
,
size_t
*
workspace_size
);
}
// namespace cutlass_kernels
}
// namespace kernels
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl
0 → 100644
View file @
e81d7f11
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion;
// Hopper helper class for defining all the cutlass helper types
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, typename TileShape,
typename ClusterShape, bool BIAS, EpilogueFusion FUSION>
struct HopperGroupedGemmInfo
{
using Arch = cutlass::arch::Sm90;
// TODO Update once mixed input support is added
static_assert(cutlass::platform::is_same<T, WeightType>::value,
"CUTLASS does not currently have specialised SM90 support for quantized operations");
#ifdef ENABLE_FP8
constexpr static bool IsFP8
= cutlass::platform::is_same<T, __nv_fp8_e4m3>::value || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value;
#else
constexpr static bool IsFP8 = false;
#endif
#ifdef ENABLE_BF16
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for bfloat16, half, float, fp8");
#else
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for half, float, fp8");
#endif
static_assert(cutlass::platform::is_same<T, WeightType>::value
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e4m3_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e5m2_t>::value,
"Unexpected quantization type");
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter<WeightType>::type;
// For legacy reasons we convert unsigned 8-bit to signed
using CutlassWeightTypeMaybeUint8
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint4, cutlass::uint4b_t>, cutlass::int4b_t,
CutlassWeightTypeMaybeUint4>;
using CutlassWeightType
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint8, uint8_t>, int8_t, CutlassWeightTypeMaybeUint8>;
using ElementA = ElementType;
using ElementB = CutlassWeightType;
using ElementD = typename TllmToCutlassTypeAdapter<HopperGroupedGemmInput::OutputTypeAdaptor_t<OutputType>>::type;
using ElementFinalOutput = typename TllmToCutlassTypeAdapter<OutputType>::type;
// using ElementC = std::conditional_t<BIAS, ElementType, void>;
// using ElementCNoVoid = std::conditional_t<BIAS, ElementType, ElementD>;
using ElementC = void;
using ElementCNoVoid = ElementD;
using ElementAccumulator = float;
using ElementBias = ElementFinalOutput;
using ElementRouterScales = float;
// A matrix configuration - this is transposed and swapped with B
using LayoutA = HopperGroupedGemmInput::LayoutA;
constexpr static int AlignmentA
= 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units
// of elements (up to 16 bytes)
// B matrix configuration - this is transposed and swapped with A
using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand
constexpr static int AlignmentB
= 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units
// of elements (up to 16 bytes)
// C matrix configuration
using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand
using StrideC = HopperGroupedGemmInput::StrideC;
// Note we use ElementType here deliberately, so we don't break when BIAS is disabled
constexpr static int AlignmentC
= 128 / cutlass::sizeof_bits<ElementType>::value; // Memory access granularity/alignment of C matrix in units
// of elements (up to 16 bytes)
// D matrix configuration
using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD;
using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD;
constexpr static int AlignmentD
= 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix
// in units of elements (up to 16 bytes)
static_assert(cutlass::platform::is_same<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpDefault>::value,
"Hopper Grouped GEMM specialisation doesn't support fused activation");
using EpilogueOp
= cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
// TODO Add mode for fused activation once CUTLASS adds support
// using EpilogueSchedule = cutlass::platform::conditional_t<
// cutlass::platform::is_same<EpilogueOp, EpilogueOpDefault>::value,
// cutlass::epilogue::PtrArrayNoSmemWarpSpecialized,
// cutlass::epilogue::?????????????????? /// <<<<<< what supports activations
// >;
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;
// Epilogue For Default Finalize
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
TileShape, ClusterShape, //
cutlass::epilogue::collective::EpilogueTileAuto, //
ElementAccumulator, ElementAccumulator, //
ElementC, LayoutC*, AlignmentC, //
ElementD, LayoutD*, AlignmentD, //
EpilogueSchedule>::CollectiveOp;
// Epilogue For Fused Finalize
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< //
TileShape, //
ElementCNoVoid, StrideC*, //
ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, //
ElementAccumulator, //
ElementAccumulator, //
ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, //
ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales //
>::CollectiveOp;
using CollectiveEpilogue
= std::conditional_t<FUSION == EpilogueFusion::FINALIZE, CollectiveEpilogueFinalize, CollectiveEpilogueDefault>;
using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>;
using KernelSchedule
= std::conditional_t<IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here
ElementType, LayoutA*, AlignmentA, //
ElementAccumulator, //
TileShape, ClusterShape, //
StageCountAutoCarveout, KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<HopperGroupedGemmInput::ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
// Hopper specialised version
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename TileShape, typename ClusterShape, bool BIAS>
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size)
{
#ifdef COMPILE_HOPPER_TMA_GEMMS
using namespace cute;
if constexpr (!should_filter_sm90_gemm_problem_shape_v<TileShape, ClusterShape, T>)
{
using GemmInfo
= HopperGroupedGemmInfo<T, WeightType, OutputType, EpilogueTag, TileShape, ClusterShape, BIAS, FUSION>;
using ElementAccumulator = typename GemmInfo::ElementAccumulator;
using ElementA = typename GemmInfo::ElementA;
using ElementB = typename GemmInfo::ElementB;
using ElementC = typename GemmInfo::ElementC;
using ElementCNoVoid = typename GemmInfo::ElementCNoVoid;
using ElementD = typename GemmInfo::ElementD;
using CollectiveMainloop = typename GemmInfo::CollectiveMainloop;
using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue;
using GemmKernel = typename GemmInfo::GemmKernel;
using GemmGrouped = typename GemmInfo::GemmGrouped;
if (kernel_occupancy != nullptr)
{
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel, true>();
return;
}
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = multi_processor_count;
GemmGrouped gemm;
if (workspace_size != nullptr)
{
// Make a mock problem shape with just the minimal information actually required to get the workspace size
// This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to
// catch future cutlass updates causing silent breakages, but that is not fool proof.
// The alternative is to wait until we have data and then dynamically allocate the workspace
typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr};
typename GemmGrouped::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info};
*workspace_size = gemm.get_workspace_size(args);
return;
}
using MainloopArguments = typename CollectiveMainloop::Arguments;
TLLM_CHECK(hopper_input.stride_a);
TLLM_CHECK(hopper_input.stride_b);
TLLM_CHECK(hopper_input.ptr_a);
TLLM_CHECK(hopper_input.ptr_b);
MainloopArguments const mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
hopper_input.stride_b, reinterpret_cast<ElementA const**>(hopper_input.ptr_a), hopper_input.stride_a};
typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{
ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)};
epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
// TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS
auto make_epi_args = [&]()
{
if constexpr (FUSION == EpilogueFusion::NONE)
{
auto epi_params = hopper_input.default_epilogue;
return EpilogueArguments{epilogue_scalars, reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c),
hopper_input.stride_c, reinterpret_cast<ElementD**>(epi_params.ptr_d), epi_params.stride_d};
}
else if constexpr (FUSION == EpilogueFusion::FINALIZE)
{
// Parameters for fused finalize
auto epi_params = hopper_input.fused_finalize_epilogue;
return EpilogueArguments{
epilogue_scalars, // Parameters to underlying epilogue
reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c), hopper_input.stride_c, // C params
reinterpret_cast<typename GemmInfo::ElementFinalOutput*>(epi_params.ptr_final_output),
epi_params.stride_final_output, // D (output) params
reinterpret_cast<typename GemmInfo::ElementBias const*>(epi_params.ptr_bias),
epi_params.stride_bias, // Bias params
epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales
epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales
epi_params.ptr_source_token_index, // Index of the source token to sum into
epi_params.num_rows_in_final_output // Number of tokens in the output buffer
};
}
else
{
static_assert(
sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher");
}
};
EpilogueArguments const epilogue_params = make_epi_args();
typename GemmKernel::TileScheduler::Arguments scheduler_args{
1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN};
typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info,
mainloop_params, epilogue_params, hw_info, scheduler_args};
size_t calculated_ws_size = gemm.get_workspace_size(args);
TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size,
"Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size);
auto can_implement = gemm.can_implement(args);
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
"Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
auto init_status = gemm.initialize(args, hopper_input.gemm_workspace);
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
"Failed to initialize cutlass SM90 grouped gemm. Error: "
+ std::string(cutlassGetStatusString(init_status)));
auto run_status = gemm.run(stream);
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
"Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
sync_check_cuda_error();
}
else
{
TLLM_THROW("Configuration was disabled by FAST_BUILD");
}
#else // COMPILE_HOPPER_TMA_GEMMS
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
#endif // COMPILE_HOPPER_TMA_GEMMS
}
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu
0 → 100644
View file @
e81d7f11
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/conv/convolution.h"
// Order matters here, packed_stride.hpp is missing cute and convolution includes
#include "cutlass/util/packed_stride.hpp"
#include "tensorrt_llm/common/logger.h"
namespace
tensorrt_llm
{
std
::
array
<
size_t
,
10
>
HopperGroupedGemmInput
::
workspaceBuffers
(
int
num_experts
)
{
size_t
problem_shape_size
=
sizeof
(
ProblemShape
::
UnderlyingProblemShape
)
*
num_experts
;
size_t
stride_a_size
=
sizeof
(
StrideA
)
*
num_experts
;
size_t
stride_b_size
=
sizeof
(
StrideB
)
*
num_experts
;
size_t
stride_c_size
=
sizeof
(
StrideC
)
*
num_experts
;
size_t
stride_d_size
=
sizeof
(
DefaultEpilogue
::
StrideD
)
*
num_experts
;
size_t
ptr_buf_size
=
sizeof
(
void
*
)
*
num_experts
;
size_t
scale_buf_size
=
sizeof
(
float
*
)
*
num_experts
;
return
std
::
array
{
problem_shape_size
,
stride_a_size
,
stride_b_size
,
stride_c_size
,
stride_d_size
,
ptr_buf_size
,
ptr_buf_size
,
ptr_buf_size
,
ptr_buf_size
,
scale_buf_size
};
}
size_t
HopperGroupedGemmInput
::
workspaceSize
(
int
num_experts
)
{
auto
buffers
=
workspaceBuffers
(
num_experts
);
return
tensorrt_llm
::
common
::
calculateTotalWorkspaceSize
(
buffers
.
data
(),
buffers
.
size
());
}
void
HopperGroupedGemmInput
::
configureWorkspace
(
int8_t
*
start_ptr
,
int
num_experts
,
void
*
gemm_workspace
,
size_t
gemm_workspace_size
)
{
auto
buffers
=
workspaceBuffers
(
num_experts
);
std
::
array
<
int8_t
*
,
10
>
pointers
{};
TLLM_CHECK_WITH_INFO
(
pointers
.
size
()
==
buffers
.
size
(),
"Mismatching workspace size and number of buffers"
);
for
(
int
i
=
0
;
i
<
buffers
.
size
();
i
++
)
{
pointers
[
i
]
=
start_ptr
;
start_ptr
=
tensorrt_llm
::
common
::
nextWorkspacePtr
(
start_ptr
,
buffers
[
i
]);
}
shape_info
.
num_groups
=
num_experts
;
shape_info
.
problem_shapes
=
reinterpret_cast
<
ProblemShape
::
UnderlyingProblemShape
*>
(
pointers
[
0
]);
shape_info
.
host_problem_shapes
=
nullptr
;
stride_a
=
reinterpret_cast
<
StrideA
*>
(
pointers
[
1
]);
stride_b
=
reinterpret_cast
<
StrideB
*>
(
pointers
[
2
]);
stride_c
=
reinterpret_cast
<
StrideC
*>
(
pointers
[
3
]);
default_epilogue
.
stride_d
=
reinterpret_cast
<
DefaultEpilogue
::
StrideD
*>
(
pointers
[
4
]);
ptr_a
=
reinterpret_cast
<
void
const
**>
(
pointers
[
5
]);
ptr_b
=
reinterpret_cast
<
void
const
**>
(
pointers
[
6
]);
ptr_c
=
reinterpret_cast
<
void
const
**>
(
pointers
[
7
]);
default_epilogue
.
ptr_d
=
reinterpret_cast
<
void
**>
(
pointers
[
8
]);
alpha_scale_ptr_array
=
reinterpret_cast
<
float
const
**>
(
pointers
[
9
]);
this
->
gemm_workspace
=
reinterpret_cast
<
uint8_t
*>
(
gemm_workspace
);
this
->
gemm_workspace_size
=
gemm_workspace_size
;
}
void
HopperGroupedGemmInput
::
setFinalizeFusionParams
(
void
*
final_output
,
float
const
*
router_scales
,
int64_t
const
*
expert_first_token_offset
,
int
const
*
source_token_index
,
void
const
*
bias
,
int
hidden_size
,
int
num_output_tokens
)
{
fused_finalize_epilogue
.
ptr_final_output
=
final_output
;
fused_finalize_epilogue
.
ptr_router_scales
=
router_scales
;
fused_finalize_epilogue
.
ptr_bias
=
bias
;
fused_finalize_epilogue
.
ptr_expert_first_token_offset
=
expert_first_token_offset
;
fused_finalize_epilogue
.
ptr_source_token_index
=
source_token_index
;
fused_finalize_epilogue
.
stride_final_output
=
cutlass
::
make_cute_packed_stride
(
FusedFinalizeEpilogue
::
StrideFinalOutput
{},
transpose_stride
(
cute
::
make_shape
(
num_output_tokens
,
hidden_size
,
1
)));
fused_finalize_epilogue
.
stride_bias
=
transpose_stride
(
cute
::
make_stride
(
cute
::
Int
<
0
>
{},
cute
::
Int
<
1
>
{},
hidden_size
));
fused_finalize_epilogue
.
stride_router_scales
=
{};
fused_finalize_epilogue
.
num_rows_in_final_output
=
num_output_tokens
;
}
std
::
string
HopperGroupedGemmInput
::
toString
()
const
{
std
::
stringstream
ss
;
ss
<<
"Hopper Input Information: "
<<
(
isValid
()
?
"valid"
:
"null"
)
<<
"
\n
"
;
if
(
isValid
())
{
ss
<<
"Ptr A: "
<<
ptr_a
<<
", Ptr B: "
<<
ptr_b
<<
", Ptr C: "
<<
ptr_c
<<
"
\n
"
;
ss
<<
"Epilogue Fusion: "
<<
(
int
)
fusion
;
if
(
fusion
==
HopperGroupedGemmInput
::
EpilogueFusion
::
FINALIZE
)
{
ss
<<
",
\n
Final Output: "
<<
fused_finalize_epilogue
.
ptr_final_output
;
ss
<<
" with Stride: "
<<
fused_finalize_epilogue
.
stride_router_scales
;
ss
<<
",
\n
Bias: "
<<
fused_finalize_epilogue
.
ptr_bias
;
ss
<<
" with Stride: "
<<
fused_finalize_epilogue
.
stride_bias
;
ss
<<
",
\n
Router Scales: "
<<
fused_finalize_epilogue
.
ptr_router_scales
;
ss
<<
" with Stride: "
<<
fused_finalize_epilogue
.
stride_router_scales
;
ss
<<
",
\n
Expert Offset: "
<<
fused_finalize_epilogue
.
ptr_expert_first_token_offset
;
ss
<<
", Source Map: "
<<
fused_finalize_epilogue
.
ptr_source_token_index
;
}
else
{
ss
<<
", Ptr D: "
<<
default_epilogue
.
ptr_d
;
}
ss
<<
'\n'
;
ss
<<
"Alpha scale ptr: "
<<
alpha_scale_ptr_array
<<
"
\n
"
;
}
return
ss
.
str
();
}
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h
0 → 100644
View file @
e81d7f11
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/workspace.h"
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h"
#include <array>
#include <cuda_runtime_api.h>
#include <optional>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/layout/layout.h"
namespace
tensorrt_llm
{
template
<
class
T
>
constexpr
auto
transpose_stride
(
T
const
&
t
)
{
return
cute
::
prepend
(
cute
::
prepend
(
cute
::
take
<
2
,
cute
::
rank_v
<
T
>>
(
t
),
cute
::
get
<
0
>
(
t
)),
cute
::
get
<
1
>
(
t
));
}
struct
HopperGroupedGemmInput
{
template
<
class
T
>
using
TransposeStride
=
decltype
(
transpose_stride
<
T
>
(
T
{}));
template
<
class
Tag
>
using
TransposeLayoutTag
=
std
::
conditional_t
<
std
::
is_same_v
<
Tag
,
cutlass
::
layout
::
RowMajor
>
,
cutlass
::
layout
::
ColumnMajor
,
cutlass
::
layout
::
RowMajor
>
;
static_assert
(
std
::
is_same_v
<
cutlass
::
layout
::
RowMajor
,
TransposeLayoutTag
<
cutlass
::
layout
::
ColumnMajor
>>
);
static_assert
(
std
::
is_same_v
<
cutlass
::
layout
::
ColumnMajor
,
TransposeLayoutTag
<
cutlass
::
layout
::
RowMajor
>>
);
// Layout for A and B is transposed and then swapped in the implementation
// This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM
using
LayoutA
=
TransposeLayoutTag
<
cutlass
::
layout
::
RowMajor
>
;
// Layout type for A matrix operand
using
LayoutB
=
TransposeLayoutTag
<
cutlass
::
layout
::
ColumnMajor
>
;
// Layout type for B matrix operand
using
LayoutC
=
TransposeLayoutTag
<
cutlass
::
layout
::
RowMajor
>
;
// Layout type for C matrix operand
using
StrideA
=
std
::
remove_pointer_t
<
cutlass
::
detail
::
TagToStrideB_t
<
LayoutA
*>>
;
// Use B because they will be swapped
using
StrideB
=
std
::
remove_pointer_t
<
cutlass
::
detail
::
TagToStrideA_t
<
LayoutB
*>>
;
// Use A because they will be swapped
using
StrideC
=
std
::
remove_pointer_t
<
cutlass
::
detail
::
TagToStrideC_t
<
LayoutC
*>>
;
template
<
class
T
>
constexpr
static
bool
IsFP8_v
=
std
::
is_same_v
<
T
,
__nv_fp8_e4m3
>
||
std
::
is_same_v
<
T
,
__nv_fp8_e5m2
>
;
// Currently this should always just be T
template
<
class
T
>
using
OutputTypeAdaptor_t
=
std
::
conditional_t
<
IsFP8_v
<
T
>
,
nv_bfloat16
,
T
>
;
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
cute
::
Shape
<
int64_t
,
int64_t
,
int64_t
>>
;
ProblemShape
shape_info
{};
StrideA
*
stride_a
=
nullptr
;
StrideB
*
stride_b
=
nullptr
;
void
const
**
ptr_a
=
nullptr
;
void
const
**
ptr_b
=
nullptr
;
// C is currently the same in both epilogues
StrideC
*
stride_c
=
nullptr
;
void
const
**
ptr_c
=
nullptr
;
struct
DefaultEpilogue
{
using
LayoutD
=
TransposeLayoutTag
<
cutlass
::
layout
::
RowMajor
>
;
// Layout type for D matrix operand
using
StrideD
=
std
::
remove_pointer_t
<
cutlass
::
detail
::
TagToStrideC_t
<
LayoutD
*>>
;
StrideD
*
stride_d
=
nullptr
;
void
**
ptr_d
=
nullptr
;
};
struct
FusedFinalizeEpilogue
{
using
StrideFinalOutput
=
DefaultEpilogue
::
StrideD
;
using
StrideBias
=
TransposeStride
<
cute
::
Stride
<
cute
::
_0
,
cute
::
_1
,
int
>>
;
using
StrideRouterScales
=
TransposeStride
<
cute
::
Stride
<
cute
::
_1
,
cute
::
_0
>>
;
void
*
ptr_final_output
=
nullptr
;
StrideFinalOutput
stride_final_output
{};
void
const
*
ptr_bias
=
nullptr
;
StrideBias
stride_bias
{};
float
const
*
ptr_router_scales
=
nullptr
;
StrideRouterScales
stride_router_scales
{};
int64_t
const
*
ptr_expert_first_token_offset
=
nullptr
;
int
const
*
ptr_source_token_index
=
nullptr
;
size_t
num_rows_in_final_output
=
0
;
};
DefaultEpilogue
default_epilogue
;
FusedFinalizeEpilogue
fused_finalize_epilogue
;
enum
class
EpilogueFusion
{
NONE
,
ACTIVATION
,
GATED_ACTIVATION
,
FINALIZE
};
EpilogueFusion
fusion
=
EpilogueFusion
::
NONE
;
float
const
**
alpha_scale_ptr_array
=
nullptr
;
uint8_t
*
gemm_workspace
=
nullptr
;
size_t
gemm_workspace_size
=
0
;
static
std
::
array
<
size_t
,
10
>
workspaceBuffers
(
int
num_experts
);
static
size_t
workspaceSize
(
int
num_experts
);
void
configureWorkspace
(
int8_t
*
start_ptr
,
int
num_experts
,
void
*
gemm_workspace
,
size_t
gemm_workspace_size
);
bool
isValid
()
const
{
return
stride_a
!=
nullptr
&&
ptr_a
!=
nullptr
;
}
void
setFinalizeFusionParams
(
void
*
final_output
,
float
const
*
router_scales
,
int64_t
const
*
expert_first_token_offset
,
int
const
*
source_token_index
,
void
const
*
bias
,
int
hidden_size
,
int
num_output_tokens
);
std
::
string
toString
()
const
;
};
// Note update moe.py to match
enum
class
ActivationType
{
Gelu
=
0
,
Relu
,
Silu
,
Swiglu
,
Geglu
,
Identity
,
InvalidType
};
constexpr
bool
isGatedActivation
(
ActivationType
activation_type
)
{
return
activation_type
==
ActivationType
::
Swiglu
||
activation_type
==
ActivationType
::
Geglu
;
}
template
<
typename
T
,
/*The type used for activations/scales/compute*/
typename
WeightType
,
/* The type for the MoE weights */
typename
OutputType
,
/* The output type for the GEMM */
typename
ScaleBiasType
=
OutputType
/* The type for the scales/bias */
>
class
MoeGemmRunner
{
public:
MoeGemmRunner
();
#if defined(ENABLE_FP8)
static
constexpr
bool
use_fp8
=
std
::
is_same_v
<
T
,
__nv_fp8_e4m3
>
||
std
::
is_same_v
<
T
,
__nv_fp8_e5m2
>
;
#else
static
constexpr
bool
use_fp8
=
false
;
#endif
void
moeGemmBiasAct
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
layout_info
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
ActivationType
activation_type
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
);
void
moeGemm
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
layout_info
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
);
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
getConfigs
()
const
;
static
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
getConfigs
(
int
sm
);
static
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
getHopperConfigs
(
int
sm
);
static
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
getAmpereConfigs
(
int
sm
);
[[
nodiscard
]]
bool
isHopperSpecialised
(
cutlass_extensions
::
CutlassGemmConfig
gemm_config
)
const
;
[[
nodiscard
]]
bool
supportsHopperSpecialisation
()
const
;
[[
nodiscard
]]
bool
isFusedGatedActivation
(
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
bool
is_gated_activation
,
int
gemm_n
,
int
gemm_k
)
const
;
[[
nodiscard
]]
bool
supportsFusedGatedActivation
(
bool
is_gated_activation
,
int
gemm_n
,
int
gemm_k
)
const
;
size_t
getMaxWorkspaceSize
(
int
num_experts
)
const
;
[[
nodiscard
]]
int
getSM
()
const
;
private:
template
<
typename
EpilogueTag
>
void
dispatchToArch
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
layout_info
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
);
template
<
typename
EpilogueTag
>
void
runGemm
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
layout_info
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
);
private:
int
sm_
{};
int
multi_processor_count_
{};
mutable
int
num_experts_
=
0
;
mutable
size_t
gemm_workspace_size_
=
0
;
size_t
calcMaxWorkspaceSize
(
int
num_experts
)
const
;
};
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
#ifdef ENABLE_BF16
template
class
MoeGemmRunner
<
__nv_bfloat16
,
__nv_bfloat16
,
__nv_bfloat16
>;
#endif
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
#ifdef ENABLE_BF16
template
class
MoeGemmRunner
<
__nv_bfloat16
,
cutlass
::
uint4b_t
,
__nv_bfloat16
>;
#endif
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
#ifdef ENABLE_BF16
template
class
MoeGemmRunner
<
__nv_bfloat16
,
uint8_t
,
__nv_bfloat16
>;
#endif
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
template
class
MoeGemmRunner
<
half
,
half
,
half
>;
}
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
template
class
MoeGemmRunner
<
half
,
cutlass
::
uint4b_t
,
half
>;
}
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
template
class
MoeGemmRunner
<
half
,
uint8_t
,
half
>;
}
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
template
class
MoeGemmRunner
<
float
,
float
,
float
>;
}
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace
tensorrt_llm
{
#ifdef ENABLE_FP8
template
class
MoeGemmRunner
<
__nv_fp8_e4m3
,
__nv_fp8_e4m3
,
half
>;
#ifdef ENABLE_BF16
template
class
MoeGemmRunner
<
__nv_fp8_e4m3
,
__nv_fp8_e4m3
,
__nv_bfloat16
>;
#endif
// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>;
#endif
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#ifdef __GNUC__ // Restore GCC-specific diagnostics
#pragma GCC diagnostic pop
#endif
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "moe_gemm_kernels_template_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include <tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace
tensorrt_llm
{
namespace
kernels
::
cutlass_kernels
{
// ============================= Variable batched Gemm things ===========================
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
ThreadblockShape
,
typename
WarpShape
,
int
Stages
>
void
genericMoeGemmKernelLauncher
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
num_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
const
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
kernel_occupancy
=
nullptr
)
{
#if defined(ENABLE_FP8)
static_assert
(
cutlass
::
platform
::
is_same
<
T
,
__nv_bfloat16
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
half
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
__nv_fp8_e4m3
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
__nv_fp8_e5m2
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
float
>::
value
,
"Specialized for fp8, bfloat16, half, float"
);
#elif defined(ENABLE_BF16)
static_assert
(
cutlass
::
platform
::
is_same
<
T
,
__nv_bfloat16
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
half
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
float
>::
value
,
"Specialized for bfloat16, half, float"
);
#else
static_assert
(
cutlass
::
platform
::
is_same
<
T
,
half
>::
value
||
cutlass
::
platform
::
is_same
<
T
,
float
>::
value
,
"Specialized for half, float"
);
#endif
static_assert
(
cutlass
::
platform
::
is_same
<
T
,
WeightType
>::
value
||
cutlass
::
platform
::
is_same
<
WeightType
,
uint8_t
>::
value
||
cutlass
::
platform
::
is_same
<
WeightType
,
cutlass
::
uint4b_t
>::
value
,
""
);
static_assert
(
!
cutlass
::
platform
::
is_same
<
arch
,
cutlass
::
arch
::
Sm90
>::
value
,
"Sm90 architecture should use specialised kernels"
);
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using
ElementType
=
typename
TllmToCutlassTypeAdapter
<
T
>::
type
;
using
CutlassGemmOutputType
=
typename
TllmToCutlassTypeAdapter
<
GemmOutputType
>::
type
;
using
CutlassWeightType
=
typename
TllmToCutlassTypeAdapter
<
WeightType
>::
type
;
if
(
!
use_fused_moe
)
{
// We need separate config for each architecture since we will target different tensorcore instructions. For
// float, we do not target TCs.
using
MixedGemmArchTraits
=
cutlass
::
gemm
::
kernel
::
MixedGemmArchTraits
<
ElementType
,
CutlassWeightType
,
arch
>
;
using
ElementAccumulator
=
typename
MixedGemmArchTraits
::
AccType
;
using
EpilogueOp
=
typename
tensorrt_llm
::
cutlass_extensions
::
Epilogue
<
CutlassGemmOutputType
,
MixedGemmArchTraits
::
ElementsPerAccessC
,
ElementAccumulator
,
EpilogueTag
>::
Op
;
typename
EpilogueOp
::
Params
epilogue_op
(
ElementAccumulator
(
1.
f
),
biases
?
ElementAccumulator
(
1.
f
)
:
ElementAccumulator
(
0.
f
));
#if defined(ENABLE_FP8)
if
constexpr
((
std
::
is_same_v
<
T
,
__nv_fp8_e4m3
>
||
std
::
is_same_v
<
T
,
__nv_fp8_e5m2
>
)
&&
std
::
is_same_v
<
EpilogueTag
,
cutlass_extensions
::
EpilogueOpDefault
>
)
{
TLLM_CHECK_WITH_INFO
(
weight_scales
==
nullptr
&&
biases
==
nullptr
&&
alpha_scale_ptr_array
,
"weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 "
"Ada"
);
epilogue_op
.
alpha_ptr_array
=
alpha_scale_ptr_array
;
}
#endif
// Finally, set up the kernel.
using
GemmKernel_
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmGrouped
<
ElementType
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
MixedGemmArchTraits
::
ElementsPerAccessA
,
CutlassWeightType
,
typename
MixedGemmArchTraits
::
LayoutB
,
cutlass
::
ComplexTransform
::
kNone
,
MixedGemmArchTraits
::
ElementsPerAccessB
,
CutlassGemmOutputType
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
typename
MixedGemmArchTraits
::
OperatorClass
,
arch
,
ThreadblockShape
,
WarpShape
,
typename
MixedGemmArchTraits
::
InstructionShape
,
EpilogueOp
,
cutlass
::
gemm
::
threadblock
::
GemmBatchedIdentityThreadblockSwizzle
,
Stages
,
cutlass
::
gemm
::
kernel
::
GroupScheduleMode
::
kDeviceOnly
,
typename
MixedGemmArchTraits
::
Operator
>::
GemmKernel
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
MoeFCGemm
<
typename
GemmKernel_
::
Mma
,
typename
GemmKernel_
::
Epilogue
,
typename
GemmKernel_
::
ThreadblockSwizzle
,
arch
,
// Ensure top level arch is used for dispatch
GemmKernel_
::
kGroupScheduleMode
>
;
using
GemmGrouped
=
cutlass
::
gemm
::
device
::
GemmGrouped
<
GemmKernel
>
;
if
(
kernel_occupancy
!=
nullptr
)
{
*
kernel_occupancy
=
tensorrt_llm
::
cutlass_extensions
::
compute_occupancy_for_kernel
<
GemmKernel
>
();
return
;
}
int
occupancy
=
std
::
min
(
2
,
GemmGrouped
::
maximum_active_blocks
());
TLLM_CHECK_WITH_INFO
(
occupancy
>
0
,
"GPU lacks the shared memory resources to run GroupedGEMM kernel"
);
int
const
threadblock_count
=
multi_processor_count
*
occupancy
;
int
const
group_size
=
gemm_k
;
typename
GemmGrouped
::
Arguments
args
(
num_experts
,
threadblock_count
,
group_size
,
epilogue_op
,
reinterpret_cast
<
ElementType
const
*>
(
A
),
reinterpret_cast
<
CutlassWeightType
const
*>
(
B
),
reinterpret_cast
<
CutlassGemmOutputType
const
*>
(
weight_scales
),
reinterpret_cast
<
CutlassGemmOutputType
const
*>
(
biases
),
bias_is_broadcast
,
reinterpret_cast
<
CutlassGemmOutputType
*>
(
C
),
total_tokens_including_expert
,
gemm_n
,
gemm_k
);
GemmGrouped
gemm
;
auto
can_implement
=
gemm
.
can_implement
(
args
);
TLLM_CHECK_WITH_INFO
(
can_implement
==
cutlass
::
Status
::
kSuccess
,
"MoE FC kernel will fail for params. Error: "
+
std
::
string
(
cutlassGetStatusString
(
can_implement
)));
auto
init_status
=
gemm
.
initialize
(
args
);
TLLM_CHECK_WITH_INFO
(
init_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize cutlass grouped gemm. Error: "
+
std
::
string
(
cutlassGetStatusString
(
init_status
)));
auto
run_status
=
gemm
.
run
(
stream
);
TLLM_CHECK_WITH_INFO
(
run_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run cutlass grouped gemm. Error: "
+
std
::
string
(
cutlassGetStatusString
(
run_status
)));
}
else
if
constexpr
(
sizeof
(
ElementType
)
==
2
&&
sizeof
(
CutlassWeightType
)
==
2
&&
(
std
::
is_same_v
<
EpilogueTag
,
cutlass_extensions
::
EpilogueOpDefaultSilu
>
||
std
::
is_same_v
<
EpilogueTag
,
cutlass_extensions
::
EpilogueOpDefaultFtGelu
>
)
)
// use fused moe gemm
// kernel.. (only support
// fp16 or bf16)
{
sm80_generic_fused_moe_gemm_kernelLauncher
<
ElementType
,
CutlassWeightType
,
ThreadblockShape
::
kM
,
ThreadblockShape
::
kN
,
ThreadblockShape
::
kK
,
Stages
,
EpilogueTag
>
(
reinterpret_cast
<
ElementType
const
*>
(
A
),
reinterpret_cast
<
CutlassWeightType
const
*>
(
B
),
reinterpret_cast
<
ElementType
const
*>
(
biases
),
bias_is_broadcast
,
reinterpret_cast
<
ElementType
*>
(
C
),
total_tokens_including_expert
,
num_rows
,
gemm_n
,
gemm_k
,
num_experts
,
multi_processor_count
,
stream
,
kernel_occupancy
);
}
}
}
// namespace kernels::cutlass_kernels
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
Arch
,
typename
EpilogueTag
,
typename
ThreadblockShape
,
typename
WarpShape
,
int
Stages
>
static
void
dispatch
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
num_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
static_assert
(
!
std
::
is_same_v
<
Arch
,
cutlass
::
arch
::
Sm90
>
,
"Use TMA specialised functions for arch SM90"
);
#if defined(ENABLE_FP8)
constexpr
bool
isFp8
=
std
::
is_same_v
<
T
,
__nv_fp8_e4m3
>
||
std
::
is_same_v
<
T
,
__nv_fp8_e5m2
>
;
#else
constexpr
bool
isFp8
=
false
;
#endif
if
constexpr
((
Stages
==
2
||
Arch
::
kMinComputeCapability
>=
80
)
&&
(
!
isFp8
||
std
::
is_same_v
<
Arch
,
cutlass
::
arch
::
Sm89
>
)
)
{
kernels
::
cutlass_kernels
::
genericMoeGemmKernelLauncher
<
T
,
WeightType
,
GemmOutputType
,
Arch
,
EpilogueTag
,
ThreadblockShape
,
WarpShape
,
Stages
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
num_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
else
{
TLLM_THROW
(
"Cutlass gemm. Not instantiated for arch %d with stages set to %d"
,
Arch
::
kMinComputeCapability
,
Stages
);
}
}
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
ThreadblockShape
,
typename
WarpShape
>
void
dispatchGemmConfig
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
num_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
switch
(
gemm_config
.
stages
)
{
case
2
:
dispatch
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
ThreadblockShape
,
WarpShape
,
2
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
num_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
3
:
dispatch
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
ThreadblockShape
,
WarpShape
,
3
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
num_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
4
:
dispatch
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
ThreadblockShape
,
WarpShape
,
4
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
num_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
default:
TLLM_THROW
(
"dispatchGemmConfig does not support stages %d"
,
gemm_config
.
stages
);
break
;
}
}
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
// This overload is only enabled when T == WeightType.
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
float
>
::
value
#if defined(ENABLE_FP8)
&&
!
std
::
is_same
<
T
,
__nv_fp8_e4m3
>::
value
&&
!
std
::
is_same
<
T
,
__nv_fp8_e5m2
>::
value
#endif
&&
std
::
is_same
<
T
,
WeightType
>::
value
>::
type
*
=
nullptr
>
void
dispatchMoeGemmToCutlass
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
switch
(
gemm_config
.
tile_config
)
{
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape16x128x64_WarpShape16x32x64
:
TLLM_CHECK_WITH_INFO
(
arch
::
kMinComputeCapability
>=
75
,
"Invalid config on Volta"
);
if
constexpr
(
arch
::
kMinComputeCapability
>=
75
)
{
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape16x256x64_WarpShape16x64x64
:
TLLM_CHECK_WITH_INFO
(
arch
::
kMinComputeCapability
>=
75
,
"Invalid config on Volta"
);
if
constexpr
(
arch
::
kMinComputeCapability
>=
75
)
{
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
256
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape32x128x64_WarpShape32x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape64x128x64_WarpShape32x64x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape128x128x64_WarpShape64x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
Undefined
:
TLLM_THROW
(
"GEMM config undefined."
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
ChooseWithHeuristic
:
TLLM_THROW
(
"GEMM config should have already been set by heuristic."
);
break
;
default:
TLLM_THROW
(
"Config is invalid for same type tensorop GEMM."
);
break
;
}
}
// Tensorop GEMM overload
// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve
// compile time
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
float
>
::
value
&&
!
std
::
is_same
<
T
,
WeightType
>::
value
>::
type
*
=
nullptr
>
void
dispatchMoeGemmToCutlass
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
switch
(
gemm_config
.
tile_config
)
{
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape16x128x64_WarpShape16x32x64
:
TLLM_CHECK_WITH_INFO
(
arch
::
kMinComputeCapability
>=
75
,
"Invalid config on Volta"
);
if
constexpr
(
arch
::
kMinComputeCapability
>=
75
)
{
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape16x256x64_WarpShape16x64x64
:
TLLM_CHECK_WITH_INFO
(
arch
::
kMinComputeCapability
>=
75
,
"Invalid config on Volta"
);
if
constexpr
(
arch
::
kMinComputeCapability
>=
75
)
{
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
256
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape32x128x64_WarpShape32x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape64x128x64_WarpShape64x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape128x128x64_WarpShape128x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
128
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
Undefined
:
TLLM_THROW
(
"GEMM config undefined."
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
ChooseWithHeuristic
:
TLLM_THROW
(
"GEMM config should have already been set by heuristic."
);
break
;
default:
TLLM_THROW
(
"Config is invalid for mixed type tensorop GEMM."
);
break
;
}
}
// This overload will handle tensorop gemms.
// This overload is only enabled when T == WeightType and T == __nv_fp8_e4m3 or __nv_fp8_e5m2
#if defined(ENABLE_FP8)
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
std
::
enable_if
<
(
std
::
is_same
<
T
,
__nv_fp8_e4m3
>
::
value
||
std
::
is_same
<
T
,
__nv_fp8_e5m2
>::
value
)
&&
std
::
is_same
<
T
,
WeightType
>::
value
>::
type
*
=
nullptr
>
void
dispatchMoeGemmToCutlass
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
switch
(
gemm_config
.
tile_config
)
{
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape16x256x128_WarpShape16x64x128
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
256
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape32x128x64_WarpShape32x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape64x128x64_WarpShape64x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape64x64x128_WarpShape32x64x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape128x64x64_WarpShape64x32x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape128x256x64_WarpShape64x64x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
256
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape256x128x64_WarpShape64x64x64
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
Undefined
:
TLLM_THROW
(
"GEMM config undefined."
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
ChooseWithHeuristic
:
TLLM_THROW
(
"GEMM config should have already been set by heuristic."
);
break
;
default:
TLLM_THROW
(
"Config is invalid for same type tensorop GEMM."
);
break
;
}
}
#endif
// This overload will handle simt gemms. It is disabled via SFINAE for tensorop.
template
<
typename
T
,
typename
WeightType
,
typename
GemmOutputType
,
typename
arch
,
typename
EpilogueTag
,
typename
std
::
enable_if
<
std
::
is_same
<
T
,
float
>
::
value
>::
type
*
=
nullptr
>
void
dispatchMoeGemmToCutlass
(
T
const
*
A
,
WeightType
const
*
B
,
GemmOutputType
const
*
weight_scales
,
GemmOutputType
const
*
biases
,
bool
bias_is_broadcast
,
GemmOutputType
*
C
,
int64_t
const
*
total_tokens_including_expert
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
=
nullptr
)
{
switch
(
gemm_config
.
tile_config
)
{
case
cutlass_extensions
::
CutlassTileConfig
::
CtaShape128x128x8_WarpShape64x64x8
:
dispatchGemmConfig
<
T
,
WeightType
,
GemmOutputType
,
arch
,
EpilogueTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
8
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
8
>>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
Undefined
:
TLLM_THROW
(
"GEMM config undefined."
);
break
;
case
cutlass_extensions
::
CutlassTileConfig
::
ChooseWithHeuristic
:
TLLM_THROW
(
"GEMM config should have already been set by heuristic."
);
break
;
default:
TLLM_THROW
(
"Unsupported config for float MoE gemm."
);
break
;
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getConfigs
()
const
{
return
getConfigs
(
sm_
);
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getConfigs
(
int
sm
)
{
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
candidate_configs
=
getHopperConfigs
(
sm
);
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
ampere_configs
=
getAmpereConfigs
(
sm
);
std
::
copy
(
ampere_configs
.
begin
(),
ampere_configs
.
end
(),
std
::
back_inserter
(
candidate_configs
));
return
candidate_configs
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getAmpereConfigs
(
int
sm
)
{
using
tensorrt_llm
::
cutlass_extensions
::
CutlassGemmConfig
;
static
constexpr
auto
weight_only_flag
=
std
::
is_same
<
T
,
WeightType
>::
value
?
CutlassGemmConfig
::
NONE
:
CutlassGemmConfig
::
WEIGHT_ONLY
;
static
constexpr
auto
simt_only_flag
=
std
::
is_same
<
T
,
float
>::
value
?
CutlassGemmConfig
::
SIMT_ONLY
:
CutlassGemmConfig
::
NONE
;
static
constexpr
auto
fp8_only_flag
=
use_fp8
?
CutlassGemmConfig
::
FP8_ONLY
:
CutlassGemmConfig
::
NONE
;
int
const
max_split_k
=
1
;
int
const
grouped_gemm_flag
=
CutlassGemmConfig
::
GROUPED_GEMM
;
int
const
enable_hopper
=
CutlassGemmConfig
::
NONE
;
auto
config_type_param
=
static_cast
<
CutlassGemmConfig
::
CandidateConfigTypeParam
>
(
weight_only_flag
|
simt_only_flag
|
grouped_gemm_flag
|
enable_hopper
|
fp8_only_flag
);
if
(
!
kernels
::
cutlass_kernels
::
isValidAmpereMOESpecialisation
<
T
,
WeightType
>
())
{
return
{};
}
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
ampere_configs
=
kernels
::
cutlass_kernels
::
get_candidate_configs
(
sm
,
max_split_k
,
config_type_param
);
return
ampere_configs
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getHopperConfigs
(
int
sm
)
{
using
tensorrt_llm
::
cutlass_extensions
::
CutlassGemmConfig
;
static
constexpr
auto
weight_only_flag
=
std
::
is_same
<
T
,
WeightType
>::
value
?
CutlassGemmConfig
::
NONE
:
CutlassGemmConfig
::
WEIGHT_ONLY
;
static
constexpr
auto
simt_only_flag
=
std
::
is_same
<
T
,
float
>::
value
?
CutlassGemmConfig
::
SIMT_ONLY
:
CutlassGemmConfig
::
NONE
;
int
const
max_split_k
=
1
;
int
const
grouped_gemm_flag
=
CutlassGemmConfig
::
GROUPED_GEMM
;
int
const
enable_hopper
=
CutlassGemmConfig
::
HOPPER
;
static
constexpr
auto
fp8_only_flag
=
use_fp8
?
CutlassGemmConfig
::
FP8_ONLY
:
CutlassGemmConfig
::
NONE
;
auto
config_type_param
=
static_cast
<
CutlassGemmConfig
::
CandidateConfigTypeParam
>
(
weight_only_flag
|
simt_only_flag
|
grouped_gemm_flag
|
enable_hopper
|
fp8_only_flag
);
if
(
!
kernels
::
cutlass_kernels
::
isValidHopperMOESpecialisation
<
T
,
WeightType
>
())
{
return
{};
}
std
::
vector
<
cutlass_extensions
::
CutlassGemmConfig
>
hopper_configs
=
kernels
::
cutlass_kernels
::
get_candidate_configs
(
sm
,
max_split_k
,
config_type_param
);
return
hopper_configs
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
bool
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
isHopperSpecialised
(
cutlass_extensions
::
CutlassGemmConfig
gemm_config
)
const
{
bool
config_is_sm90
=
gemm_config
.
is_sm90
;
return
supportsHopperSpecialisation
()
&&
config_is_sm90
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
bool
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
supportsHopperSpecialisation
()
const
{
return
sm_
==
90
&&
kernels
::
cutlass_kernels
::
isValidHopperMOESpecialisation
<
T
,
WeightType
>
();
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
int
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getSM
()
const
{
return
this
->
sm_
;
}
// currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
bool
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
supportsFusedGatedActivation
(
bool
is_gated_activation
,
int
gemm_n
,
int
gemm_k
)
const
{
constexpr
bool
ENABLE_FUSED_GATED_ACTIVATION
=
true
;
return
is_gated_activation
&&
std
::
is_same_v
<
T
,
WeightType
>
&&
!
std
::
is_same_v
<
T
,
float
>
&&
!
use_fp8
&&
(
this
->
getSM
()
>=
80
)
&&
(
gemm_k
%
64
==
0
)
&&
(
gemm_n
%
64
==
0
)
&&
ENABLE_FUSED_GATED_ACTIVATION
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
bool
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
isFusedGatedActivation
(
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
bool
is_gated_activation
,
int
gemm_n
,
int
gemm_k
)
const
{
return
supportsFusedGatedActivation
(
is_gated_activation
,
gemm_n
,
gemm_k
)
&&
!
gemm_config
.
is_sm90
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
MoeGemmRunner
()
{
int
device
{
-
1
};
tensorrt_llm
::
common
::
check_cuda_error
(
cudaGetDevice
(
&
device
));
sm_
=
tensorrt_llm
::
common
::
getSMVersion
();
tensorrt_llm
::
common
::
check_cuda_error
(
cudaDeviceGetAttribute
(
&
multi_processor_count_
,
cudaDevAttrMultiProcessorCount
,
device
));
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
template
<
typename
EpilogueTag
>
void
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
dispatchToArch
<
EpilogueTag
>
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C_void
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
hopper_input
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
int
*
occupancy
)
{
static_assert
(
std
::
is_same_v
<
ScaleBiasType
,
OutputType
>
,
"Separate Scale/Bias type is not supported. This is assumed to be the gemm output type"
);
// For now we always cast this to output type.
// In the future this will vary based on what fusions are applied for FP8
auto
*
C
=
reinterpret_cast
<
OutputType
*>
(
C_void
);
TLLM_CHECK_WITH_INFO
(
sm_
>=
89
||
!
hopper_input
.
isValid
(),
"Hopper input information is set for non specialised implementation"
);
TLLM_CHECK_WITH_INFO
(
sm_
==
90
||
!
gemm_config
.
is_sm90
,
"Hopper configuration provided for non-Hopper architecture"
);
if
(
sm_
>=
75
&&
sm_
<
80
)
{
dispatchMoeGemmToCutlass
<
T
,
WeightType
,
ScaleBiasType
,
cutlass
::
arch
::
Sm75
,
EpilogueTag
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count_
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
else
if
(
sm_
>=
80
&&
sm_
<
90
)
{
if
constexpr
(
use_fp8
)
{
#if defined(ENABLE_FP8)
static_assert
(
!
std
::
is_same_v
<
OutputType
,
__nv_fp8_e4m3
>
&&
!
std
::
is_same_v
<
OutputType
,
__nv_fp8_e5m2
>
,
"FP8 GEMM Output not supported"
);
#endif
TLLM_CHECK_WITH_INFO
(
sm_
==
89
,
"For sm >= 80 and < 90, fp8 is only supported with sm == 89"
);
dispatchMoeGemmToCutlass
<
T
,
WeightType
,
ScaleBiasType
,
cutlass
::
arch
::
Sm89
,
EpilogueTag
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count_
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
else
{
dispatchMoeGemmToCutlass
<
T
,
WeightType
,
ScaleBiasType
,
cutlass
::
arch
::
Sm80
,
EpilogueTag
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count_
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
}
else
if
(
sm_
>=
90
)
{
if
constexpr
(
kernels
::
cutlass_kernels
::
isValidHopperMOESpecialisation
<
T
,
WeightType
,
EpilogueTag
>
())
{
// We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens
// SM80 is faster. We check here to see which is selected
if
(
gemm_config
.
is_sm90
)
{
TLLM_CHECK_WITH_INFO
(
biases
!=
nullptr
||
hopper_input
.
ptr_c
==
nullptr
,
"Input biases and hopper input disagree if bias is enabled"
);
TLLM_CHECK_WITH_INFO
(
hopper_input
.
isValid
(),
"Calling SM90 configuration with invalid hopper config"
);
// Select the appropriate fusion function
auto
select_function
=
[
&
]()
{
switch
(
hopper_input
.
fusion
)
{
case
HopperGroupedGemmInput
::
EpilogueFusion
::
FINALIZE
:
return
&
dispatchMoeGemmSelectTileShapeSM90
<
T
,
WeightType
,
OutputType
,
EpilogueTag
,
HopperGroupedGemmInput
::
EpilogueFusion
::
FINALIZE
>
;
case
HopperGroupedGemmInput
::
EpilogueFusion
::
NONE
:
return
&
dispatchMoeGemmSelectTileShapeSM90
<
T
,
WeightType
,
OutputType
,
EpilogueTag
,
HopperGroupedGemmInput
::
EpilogueFusion
::
NONE
>
;
case
HopperGroupedGemmInput
::
EpilogueFusion
::
ACTIVATION
:
case
HopperGroupedGemmInput
::
EpilogueFusion
::
GATED_ACTIVATION
:
default:
TLLM_THROW
(
"Unimplemented fusion %d requested"
,
(
int
)
hopper_input
.
fusion
);
};
};
auto
selected_func
=
select_function
();
selected_func
(
hopper_input
,
num_experts
,
gemm_config
,
multi_processor_count_
,
stream
,
occupancy
,
nullptr
);
return
;
}
// Fallthrough to SM80 impl below
}
// Do Ampere case instead
if
constexpr
(
kernels
::
cutlass_kernels
::
isValidAmpereMOESpecialisation
<
T
,
WeightType
,
EpilogueTag
>
())
{
TLLM_CHECK_WITH_INFO
(
!
hopper_input
.
isValid
(),
"Non-specialised Hopper implementation is being rerouted to fallback implementation so input "
"information is not required"
);
TLLM_CHECK_WITH_INFO
(
!
gemm_config
.
is_sm90
,
"GEMM config is for SM90 configuration, but this configuration is not valid for Hppper"
);
dispatchMoeGemmToCutlass
<
T
,
WeightType
,
ScaleBiasType
,
cutlass
::
arch
::
Sm80
,
EpilogueTag
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
gemm_config
,
multi_processor_count_
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
occupancy
);
}
else
{
TLLM_THROW
(
"Configuration expects SM80 but configuration is not supported by SM80 kernels"
);
}
}
else
{
TLLM_THROW
(
"Arch unsupported for MoE GEMM"
);
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
size_t
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
getMaxWorkspaceSize
(
int
num_experts
)
const
{
if
(
num_experts
!=
num_experts_
)
{
TLLM_LOG_TRACE
(
"Calling getMaxWorkspaceSize() with a new expert count %d vs %d"
,
num_experts
,
num_experts_
);
num_experts_
=
num_experts
;
gemm_workspace_size_
=
calcMaxWorkspaceSize
(
num_experts
);
}
return
gemm_workspace_size_
;
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
size_t
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
calcMaxWorkspaceSize
(
int
num_experts
)
const
{
if
(
!
supportsHopperSpecialisation
())
{
return
0
;
}
if
constexpr
(
kernels
::
cutlass_kernels
::
isValidHopperMOESpecialisation
<
T
,
WeightType
>
())
{
auto
configs
=
getHopperConfigs
(
sm_
);
size_t
max_size
=
0
;
bool
has_config
=
false
;
for
(
auto
conf
:
configs
)
{
#define CALC_SIZE_FUSION(FUSION) \
do \
{ \
try \
{ \
size_t size = calcMaxWorkspaceSizeSM90<T, WeightType, OutputType, FUSION>( \
num_experts, conf, multi_processor_count_); \
max_size = std::max(max_size, size); \
has_config = true; \
} \
catch (tensorrt_llm::common::TllmException const& e) \
{ \
TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size"); \
} \
} while (0)
CALC_SIZE_FUSION
(
HopperGroupedGemmInput
::
EpilogueFusion
::
NONE
);
CALC_SIZE_FUSION
(
HopperGroupedGemmInput
::
EpilogueFusion
::
FINALIZE
);
}
TLLM_CHECK_WITH_INFO
(
has_config
,
"Could not find valid config when calculating workspace size"
);
return
max_size
;
}
else
{
TLLM_THROW
(
"Attempting to calculate Hopper GEMM workspace size with unsupported weight combination"
);
return
0
;
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
template
<
typename
EpilogueTag
>
void
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
runGemm
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
hopper_input
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
)
{
dispatchToArch
<
EpilogueTag
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
chosen_conf
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
nullptr
);
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
void
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
moeGemmBiasAct
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
ScaleBiasType
const
*
biases
,
bool
bias_is_broadcast
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
hopper_input
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
ActivationType
activation_type
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
)
{
switch
(
activation_type
)
{
case
ActivationType
::
Relu
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefaultReLU
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
Gelu
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefaultFtGelu
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
Silu
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefaultSilu
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
Identity
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefault
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
Swiglu
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefaultSilu
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
Geglu
:
runGemm
<
cutlass_extensions
::
EpilogueOpDefaultFtGelu
>
(
A
,
B
,
weight_scales
,
biases
,
bias_is_broadcast
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
break
;
case
ActivationType
::
InvalidType
:
TLLM_THROW
(
"Activation type for fpA_intB must be valid."
);
break
;
default:
TLLM_THROW
(
"Invalid activation type."
);
break
;
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
ScaleBiasType
>
void
MoeGemmRunner
<
T
,
WeightType
,
OutputType
,
ScaleBiasType
>::
moeGemm
(
T
const
*
A
,
WeightType
const
*
B
,
ScaleBiasType
const
*
weight_scales
,
void
*
C
,
int64_t
const
*
total_tokens_including_expert
,
HopperGroupedGemmInput
hopper_input
,
int64_t
total_rows
,
int64_t
gemm_n
,
int64_t
gemm_k
,
int
num_experts
,
bool
use_fused_moe
,
float
const
**
alpha_scale_ptr_array
,
cudaStream_t
stream
,
cutlass_extensions
::
CutlassGemmConfig
chosen_conf
)
{
runGemm
<
cutlass_extensions
::
EpilogueOpDefault
>
(
A
,
B
,
weight_scales
,
nullptr
,
true
,
C
,
total_tokens_including_expert
,
hopper_input
,
total_rows
,
gemm_n
,
gemm_k
,
num_experts
,
use_fused_moe
,
alpha_scale_ptr_array
,
stream
,
chosen_conf
);
}
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h
0 → 100644
View file @
e81d7f11
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // __GNUC__
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // __GNUC__
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace
tensorrt_llm
{
using
EpilogueFusion
=
HopperGroupedGemmInput
::
EpilogueFusion
;
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
EpilogueTag
,
EpilogueFusion
FUSION
,
typename
TileShape
,
typename
ClusterShape
>
void
dispatchMoeGemmSelectBiasSM90
(
HopperGroupedGemmInput
hopper_input
,
int
num_experts
,
int
multi_processor_count
,
cudaStream_t
stream
,
int
*
occupancy
,
size_t
*
workspace_size
)
{
static_assert
(
kernels
::
cutlass_kernels
::
isValidHopperMOESpecialisation
<
T
,
WeightType
,
EpilogueTag
>
(),
"Invalid hopper configuration invoked, fallback to Sm80"
);
TLLM_CHECK_WITH_INFO
(
workspace_size
||
hopper_input
.
isValid
(),
"Hopper specialisation is missing additional input information"
);
// auto func = hopper_input.ptr_c ?
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T, WeightType,
// cutlass::arch::Sm90, EpilogueTag, true>
// :
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T,
// WeightType,
// cutlass::arch::Sm90, EpilogueTag, false>;
// TODO(dastokes) Re-enable bias when CUTLASS supports it
auto
func
=
kernels
::
cutlass_kernels
::
sm90_generic_moe_gemm_kernelLauncher
<
T
,
WeightType
,
OutputType
,
EpilogueTag
,
FUSION
,
TileShape
,
ClusterShape
,
false
>
;
func
(
hopper_input
,
num_experts
,
multi_processor_count
,
stream
,
occupancy
,
workspace_size
);
}
/*
1x1x1 cluster shape is are supported for any tile shape.
2x1x1 cluster shape is only supported for when the M tile is at least 128.
1x2x1 cluster shape is only supported when the N tile is at least 128.
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
that may not be very useful in practice.
*/
template
<
typename
CTAShape
,
typename
ClusterShape
>
constexpr
bool
are_tile_shapes_supported
()
{
using
namespace
cute
;
[[
maybe_unused
]]
constexpr
int
cta_m
=
get
<
0
>
(
CTAShape
{});
[[
maybe_unused
]]
constexpr
int
cta_n
=
get
<
1
>
(
CTAShape
{});
constexpr
int
cga_m
=
get
<
0
>
(
ClusterShape
{});
constexpr
int
cga_n
=
get
<
1
>
(
ClusterShape
{});
if
constexpr
(
cga_m
==
_1
{}
&&
cga_n
==
_1
{})
{
return
true
;
}
else
if
constexpr
(
cga_m
==
_2
{}
&&
cga_n
==
_1
{}
&&
cta_m
>=
_128
{})
{
return
true
;
}
else
if
constexpr
(
cga_m
==
_1
{}
&&
cga_n
==
_2
{}
&&
cta_n
>=
_128
{})
{
return
true
;
}
else
if
constexpr
(
cga_m
==
_2
{}
&&
cga_n
==
_2
{}
&&
cta_m
>=
_128
{}
&&
cta_n
>=
_128
{})
{
return
true
;
}
else
{
return
false
;
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
EpilogueTag
,
EpilogueFusion
FUSION
,
typename
TileShape
>
void
dispatchMoeGemmSelectClusterShapeSM90
(
HopperGroupedGemmInput
hopper_input
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
cudaStream_t
stream
,
int
*
occupancy
,
size_t
*
workspace_size
)
{
using
namespace
cute
;
switch
(
gemm_config
.
cluster_shape
)
{
#define SHAPE_CASE(M, N, K) \
case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \
{ \
using ClusterShape = Shape<_##M, _##N, _##K>; \
if constexpr (are_tile_shapes_supported<TileShape, ClusterShape>()) \
{ \
dispatchMoeGemmSelectBiasSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape>( \
hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \
break; \
} \
else \
{ \
TLLM_THROW("Unsupported tile and cluster shape combination"); \
} \
}
SHAPE_CASE
(
1
,
1
,
1
)
SHAPE_CASE
(
1
,
2
,
1
)
SHAPE_CASE
(
2
,
1
,
1
)
SHAPE_CASE
(
2
,
2
,
1
)
#undef SHAPE_CASE
default:
TLLM_THROW
(
"Unsupported config for MoE gemm."
);
}
}
// namespace tensorrt_llm
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
typename
EpilogueTag
,
EpilogueFusion
FUSION
>
void
dispatchMoeGemmSelectTileShapeSM90
(
HopperGroupedGemmInput
hopper_input
,
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
,
cudaStream_t
stream
,
int
*
occupancy
,
size_t
*
workspace_size
)
{
using
namespace
cute
;
switch
(
gemm_config
.
tile_config_sm90
)
{
#define SHAPE_CASE(M, N, K) \
case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \
{ \
constexpr int KtileBytes = K / sizeof(T); \
using KTileDim = Int<KtileBytes>; \
using TileShape = Shape<_##M, _##N, KTileDim>; \
dispatchMoeGemmSelectClusterShapeSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape>( \
hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \
break; \
}
SHAPE_CASE
(
128
,
16
,
128
)
SHAPE_CASE
(
128
,
32
,
128
)
SHAPE_CASE
(
128
,
64
,
128
)
SHAPE_CASE
(
128
,
128
,
128
)
SHAPE_CASE
(
128
,
256
,
128
)
SHAPE_CASE
(
256
,
128
,
128
)
#undef SHAPE_CASE
case
cutlass_extensions
::
CutlassTileConfigSM90
::
Undefined
:
TLLM_THROW
(
"GEMM config undefined."
);
break
;
case
cutlass_extensions
::
CutlassTileConfigSM90
::
ChooseWithHeuristic
:
TLLM_THROW
(
"GEMM config should have already been set by heuristic."
);
break
;
default:
TLLM_THROW
(
"Unsupported config for MoE gemm."
);
break
;
}
}
template
<
typename
T
,
typename
WeightType
,
typename
OutputType
,
EpilogueFusion
FUSION
>
size_t
calcMaxWorkspaceSizeSM90
(
int
num_experts
,
cutlass_extensions
::
CutlassGemmConfig
gemm_config
,
int
multi_processor_count
)
{
size_t
count
;
// Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat
dispatchMoeGemmSelectTileShapeSM90
<
T
,
WeightType
,
OutputType
,
cutlass_extensions
::
EpilogueOpDefault
,
FUSION
>
(
HopperGroupedGemmInput
{},
num_experts
,
gemm_config
,
multi_processor_count
,
cudaStream_t
{
0
},
nullptr
,
&
count
);
return
count
;
}
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h
0 → 100644
View file @
e81d7f11
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/mma_sm90.h"
#include "cutlass_extensions/epilogue_helpers.h"
namespace
tensorrt_llm
::
kernels
::
cutlass_kernels
{
// Hopper arch
template
<
typename
T
,
typename
WeightType
,
typename
EpilogueTag
=
cutlass_extensions
::
EpilogueOpDefault
>
constexpr
bool
isValidHopperMOESpecialisation
()
{
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
return
cutlass
::
platform
::
is_same
<
T
,
WeightType
>::
value
&&
cutlass
::
platform
::
is_same
<
EpilogueTag
,
cutlass_extensions
::
EpilogueOpDefault
>::
value
;
#else
return
false
;
// CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled
#endif
}
// Hopper arch
template
<
typename
T
,
typename
WeightType
,
typename
EpilogueTag
=
cutlass_extensions
::
EpilogueOpDefault
>
constexpr
bool
isValidAmpereMOESpecialisation
()
{
return
true
;
// Default to true
}
}
// namespace tensorrt_llm::kernels::cutlass_kernels
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment