Unverified Commit e81d7f11 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

add tensorrt_llm moe_gemm as 3rdparty (#3217)

parent 222ce6f1
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define CUDA_LIB_NAME "cuda"
#if defined(_WIN32)
#include <windows.h>
#define dllOpen(name) LoadLibrary("nv" name ".dll")
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
#else // For non-Windows platforms
#include <dlfcn.h>
#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY)
#define dllClose(handle) dlclose(handle)
#define dllGetSym(handle, name) dlsym(handle, name)
#endif // defined(_WIN32)
#include "cudaDriverWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
namespace tensorrt_llm::common
{
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
{
static std::mutex mutex;
static std::weak_ptr<CUDADriverWrapper> instance;
std::shared_ptr<CUDADriverWrapper> result = instance.lock();
if (result)
{
return result;
}
std::lock_guard<std::mutex> lock(mutex);
result = instance.lock();
if (!result)
{
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
instance = result;
}
return result;
}
CUDADriverWrapper::CUDADriverWrapper()
: handle(dllOpen(CUDA_LIB_NAME))
{
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
auto load_sym = [](void* handle, char const* name)
{
void* ret = dllGetSym(handle, name);
return ret;
};
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
*reinterpret_cast<void**>(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
*reinterpret_cast<void**>(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
*reinterpret_cast<void**>(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
*reinterpret_cast<void**>(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
*reinterpret_cast<void**>(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
*reinterpret_cast<void**>(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
*reinterpret_cast<void**>(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
*reinterpret_cast<void**>(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
}
CUDADriverWrapper::~CUDADriverWrapper()
{
dllClose(handle);
}
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
{
return (*_cuGetErrorName)(error, pStr);
}
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
{
return (*_cuGetErrorMessage)(error, pStr);
}
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
{
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
}
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const
{
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
}
CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const
{
return (*_cuModuleUnload)(hmod);
}
CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
{
return (*_cuLinkDestroy)(state);
}
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
{
return (*_cuModuleLoadData)(module, image);
}
CUresult CUDADriverWrapper::cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const
{
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
}
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetFunction)(hfunc, hmod, name);
}
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
}
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const
{
return (*_cuLaunchCooperativeKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams);
}
CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const
{
return (*_cuLaunchKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
}
CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const
{
return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides,
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
}
CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const
{
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
}
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDA_DRIVER_WRAPPER_H
#define CUDA_DRIVER_WRAPPER_H
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
#include <memory>
#include <mutex>
namespace tensorrt_llm::common
{
class CUDADriverWrapper
{
public:
static std::shared_ptr<CUDADriverWrapper> getInstance();
~CUDADriverWrapper();
CUDADriverWrapper(CUDADriverWrapper const&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete;
CUDADriverWrapper(CUDADriverWrapper&&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete;
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
CUresult cuModuleUnload(CUmodule hmod) const;
CUresult cuLinkDestroy(CUlinkState state) const;
CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
CUresult cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
CUjit_option* options, void** optionValues) const;
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const;
CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra) const;
CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank,
void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim,
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const;
private:
void* handle;
CUDADriverWrapper();
CUresult (*_cuGetErrorName)(CUresult, char const**);
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
CUresult (*_cuModuleUnload)(CUmodule);
CUresult (*_cuLinkDestroy)(CUlinkState);
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLinkAddData)(
CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, CUstream, void**);
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra);
CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
};
template <typename T>
void checkDriver(
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
{
if (result)
{
char const* errorName = nullptr;
char const* errorMsg = nullptr;
wrap.cuGetErrorName(result, &errorName);
wrap.cuGetErrorMessage(result, &errorMsg);
throw TllmException(
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
}
}
} // namespace tensorrt_llm::common
/*
* Macros compliant with TensorRT coding conventions
*/
#define TLLM_CU_CHECK(stat) \
do \
{ \
tensorrt_llm::common::checkDriver( \
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
} while (0)
#endif // CUDA_DRIVER_WRAPPER_H
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
namespace tensorrt_llm::kernels::cutlass_kernels
{
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
typename EpilogueTag>
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
int* kernel_occupancy);
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include <cutlass_extensions/epilogue_helpers.h>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh>
#include <tensorrt_llm/common/cudaUtils.h>
namespace tensorrt_llm::kernels::cutlass_kernels
{
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
typename EpilogueTag>
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
int* kernel_occupancy)
{
constexpr auto activation_type = fused_moe::EpilogueRouting<EpilogueTag>(true);
using GemmType = fused_moe::Fused_Moe_Kernel_sm80<ElementType_, CutlassWeightType_, ElementType_, MaxTileM_, TileN_,
TileK_, Stages_, activation_type>;
// make sure GPU has enough resources..
if (kernel_occupancy != nullptr)
{
constexpr int smem_size = GemmType::kSmemSize;
if (smem_size > (48 << 10))
{
cudaFuncAttributes attr{};
int device = 0;
int max_smem_per_block = 0;
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
tensorrt_llm::common::check_cuda_error(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global<GemmType>));
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
{
// This should mean that
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize,
// smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the
// heuristic to ignore this configuration.
*kernel_occupancy = 0;
return;
}
}
int max_active_blocks = -1;
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, fused_moe::run_global<GemmType>, GemmType::kThreadCount, smem_size));
*kernel_occupancy = max_active_blocks;
return;
}
int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks<GemmType>());
int const threadblock_count = multi_processor_count * occupancy;
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel");
using Arguments = typename GemmType::Arguments;
Arguments args{{const_cast<ElementType_*>(A), const_cast<CutlassWeightType_*>(B), const_cast<ElementType_*>(biases),
reinterpret_cast<ElementType_*>(C), total_tokens_including_expert, static_cast<int>(gemm_n),
static_cast<int>(gemm_k), num_experts, bias_is_broadcast},
num_experts, threadblock_count};
auto params = GemmType::to_underlying_arguments(args);
if (GemmType::kSmemSize >= (48 << 10))
{
cudaError_t result = cudaFuncSetAttribute(
fused_moe::run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize);
TLLM_CHECK_WITH_INFO(result == cudaSuccess,
"Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel");
}
dim3 grid(params.threadblock_count, 1, 1);
dim3 block(GemmType::kThreadCount);
fused_moe::run_global<GemmType><<<grid, block, GemmType::kSmemSize, stream>>>(params);
auto result = cudaGetLastError();
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result));
}
} // namespace tensorrt_llm::kernels::cutlass_kernels
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda_runtime_api.h>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
// Keep in sync with the signature generated by generate_kernels.py
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag,
HopperGroupedGemmInput::EpilogueFusion FUSION, typename TileShape, typename ClusterShape, bool BIAS>
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size);
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion;
// Hopper helper class for defining all the cutlass helper types
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, typename TileShape,
typename ClusterShape, bool BIAS, EpilogueFusion FUSION>
struct HopperGroupedGemmInfo
{
using Arch = cutlass::arch::Sm90;
// TODO Update once mixed input support is added
static_assert(cutlass::platform::is_same<T, WeightType>::value,
"CUTLASS does not currently have specialised SM90 support for quantized operations");
#ifdef ENABLE_FP8
constexpr static bool IsFP8
= cutlass::platform::is_same<T, __nv_fp8_e4m3>::value || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value;
#else
constexpr static bool IsFP8 = false;
#endif
#ifdef ENABLE_BF16
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for bfloat16, half, float, fp8");
#else
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for half, float, fp8");
#endif
static_assert(cutlass::platform::is_same<T, WeightType>::value
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e4m3_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e5m2_t>::value,
"Unexpected quantization type");
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter<WeightType>::type;
// For legacy reasons we convert unsigned 8-bit to signed
using CutlassWeightTypeMaybeUint8
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint4, cutlass::uint4b_t>, cutlass::int4b_t,
CutlassWeightTypeMaybeUint4>;
using CutlassWeightType
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint8, uint8_t>, int8_t, CutlassWeightTypeMaybeUint8>;
using ElementA = ElementType;
using ElementB = CutlassWeightType;
using ElementD = typename TllmToCutlassTypeAdapter<HopperGroupedGemmInput::OutputTypeAdaptor_t<OutputType>>::type;
using ElementFinalOutput = typename TllmToCutlassTypeAdapter<OutputType>::type;
// using ElementC = std::conditional_t<BIAS, ElementType, void>;
// using ElementCNoVoid = std::conditional_t<BIAS, ElementType, ElementD>;
using ElementC = void;
using ElementCNoVoid = ElementD;
using ElementAccumulator = float;
using ElementBias = ElementFinalOutput;
using ElementRouterScales = float;
// A matrix configuration - this is transposed and swapped with B
using LayoutA = HopperGroupedGemmInput::LayoutA;
constexpr static int AlignmentA
= 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units
// of elements (up to 16 bytes)
// B matrix configuration - this is transposed and swapped with A
using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand
constexpr static int AlignmentB
= 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units
// of elements (up to 16 bytes)
// C matrix configuration
using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand
using StrideC = HopperGroupedGemmInput::StrideC;
// Note we use ElementType here deliberately, so we don't break when BIAS is disabled
constexpr static int AlignmentC
= 128 / cutlass::sizeof_bits<ElementType>::value; // Memory access granularity/alignment of C matrix in units
// of elements (up to 16 bytes)
// D matrix configuration
using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD;
using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD;
constexpr static int AlignmentD
= 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix
// in units of elements (up to 16 bytes)
static_assert(cutlass::platform::is_same<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpDefault>::value,
"Hopper Grouped GEMM specialisation doesn't support fused activation");
using EpilogueOp
= cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
// TODO Add mode for fused activation once CUTLASS adds support
// using EpilogueSchedule = cutlass::platform::conditional_t<
// cutlass::platform::is_same<EpilogueOp, EpilogueOpDefault>::value,
// cutlass::epilogue::PtrArrayNoSmemWarpSpecialized,
// cutlass::epilogue::?????????????????? /// <<<<<< what supports activations
// >;
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;
// Epilogue For Default Finalize
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
TileShape, ClusterShape, //
cutlass::epilogue::collective::EpilogueTileAuto, //
ElementAccumulator, ElementAccumulator, //
ElementC, LayoutC*, AlignmentC, //
ElementD, LayoutD*, AlignmentD, //
EpilogueSchedule>::CollectiveOp;
// Epilogue For Fused Finalize
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< //
TileShape, //
ElementCNoVoid, StrideC*, //
ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, //
ElementAccumulator, //
ElementAccumulator, //
ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, //
ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales //
>::CollectiveOp;
using CollectiveEpilogue
= std::conditional_t<FUSION == EpilogueFusion::FINALIZE, CollectiveEpilogueFinalize, CollectiveEpilogueDefault>;
using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>;
using KernelSchedule
= std::conditional_t<IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here
ElementType, LayoutA*, AlignmentA, //
ElementAccumulator, //
TileShape, ClusterShape, //
StageCountAutoCarveout, KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<HopperGroupedGemmInput::ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
// Hopper specialised version
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename TileShape, typename ClusterShape, bool BIAS>
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size)
{
#ifdef COMPILE_HOPPER_TMA_GEMMS
using namespace cute;
if constexpr (!should_filter_sm90_gemm_problem_shape_v<TileShape, ClusterShape, T>)
{
using GemmInfo
= HopperGroupedGemmInfo<T, WeightType, OutputType, EpilogueTag, TileShape, ClusterShape, BIAS, FUSION>;
using ElementAccumulator = typename GemmInfo::ElementAccumulator;
using ElementA = typename GemmInfo::ElementA;
using ElementB = typename GemmInfo::ElementB;
using ElementC = typename GemmInfo::ElementC;
using ElementCNoVoid = typename GemmInfo::ElementCNoVoid;
using ElementD = typename GemmInfo::ElementD;
using CollectiveMainloop = typename GemmInfo::CollectiveMainloop;
using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue;
using GemmKernel = typename GemmInfo::GemmKernel;
using GemmGrouped = typename GemmInfo::GemmGrouped;
if (kernel_occupancy != nullptr)
{
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel, true>();
return;
}
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = multi_processor_count;
GemmGrouped gemm;
if (workspace_size != nullptr)
{
// Make a mock problem shape with just the minimal information actually required to get the workspace size
// This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to
// catch future cutlass updates causing silent breakages, but that is not fool proof.
// The alternative is to wait until we have data and then dynamically allocate the workspace
typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr};
typename GemmGrouped::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info};
*workspace_size = gemm.get_workspace_size(args);
return;
}
using MainloopArguments = typename CollectiveMainloop::Arguments;
TLLM_CHECK(hopper_input.stride_a);
TLLM_CHECK(hopper_input.stride_b);
TLLM_CHECK(hopper_input.ptr_a);
TLLM_CHECK(hopper_input.ptr_b);
MainloopArguments const mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
hopper_input.stride_b, reinterpret_cast<ElementA const**>(hopper_input.ptr_a), hopper_input.stride_a};
typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{
ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)};
epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
// TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS
auto make_epi_args = [&]()
{
if constexpr (FUSION == EpilogueFusion::NONE)
{
auto epi_params = hopper_input.default_epilogue;
return EpilogueArguments{epilogue_scalars, reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c),
hopper_input.stride_c, reinterpret_cast<ElementD**>(epi_params.ptr_d), epi_params.stride_d};
}
else if constexpr (FUSION == EpilogueFusion::FINALIZE)
{
// Parameters for fused finalize
auto epi_params = hopper_input.fused_finalize_epilogue;
return EpilogueArguments{
epilogue_scalars, // Parameters to underlying epilogue
reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c), hopper_input.stride_c, // C params
reinterpret_cast<typename GemmInfo::ElementFinalOutput*>(epi_params.ptr_final_output),
epi_params.stride_final_output, // D (output) params
reinterpret_cast<typename GemmInfo::ElementBias const*>(epi_params.ptr_bias),
epi_params.stride_bias, // Bias params
epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales
epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales
epi_params.ptr_source_token_index, // Index of the source token to sum into
epi_params.num_rows_in_final_output // Number of tokens in the output buffer
};
}
else
{
static_assert(
sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher");
}
};
EpilogueArguments const epilogue_params = make_epi_args();
typename GemmKernel::TileScheduler::Arguments scheduler_args{
1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN};
typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info,
mainloop_params, epilogue_params, hw_info, scheduler_args};
size_t calculated_ws_size = gemm.get_workspace_size(args);
TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size,
"Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size);
auto can_implement = gemm.can_implement(args);
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
"Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
auto init_status = gemm.initialize(args, hopper_input.gemm_workspace);
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
"Failed to initialize cutlass SM90 grouped gemm. Error: "
+ std::string(cutlassGetStatusString(init_status)));
auto run_status = gemm.run(stream);
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
"Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
sync_check_cuda_error();
}
else
{
TLLM_THROW("Configuration was disabled by FAST_BUILD");
}
#else // COMPILE_HOPPER_TMA_GEMMS
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
#endif // COMPILE_HOPPER_TMA_GEMMS
}
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/conv/convolution.h"
// Order matters here, packed_stride.hpp is missing cute and convolution includes
#include "cutlass/util/packed_stride.hpp"
#include "tensorrt_llm/common/logger.h"
namespace tensorrt_llm
{
std::array<size_t, 10> HopperGroupedGemmInput::workspaceBuffers(int num_experts)
{
size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts;
size_t stride_a_size = sizeof(StrideA) * num_experts;
size_t stride_b_size = sizeof(StrideB) * num_experts;
size_t stride_c_size = sizeof(StrideC) * num_experts;
size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts;
size_t ptr_buf_size = sizeof(void*) * num_experts;
size_t scale_buf_size = sizeof(float*) * num_experts;
return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size,
ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size};
}
size_t HopperGroupedGemmInput::workspaceSize(int num_experts)
{
auto buffers = workspaceBuffers(num_experts);
return tensorrt_llm::common::calculateTotalWorkspaceSize(buffers.data(), buffers.size());
}
void HopperGroupedGemmInput::configureWorkspace(
int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size)
{
auto buffers = workspaceBuffers(num_experts);
std::array<int8_t*, 10> pointers{};
TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers");
for (int i = 0; i < buffers.size(); i++)
{
pointers[i] = start_ptr;
start_ptr = tensorrt_llm::common::nextWorkspacePtr(start_ptr, buffers[i]);
}
shape_info.num_groups = num_experts;
shape_info.problem_shapes = reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(pointers[0]);
shape_info.host_problem_shapes = nullptr;
stride_a = reinterpret_cast<StrideA*>(pointers[1]);
stride_b = reinterpret_cast<StrideB*>(pointers[2]);
stride_c = reinterpret_cast<StrideC*>(pointers[3]);
default_epilogue.stride_d = reinterpret_cast<DefaultEpilogue::StrideD*>(pointers[4]);
ptr_a = reinterpret_cast<void const**>(pointers[5]);
ptr_b = reinterpret_cast<void const**>(pointers[6]);
ptr_c = reinterpret_cast<void const**>(pointers[7]);
default_epilogue.ptr_d = reinterpret_cast<void**>(pointers[8]);
alpha_scale_ptr_array = reinterpret_cast<float const**>(pointers[9]);
this->gemm_workspace = reinterpret_cast<uint8_t*>(gemm_workspace);
this->gemm_workspace_size = gemm_workspace_size;
}
void HopperGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales,
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
int num_output_tokens)
{
fused_finalize_epilogue.ptr_final_output = final_output;
fused_finalize_epilogue.ptr_router_scales = router_scales;
fused_finalize_epilogue.ptr_bias = bias;
fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset;
fused_finalize_epilogue.ptr_source_token_index = source_token_index;
fused_finalize_epilogue.stride_final_output
= cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{},
transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1)));
fused_finalize_epilogue.stride_bias
= transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size));
fused_finalize_epilogue.stride_router_scales = {};
fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens;
}
std::string HopperGroupedGemmInput::toString() const
{
std::stringstream ss;
ss << "Hopper Input Information: " << (isValid() ? "valid" : "null") << "\n";
if (isValid())
{
ss << "Ptr A: " << ptr_a << ", Ptr B: " << ptr_b << ", Ptr C: " << ptr_c << "\n";
ss << "Epilogue Fusion: " << (int) fusion;
if (fusion == HopperGroupedGemmInput::EpilogueFusion::FINALIZE)
{
ss << ",\nFinal Output: " << fused_finalize_epilogue.ptr_final_output;
ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales;
ss << ",\nBias: " << fused_finalize_epilogue.ptr_bias;
ss << " with Stride: " << fused_finalize_epilogue.stride_bias;
ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales;
ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales;
ss << ",\nExpert Offset: " << fused_finalize_epilogue.ptr_expert_first_token_offset;
ss << ", Source Map: " << fused_finalize_epilogue.ptr_source_token_index;
}
else
{
ss << ", Ptr D: " << default_epilogue.ptr_d;
}
ss << '\n';
ss << "Alpha scale ptr: " << alpha_scale_ptr_array << "\n";
}
return ss.str();
}
} // namespace tensorrt_llm
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/workspace.h"
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h"
#include <array>
#include <cuda_runtime_api.h>
#include <optional>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/layout/layout.h"
namespace tensorrt_llm
{
template <class T>
constexpr auto transpose_stride(T const& t)
{
return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t));
}
struct HopperGroupedGemmInput
{
template <class T>
using TransposeStride = decltype(transpose_stride<T>(T{}));
template <class Tag>
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
static_assert(std::is_same_v<cutlass::layout::RowMajor, TransposeLayoutTag<cutlass::layout::ColumnMajor>>);
static_assert(std::is_same_v<cutlass::layout::ColumnMajor, TransposeLayoutTag<cutlass::layout::RowMajor>>);
// Layout for A and B is transposed and then swapped in the implementation
// This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
using StrideA
= std::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutA*>>; // Use B because they will be swapped
using StrideB
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
template <class T>
constexpr static bool IsFP8_v = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
// Currently this should always just be T
template <class T>
using OutputTypeAdaptor_t = std::conditional_t<IsFP8_v<T>, nv_bfloat16, T>;
using ProblemShape = cutlass::gemm::GroupProblemShape<cute::Shape<int64_t, int64_t, int64_t>>;
ProblemShape shape_info{};
StrideA* stride_a = nullptr;
StrideB* stride_b = nullptr;
void const** ptr_a = nullptr;
void const** ptr_b = nullptr;
// C is currently the same in both epilogues
StrideC* stride_c = nullptr;
void const** ptr_c = nullptr;
struct DefaultEpilogue
{
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
StrideD* stride_d = nullptr;
void** ptr_d = nullptr;
};
struct FusedFinalizeEpilogue
{
using StrideFinalOutput = DefaultEpilogue::StrideD;
using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>;
using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>;
void* ptr_final_output = nullptr;
StrideFinalOutput stride_final_output{};
void const* ptr_bias = nullptr;
StrideBias stride_bias{};
float const* ptr_router_scales = nullptr;
StrideRouterScales stride_router_scales{};
int64_t const* ptr_expert_first_token_offset = nullptr;
int const* ptr_source_token_index = nullptr;
size_t num_rows_in_final_output = 0;
};
DefaultEpilogue default_epilogue;
FusedFinalizeEpilogue fused_finalize_epilogue;
enum class EpilogueFusion
{
NONE,
ACTIVATION,
GATED_ACTIVATION,
FINALIZE
};
EpilogueFusion fusion = EpilogueFusion::NONE;
float const** alpha_scale_ptr_array = nullptr;
uint8_t* gemm_workspace = nullptr;
size_t gemm_workspace_size = 0;
static std::array<size_t, 10> workspaceBuffers(int num_experts);
static size_t workspaceSize(int num_experts);
void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size);
bool isValid() const
{
return stride_a != nullptr && ptr_a != nullptr;
}
void setFinalizeFusionParams(void* final_output, float const* router_scales,
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
int num_output_tokens);
std::string toString() const;
};
// Note update moe.py to match
enum class ActivationType
{
Gelu = 0,
Relu,
Silu,
Swiglu,
Geglu,
Identity,
InvalidType
};
constexpr bool isGatedActivation(ActivationType activation_type)
{
return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu;
}
template <typename T, /*The type used for activations/scales/compute*/
typename WeightType, /* The type for the MoE weights */
typename OutputType, /* The output type for the GEMM */
typename ScaleBiasType = OutputType /* The type for the scales/bias */
>
class MoeGemmRunner
{
public:
MoeGemmRunner();
#if defined(ENABLE_FP8)
static constexpr bool use_fp8 = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
#else
static constexpr bool use_fp8 = false;
#endif
void moeGemmBiasAct(T const* A, WeightType const* B, ScaleBiasType const* weight_scales,
ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
ActivationType activation_type, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig chosen_conf);
void moeGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, void* C,
int64_t const* total_tokens_including_expert, HopperGroupedGemmInput layout_info, int64_t total_rows,
int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array,
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf);
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs() const;
static std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int sm);
static std::vector<cutlass_extensions::CutlassGemmConfig> getHopperConfigs(int sm);
static std::vector<cutlass_extensions::CutlassGemmConfig> getAmpereConfigs(int sm);
[[nodiscard]] bool isHopperSpecialised(cutlass_extensions::CutlassGemmConfig gemm_config) const;
[[nodiscard]] bool supportsHopperSpecialisation() const;
[[nodiscard]] bool isFusedGatedActivation(
cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const;
[[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const;
size_t getMaxWorkspaceSize(int num_experts) const;
[[nodiscard]] int getSM() const;
private:
template <typename EpilogueTag>
void dispatchToArch(T const* A, WeightType const* B, ScaleBiasType const* weight_scales,
ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array,
cudaStream_t stream, int* occupancy = nullptr);
template <typename EpilogueTag>
void runGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases,
bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig chosen_conf);
private:
int sm_{};
int multi_processor_count_{};
mutable int num_experts_ = 0;
mutable size_t gemm_workspace_size_ = 0;
size_t calcMaxWorkspaceSize(int num_experts) const;
};
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
#ifdef ENABLE_BF16
template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>;
#endif
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
#ifdef ENABLE_BF16
template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16>;
#endif
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
#ifdef ENABLE_BF16
template class MoeGemmRunner<__nv_bfloat16, uint8_t, __nv_bfloat16>;
#endif
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
template class MoeGemmRunner<half, half, half>;
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
template class MoeGemmRunner<half, cutlass::uint4b_t, half>;
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
template class MoeGemmRunner<half, uint8_t, half>;
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
template class MoeGemmRunner<float, float, float>;
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
#ifdef ENABLE_FP8
template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>;
#ifdef ENABLE_BF16
template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
#endif
// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>;
#endif
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#ifdef __GNUC__ // Restore GCC-specific diagnostics
#pragma GCC diagnostic pop
#endif
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "moe_gemm_kernels_template_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include <tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace tensorrt_llm
{
namespace kernels::cutlass_kernels
{
// ============================= Variable batched Gemm things ===========================
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape, int Stages>
void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, bool use_fused_moe,
float const** alpha_scale_ptr_array, cudaStream_t stream, int* kernel_occupancy = nullptr)
{
#if defined(ENABLE_FP8)
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, __nv_fp8_e4m3>::value
|| cutlass::platform::is_same<T, __nv_fp8_e5m2>::value || cutlass::platform::is_same<T, float>::value,
"Specialized for fp8, bfloat16, half, float");
#elif defined(ENABLE_BF16)
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, float>::value,
"Specialized for bfloat16, half, float");
#else
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
"Specialized for half, float");
#endif
static_assert(cutlass::platform::is_same<T, WeightType>::value
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
"");
static_assert(!cutlass::platform::is_same<arch, cutlass::arch::Sm90>::value,
"Sm90 architecture should use specialised kernels");
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter<GemmOutputType>::type;
using CutlassWeightType = typename TllmToCutlassTypeAdapter<WeightType>::type;
if (!use_fused_moe)
{
// We need separate config for each architecture since we will target different tensorcore instructions. For
// float, we do not target TCs.
using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue<CutlassGemmOutputType,
MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op;
typename EpilogueOp::Params epilogue_op(
ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f));
#if defined(ENABLE_FP8)
if constexpr ((std::is_same_v<T, __nv_fp8_e4m3>
|| std::is_same_v<T, __nv_fp8_e5m2>) &&std::is_same_v<EpilogueTag,
cutlass_extensions::EpilogueOpDefault>)
{
TLLM_CHECK_WITH_INFO(weight_scales == nullptr && biases == nullptr && alpha_scale_ptr_array,
"weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 "
"Ada");
epilogue_op.alpha_ptr_array = alpha_scale_ptr_array;
}
#endif
// Finally, set up the kernel.
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped<ElementType, cutlass::layout::RowMajor,
cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType,
typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone,
MixedGemmArchTraits::ElementsPerAccessB, CutlassGemmOutputType, cutlass::layout::RowMajor,
ElementAccumulator, typename MixedGemmArchTraits::OperatorClass, arch, ThreadblockShape, WarpShape,
typename MixedGemmArchTraits::InstructionShape, EpilogueOp,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel;
using GemmKernel = cutlass::gemm::kernel::MoeFCGemm<typename GemmKernel_::Mma, typename GemmKernel_::Epilogue,
typename GemmKernel_::ThreadblockSwizzle,
arch, // Ensure top level arch is used for dispatch
GemmKernel_::kGroupScheduleMode>;
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
if (kernel_occupancy != nullptr)
{
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel>();
return;
}
int occupancy = std::min(2, GemmGrouped::maximum_active_blocks());
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel");
int const threadblock_count = multi_processor_count * occupancy;
int const group_size = gemm_k;
typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op,
reinterpret_cast<ElementType const*>(A), reinterpret_cast<CutlassWeightType const*>(B),
reinterpret_cast<CutlassGemmOutputType const*>(weight_scales),
reinterpret_cast<CutlassGemmOutputType const*>(biases), bias_is_broadcast,
reinterpret_cast<CutlassGemmOutputType*>(C), total_tokens_including_expert, gemm_n, gemm_k);
GemmGrouped gemm;
auto can_implement = gemm.can_implement(args);
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
"MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
auto init_status = gemm.initialize(args);
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
"Failed to initialize cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(init_status)));
auto run_status = gemm.run(stream);
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
"Failed to run cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
}
else if constexpr (sizeof(ElementType) == 2 && sizeof(CutlassWeightType) == 2
&& (std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefaultSilu>
|| std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefaultFtGelu>) ) // use fused moe gemm
// kernel.. (only support
// fp16 or bf16)
{
sm80_generic_fused_moe_gemm_kernelLauncher<ElementType, CutlassWeightType, ThreadblockShape::kM,
ThreadblockShape::kN, ThreadblockShape::kK, Stages, EpilogueTag>(reinterpret_cast<ElementType const*>(A),
reinterpret_cast<CutlassWeightType const*>(B), reinterpret_cast<ElementType const*>(biases),
bias_is_broadcast, reinterpret_cast<ElementType*>(C), total_tokens_including_expert, num_rows, gemm_n,
gemm_k, num_experts, multi_processor_count, stream, kernel_occupancy);
}
}
} // namespace kernels::cutlass_kernels
template <typename T, typename WeightType, typename GemmOutputType, typename Arch, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape, int Stages>
static void dispatch(T const* A, WeightType const* B, GemmOutputType const* weight_scales, GemmOutputType const* biases,
bool bias_is_broadcast, GemmOutputType* C, int64_t const* total_tokens_including_expert, int64_t num_rows,
int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config,
int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
int* occupancy = nullptr)
{
static_assert(!std::is_same_v<Arch, cutlass::arch::Sm90>, "Use TMA specialised functions for arch SM90");
#if defined(ENABLE_FP8)
constexpr bool isFp8 = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
#else
constexpr bool isFp8 = false;
#endif
if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80)
&& (!isFp8 || std::is_same_v<Arch, cutlass::arch::Sm89>) )
{
kernels::cutlass_kernels::genericMoeGemmKernelLauncher<T, WeightType, GemmOutputType, Arch, EpilogueTag,
ThreadblockShape, WarpShape, Stages>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
}
else
{
TLLM_THROW(
"Cutlass gemm. Not instantiated for arch %d with stages set to %d", Arch::kMinComputeCapability, Stages);
}
}
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape>
void dispatchGemmConfig(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
{
switch (gemm_config.stages)
{
case 2:
dispatch<T, WeightType, GemmOutputType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B, weight_scales,
biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts,
gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case 3:
dispatch<T, WeightType, GemmOutputType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>(A, B, weight_scales,
biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts,
gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case 4:
dispatch<T, WeightType, GemmOutputType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>(A, B, weight_scales,
biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts,
gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break;
}
}
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
// This overload is only enabled when T == WeightType.
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
typename std::enable_if<!std::is_same<T, float>::value
#if defined(ENABLE_FP8)
&& !std::is_same<T, __nv_fp8_e4m3>::value && !std::is_same<T, __nv_fp8_e5m2>::value
#endif
&& std::is_same<T, WeightType>::value>::type* = nullptr>
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
{
switch (gemm_config.tile_config)
{
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
TLLM_THROW("GEMM config should have already been set by heuristic.");
break;
default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break;
}
}
// Tensorop GEMM overload
// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve
// compile time
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
typename std::enable_if<!std::is_same<T, float>::value && !std::is_same<T, WeightType>::value>::type* = nullptr>
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
{
switch (gemm_config.tile_config)
{
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
TLLM_THROW("GEMM config should have already been set by heuristic.");
break;
default: TLLM_THROW("Config is invalid for mixed type tensorop GEMM."); break;
}
}
// This overload will handle tensorop gemms.
// This overload is only enabled when T == WeightType and T == __nv_fp8_e4m3 or __nv_fp8_e5m2
#if defined(ENABLE_FP8)
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
typename std::enable_if<(std::is_same<T, __nv_fp8_e4m3>::value || std::is_same<T, __nv_fp8_e5m2>::value)
&& std::is_same<T, WeightType>::value>::type* = nullptr>
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
{
switch (gemm_config.tile_config)
{
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 128>,
cutlass::gemm::GemmShape<16, 64, 128>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
TLLM_THROW("GEMM config should have already been set by heuristic.");
break;
default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break;
}
}
#endif
// This overload will handle simt gemms. It is disabled via SFINAE for tensorop.
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
{
switch (gemm_config.tile_config)
{
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 8>,
cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
TLLM_THROW("GEMM config should have already been set by heuristic.");
break;
default: TLLM_THROW("Unsupported config for float MoE gemm."); break;
}
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
std::vector<cutlass_extensions::CutlassGemmConfig>
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs() const
{
return getConfigs(sm_);
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs(
int sm)
{
std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs = getHopperConfigs(sm);
std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs = getAmpereConfigs(sm);
std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs));
return candidate_configs;
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
std::vector<cutlass_extensions::CutlassGemmConfig>
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getAmpereConfigs(int sm)
{
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
static constexpr auto weight_only_flag
= std::is_same<T, WeightType>::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY;
static constexpr auto simt_only_flag
= std::is_same<T, float>::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE;
static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE;
int const max_split_k = 1;
int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM;
int const enable_hopper = CutlassGemmConfig::NONE;
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(
weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag);
if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType>())
{
return {};
}
std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs
= kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param);
return ampere_configs;
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
std::vector<cutlass_extensions::CutlassGemmConfig>
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getHopperConfigs(int sm)
{
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
static constexpr auto weight_only_flag
= std::is_same<T, WeightType>::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY;
static constexpr auto simt_only_flag
= std::is_same<T, float>::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE;
int const max_split_k = 1;
int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM;
int const enable_hopper = CutlassGemmConfig::HOPPER;
static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE;
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(
weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag);
if (!kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
{
return {};
}
std::vector<cutlass_extensions::CutlassGemmConfig> hopper_configs
= kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param);
return hopper_configs;
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::isHopperSpecialised(
cutlass_extensions::CutlassGemmConfig gemm_config) const
{
bool config_is_sm90 = gemm_config.is_sm90;
return supportsHopperSpecialisation() && config_is_sm90;
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsHopperSpecialisation() const
{
return sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>();
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
int MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getSM() const
{
return this->sm_;
}
// currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsFusedGatedActivation(
bool is_gated_activation, int gemm_n, int gemm_k) const
{
constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true;
return is_gated_activation && std::is_same_v<T, WeightType> && !std::is_same_v<T, float> && !use_fp8
&& (this->getSM() >= 80) && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION;
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::isFusedGatedActivation(
cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const
{
return supportsFusedGatedActivation(is_gated_activation, gemm_n, gemm_k) && !gemm_config.is_sm90;
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::MoeGemmRunner()
{
int device{-1};
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
sm_ = tensorrt_llm::common::getSMVersion();
tensorrt_llm::common::check_cuda_error(
cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
template <typename EpilogueTag>
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch<EpilogueTag>(T const* A,
WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast,
void* C_void, int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config,
bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy)
{
static_assert(std::is_same_v<ScaleBiasType, OutputType>,
"Separate Scale/Bias type is not supported. This is assumed to be the gemm output type");
// For now we always cast this to output type.
// In the future this will vary based on what fusions are applied for FP8
auto* C = reinterpret_cast<OutputType*>(C_void);
TLLM_CHECK_WITH_INFO(
sm_ >= 89 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation");
TLLM_CHECK_WITH_INFO(
sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture");
if (sm_ >= 75 && sm_ < 80)
{
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm75, EpilogueTag>(A, B, weight_scales,
biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts,
gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
}
else if (sm_ >= 80 && sm_ < 90)
{
if constexpr (use_fp8)
{
#if defined(ENABLE_FP8)
static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>,
"FP8 GEMM Output not supported");
#endif
TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89");
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(A, B,
weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k,
num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream,
occupancy);
}
else
{
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(A, B,
weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k,
num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream,
occupancy);
}
}
else if (sm_ >= 90)
{
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>())
{
// We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens
// SM80 is faster. We check here to see which is selected
if (gemm_config.is_sm90)
{
TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr,
"Input biases and hopper input disagree if bias is enabled");
TLLM_CHECK_WITH_INFO(hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config");
// Select the appropriate fusion function
auto select_function = [&]()
{
switch (hopper_input.fusion)
{
case HopperGroupedGemmInput::EpilogueFusion::FINALIZE:
return &dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, EpilogueTag,
HopperGroupedGemmInput::EpilogueFusion::FINALIZE>;
case HopperGroupedGemmInput::EpilogueFusion::NONE:
return &dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, EpilogueTag,
HopperGroupedGemmInput::EpilogueFusion::NONE>;
case HopperGroupedGemmInput::EpilogueFusion::ACTIVATION:
case HopperGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION:
default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_input.fusion);
};
};
auto selected_func = select_function();
selected_func(
hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr);
return;
}
// Fallthrough to SM80 impl below
}
// Do Ampere case instead
if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType, EpilogueTag>())
{
TLLM_CHECK_WITH_INFO(!hopper_input.isValid(),
"Non-specialised Hopper implementation is being rerouted to fallback implementation so input "
"information is not required");
TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90,
"GEMM config is for SM90 configuration, but this configuration is not valid for Hppper");
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(A, B,
weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k,
num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream,
occupancy);
}
else
{
TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels");
}
}
else
{
TLLM_THROW("Arch unsupported for MoE GEMM");
}
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getMaxWorkspaceSize(int num_experts) const
{
if (num_experts != num_experts_)
{
TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_);
num_experts_ = num_experts;
gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts);
}
return gemm_workspace_size_;
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::calcMaxWorkspaceSize(int num_experts) const
{
if (!supportsHopperSpecialisation())
{
return 0;
}
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
{
auto configs = getHopperConfigs(sm_);
size_t max_size = 0;
bool has_config = false;
for (auto conf : configs)
{
#define CALC_SIZE_FUSION(FUSION) \
do \
{ \
try \
{ \
size_t size = calcMaxWorkspaceSizeSM90<T, WeightType, OutputType, FUSION>( \
num_experts, conf, multi_processor_count_); \
max_size = std::max(max_size, size); \
has_config = true; \
} \
catch (tensorrt_llm::common::TllmException const& e) \
{ \
TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size"); \
} \
} while (0)
CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::NONE);
CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::FINALIZE);
}
TLLM_CHECK_WITH_INFO(has_config, "Could not find valid config when calculating workspace size");
return max_size;
}
else
{
TLLM_THROW("Attempting to calculate Hopper GEMM workspace size with unsupported weight combination");
return 0;
}
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
template <typename EpilogueTag>
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::runGemm(T const* A, WeightType const* B,
ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C,
int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array,
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf)
{
dispatchToArch<EpilogueTag>(A, B, weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert,
hopper_input, total_rows, gemm_n, gemm_k, num_experts, chosen_conf, use_fused_moe, alpha_scale_ptr_array,
stream, nullptr);
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemmBiasAct(T const* A, WeightType const* B,
ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C,
int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe,
float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf)
{
switch (activation_type)
{
case ActivationType::Relu:
runGemm<cutlass_extensions::EpilogueOpDefaultReLU>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
alpha_scale_ptr_array, stream, chosen_conf);
break;
case ActivationType::Gelu:
runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
alpha_scale_ptr_array, stream, chosen_conf);
break;
case ActivationType::Silu:
runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
alpha_scale_ptr_array, stream, chosen_conf);
break;
case ActivationType::Identity:
runGemm<cutlass_extensions::EpilogueOpDefault>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
alpha_scale_ptr_array, stream, chosen_conf);
break;
case ActivationType::Swiglu:
runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
alpha_scale_ptr_array, stream, chosen_conf);
break;
case ActivationType::Geglu:
runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(A, B, weight_scales, biases, bias_is_broadcast, C,
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
alpha_scale_ptr_array, stream, chosen_conf);
break;
case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break;
default: TLLM_THROW("Invalid activation type."); break;
}
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemm(T const* A, WeightType const* B,
ScaleBiasType const* weight_scales, void* C, int64_t const* total_tokens_including_expert,
HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig chosen_conf)
{
runGemm<cutlass_extensions::EpilogueOpDefault>(A, B, weight_scales, nullptr, true, C, total_tokens_including_expert,
hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream,
chosen_conf);
}
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // __GNUC__
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // __GNUC__
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace tensorrt_llm
{
using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion;
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename TileShape, typename ClusterShape>
void dispatchMoeGemmSelectBiasSM90(HopperGroupedGemmInput hopper_input, int num_experts, int multi_processor_count,
cudaStream_t stream, int* occupancy, size_t* workspace_size)
{
static_assert(kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>(),
"Invalid hopper configuration invoked, fallback to Sm80");
TLLM_CHECK_WITH_INFO(
workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information");
// auto func = hopper_input.ptr_c ?
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T, WeightType,
// cutlass::arch::Sm90, EpilogueTag, true>
// :
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T,
// WeightType,
// cutlass::arch::Sm90, EpilogueTag, false>;
// TODO(dastokes) Re-enable bias when CUTLASS supports it
auto func = kernels::cutlass_kernels::sm90_generic_moe_gemm_kernelLauncher<T, WeightType, OutputType, EpilogueTag,
FUSION, TileShape, ClusterShape, false>;
func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size);
}
/*
1x1x1 cluster shape is are supported for any tile shape.
2x1x1 cluster shape is only supported for when the M tile is at least 128.
1x2x1 cluster shape is only supported when the N tile is at least 128.
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
that may not be very useful in practice.
*/
template <typename CTAShape, typename ClusterShape>
constexpr bool are_tile_shapes_supported()
{
using namespace cute;
[[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{});
[[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{});
constexpr int cga_m = get<0>(ClusterShape{});
constexpr int cga_n = get<1>(ClusterShape{});
if constexpr (cga_m == _1{} && cga_n == _1{})
{
return true;
}
else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{})
{
return true;
}
else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{})
{
return true;
}
else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{})
{
return true;
}
else
{
return false;
}
}
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename TileShape>
void dispatchMoeGemmSelectClusterShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
size_t* workspace_size)
{
using namespace cute;
switch (gemm_config.cluster_shape)
{
#define SHAPE_CASE(M, N, K) \
case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \
{ \
using ClusterShape = Shape<_##M, _##N, _##K>; \
if constexpr (are_tile_shapes_supported<TileShape, ClusterShape>()) \
{ \
dispatchMoeGemmSelectBiasSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape>( \
hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \
break; \
} \
else \
{ \
TLLM_THROW("Unsupported tile and cluster shape combination"); \
} \
}
SHAPE_CASE(1, 1, 1)
SHAPE_CASE(1, 2, 1)
SHAPE_CASE(2, 1, 1)
SHAPE_CASE(2, 2, 1)
#undef SHAPE_CASE
default: TLLM_THROW("Unsupported config for MoE gemm.");
}
} // namespace tensorrt_llm
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION>
void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
size_t* workspace_size)
{
using namespace cute;
switch (gemm_config.tile_config_sm90)
{
#define SHAPE_CASE(M, N, K) \
case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \
{ \
constexpr int KtileBytes = K / sizeof(T); \
using KTileDim = Int<KtileBytes>; \
using TileShape = Shape<_##M, _##N, KTileDim>; \
dispatchMoeGemmSelectClusterShapeSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape>( \
hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \
break; \
}
SHAPE_CASE(128, 16, 128)
SHAPE_CASE(128, 32, 128)
SHAPE_CASE(128, 64, 128)
SHAPE_CASE(128, 128, 128)
SHAPE_CASE(128, 256, 128)
SHAPE_CASE(256, 128, 128)
#undef SHAPE_CASE
case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break;
case cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic:
TLLM_THROW("GEMM config should have already been set by heuristic.");
break;
default: TLLM_THROW("Unsupported config for MoE gemm."); break;
}
}
template <typename T, typename WeightType, typename OutputType, EpilogueFusion FUSION>
size_t calcMaxWorkspaceSizeSM90(
int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count)
{
size_t count;
// Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat
dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, cutlass_extensions::EpilogueOpDefault, FUSION>(
HopperGroupedGemmInput{}, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count);
return count;
}
} // namespace tensorrt_llm
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/mma_sm90.h"
#include "cutlass_extensions/epilogue_helpers.h"
namespace tensorrt_llm::kernels::cutlass_kernels
{
// Hopper arch
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
constexpr bool isValidHopperMOESpecialisation()
{
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
return cutlass::platform::is_same<T, WeightType>::value
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value;
#else
return false; // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled
#endif
}
// Hopper arch
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
constexpr bool isValidAmpereMOESpecialisation()
{
return true; // Default to true
}
} // namespace tensorrt_llm::kernels::cutlass_kernels
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment