"docs/vscode:/vscode.git/clone" did not exist on "8db3c9bc9fb3aed9fb8392945ec75e5237a351ba"
Unverified Commit e81d7f11 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

add tensorrt_llm moe_gemm as 3rdparty (#3217)

parent 222ce6f1
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define CUDA_LIB_NAME "cuda"
#if defined(_WIN32)
#include <windows.h>
#define dllOpen(name) LoadLibrary("nv" name ".dll")
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
#else // For non-Windows platforms
#include <dlfcn.h>
#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY)
#define dllClose(handle) dlclose(handle)
#define dllGetSym(handle, name) dlsym(handle, name)
#endif // defined(_WIN32)
#include "cudaDriverWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
namespace tensorrt_llm::common
{
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
{
static std::mutex mutex;
static std::weak_ptr<CUDADriverWrapper> instance;
std::shared_ptr<CUDADriverWrapper> result = instance.lock();
if (result)
{
return result;
}
std::lock_guard<std::mutex> lock(mutex);
result = instance.lock();
if (!result)
{
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
instance = result;
}
return result;
}
CUDADriverWrapper::CUDADriverWrapper()
: handle(dllOpen(CUDA_LIB_NAME))
{
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
auto load_sym = [](void* handle, char const* name)
{
void* ret = dllGetSym(handle, name);
return ret;
};
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
*reinterpret_cast<void**>(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
*reinterpret_cast<void**>(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
*reinterpret_cast<void**>(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
*reinterpret_cast<void**>(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
*reinterpret_cast<void**>(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
*reinterpret_cast<void**>(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
*reinterpret_cast<void**>(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
*reinterpret_cast<void**>(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
}
CUDADriverWrapper::~CUDADriverWrapper()
{
dllClose(handle);
}
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
{
return (*_cuGetErrorName)(error, pStr);
}
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
{
return (*_cuGetErrorMessage)(error, pStr);
}
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
{
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
}
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const
{
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
}
CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const
{
return (*_cuModuleUnload)(hmod);
}
CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
{
return (*_cuLinkDestroy)(state);
}
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
{
return (*_cuModuleLoadData)(module, image);
}
CUresult CUDADriverWrapper::cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const
{
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
}
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetFunction)(hfunc, hmod, name);
}
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
}
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const
{
return (*_cuLaunchCooperativeKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams);
}
CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const
{
return (*_cuLaunchKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
}
CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const
{
return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides,
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
}
CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const
{
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
}
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDA_DRIVER_WRAPPER_H
#define CUDA_DRIVER_WRAPPER_H
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
#include <memory>
#include <mutex>
namespace tensorrt_llm::common
{
class CUDADriverWrapper
{
public:
static std::shared_ptr<CUDADriverWrapper> getInstance();
~CUDADriverWrapper();
CUDADriverWrapper(CUDADriverWrapper const&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete;
CUDADriverWrapper(CUDADriverWrapper&&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete;
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
CUresult cuModuleUnload(CUmodule hmod) const;
CUresult cuLinkDestroy(CUlinkState state) const;
CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
CUresult cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
CUjit_option* options, void** optionValues) const;
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const;
CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra) const;
CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank,
void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim,
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const;
private:
void* handle;
CUDADriverWrapper();
CUresult (*_cuGetErrorName)(CUresult, char const**);
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
CUresult (*_cuModuleUnload)(CUmodule);
CUresult (*_cuLinkDestroy)(CUlinkState);
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLinkAddData)(
CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, CUstream, void**);
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra);
CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
};
template <typename T>
void checkDriver(
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
{
if (result)
{
char const* errorName = nullptr;
char const* errorMsg = nullptr;
wrap.cuGetErrorName(result, &errorName);
wrap.cuGetErrorMessage(result, &errorMsg);
throw TllmException(
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
}
}
} // namespace tensorrt_llm::common
/*
* Macros compliant with TensorRT coding conventions
*/
#define TLLM_CU_CHECK(stat) \
do \
{ \
tensorrt_llm::common::checkDriver( \
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
} while (0)
#endif // CUDA_DRIVER_WRAPPER_H
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
namespace tensorrt_llm::kernels::cutlass_kernels
{
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
typename EpilogueTag>
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
int* kernel_occupancy);
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include <cutlass_extensions/epilogue_helpers.h>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh>
#include <tensorrt_llm/common/cudaUtils.h>
namespace tensorrt_llm::kernels::cutlass_kernels
{
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
typename EpilogueTag>
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
int* kernel_occupancy)
{
constexpr auto activation_type = fused_moe::EpilogueRouting<EpilogueTag>(true);
using GemmType = fused_moe::Fused_Moe_Kernel_sm80<ElementType_, CutlassWeightType_, ElementType_, MaxTileM_, TileN_,
TileK_, Stages_, activation_type>;
// make sure GPU has enough resources..
if (kernel_occupancy != nullptr)
{
constexpr int smem_size = GemmType::kSmemSize;
if (smem_size > (48 << 10))
{
cudaFuncAttributes attr{};
int device = 0;
int max_smem_per_block = 0;
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
tensorrt_llm::common::check_cuda_error(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global<GemmType>));
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
{
// This should mean that
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize,
// smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the
// heuristic to ignore this configuration.
*kernel_occupancy = 0;
return;
}
}
int max_active_blocks = -1;
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, fused_moe::run_global<GemmType>, GemmType::kThreadCount, smem_size));
*kernel_occupancy = max_active_blocks;
return;
}
int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks<GemmType>());
int const threadblock_count = multi_processor_count * occupancy;
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel");
using Arguments = typename GemmType::Arguments;
Arguments args{{const_cast<ElementType_*>(A), const_cast<CutlassWeightType_*>(B), const_cast<ElementType_*>(biases),
reinterpret_cast<ElementType_*>(C), total_tokens_including_expert, static_cast<int>(gemm_n),
static_cast<int>(gemm_k), num_experts, bias_is_broadcast},
num_experts, threadblock_count};
auto params = GemmType::to_underlying_arguments(args);
if (GemmType::kSmemSize >= (48 << 10))
{
cudaError_t result = cudaFuncSetAttribute(
fused_moe::run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize);
TLLM_CHECK_WITH_INFO(result == cudaSuccess,
"Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel");
}
dim3 grid(params.threadblock_count, 1, 1);
dim3 block(GemmType::kThreadCount);
fused_moe::run_global<GemmType><<<grid, block, GemmType::kSmemSize, stream>>>(params);
auto result = cudaGetLastError();
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result));
}
} // namespace tensorrt_llm::kernels::cutlass_kernels
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda_runtime_api.h>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
// Keep in sync with the signature generated by generate_kernels.py
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag,
HopperGroupedGemmInput::EpilogueFusion FUSION, typename TileShape, typename ClusterShape, bool BIAS>
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size);
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion;
// Hopper helper class for defining all the cutlass helper types
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, typename TileShape,
typename ClusterShape, bool BIAS, EpilogueFusion FUSION>
struct HopperGroupedGemmInfo
{
using Arch = cutlass::arch::Sm90;
// TODO Update once mixed input support is added
static_assert(cutlass::platform::is_same<T, WeightType>::value,
"CUTLASS does not currently have specialised SM90 support for quantized operations");
#ifdef ENABLE_FP8
constexpr static bool IsFP8
= cutlass::platform::is_same<T, __nv_fp8_e4m3>::value || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value;
#else
constexpr static bool IsFP8 = false;
#endif
#ifdef ENABLE_BF16
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for bfloat16, half, float, fp8");
#else
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for half, float, fp8");
#endif
static_assert(cutlass::platform::is_same<T, WeightType>::value
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e4m3_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e5m2_t>::value,
"Unexpected quantization type");
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter<WeightType>::type;
// For legacy reasons we convert unsigned 8-bit to signed
using CutlassWeightTypeMaybeUint8
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint4, cutlass::uint4b_t>, cutlass::int4b_t,
CutlassWeightTypeMaybeUint4>;
using CutlassWeightType
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint8, uint8_t>, int8_t, CutlassWeightTypeMaybeUint8>;
using ElementA = ElementType;
using ElementB = CutlassWeightType;
using ElementD = typename TllmToCutlassTypeAdapter<HopperGroupedGemmInput::OutputTypeAdaptor_t<OutputType>>::type;
using ElementFinalOutput = typename TllmToCutlassTypeAdapter<OutputType>::type;
// using ElementC = std::conditional_t<BIAS, ElementType, void>;
// using ElementCNoVoid = std::conditional_t<BIAS, ElementType, ElementD>;
using ElementC = void;
using ElementCNoVoid = ElementD;
using ElementAccumulator = float;
using ElementBias = ElementFinalOutput;
using ElementRouterScales = float;
// A matrix configuration - this is transposed and swapped with B
using LayoutA = HopperGroupedGemmInput::LayoutA;
constexpr static int AlignmentA
= 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units
// of elements (up to 16 bytes)
// B matrix configuration - this is transposed and swapped with A
using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand
constexpr static int AlignmentB
= 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units
// of elements (up to 16 bytes)
// C matrix configuration
using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand
using StrideC = HopperGroupedGemmInput::StrideC;
// Note we use ElementType here deliberately, so we don't break when BIAS is disabled
constexpr static int AlignmentC
= 128 / cutlass::sizeof_bits<ElementType>::value; // Memory access granularity/alignment of C matrix in units
// of elements (up to 16 bytes)
// D matrix configuration
using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD;
using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD;
constexpr static int AlignmentD
= 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix
// in units of elements (up to 16 bytes)
static_assert(cutlass::platform::is_same<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpDefault>::value,
"Hopper Grouped GEMM specialisation doesn't support fused activation");
using EpilogueOp
= cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
// TODO Add mode for fused activation once CUTLASS adds support
// using EpilogueSchedule = cutlass::platform::conditional_t<
// cutlass::platform::is_same<EpilogueOp, EpilogueOpDefault>::value,
// cutlass::epilogue::PtrArrayNoSmemWarpSpecialized,
// cutlass::epilogue::?????????????????? /// <<<<<< what supports activations
// >;
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;
// Epilogue For Default Finalize
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
TileShape, ClusterShape, //
cutlass::epilogue::collective::EpilogueTileAuto, //
ElementAccumulator, ElementAccumulator, //
ElementC, LayoutC*, AlignmentC, //
ElementD, LayoutD*, AlignmentD, //
EpilogueSchedule>::CollectiveOp;
// Epilogue For Fused Finalize
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< //
TileShape, //
ElementCNoVoid, StrideC*, //
ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, //
ElementAccumulator, //
ElementAccumulator, //
ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, //
ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales //
>::CollectiveOp;
using CollectiveEpilogue
= std::conditional_t<FUSION == EpilogueFusion::FINALIZE, CollectiveEpilogueFinalize, CollectiveEpilogueDefault>;
using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>;
using KernelSchedule
= std::conditional_t<IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here
ElementType, LayoutA*, AlignmentA, //
ElementAccumulator, //
TileShape, ClusterShape, //
StageCountAutoCarveout, KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<HopperGroupedGemmInput::ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
// Hopper specialised version
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename TileShape, typename ClusterShape, bool BIAS>
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size)
{
#ifdef COMPILE_HOPPER_TMA_GEMMS
using namespace cute;
if constexpr (!should_filter_sm90_gemm_problem_shape_v<TileShape, ClusterShape, T>)
{
using GemmInfo
= HopperGroupedGemmInfo<T, WeightType, OutputType, EpilogueTag, TileShape, ClusterShape, BIAS, FUSION>;
using ElementAccumulator = typename GemmInfo::ElementAccumulator;
using ElementA = typename GemmInfo::ElementA;
using ElementB = typename GemmInfo::ElementB;
using ElementC = typename GemmInfo::ElementC;
using ElementCNoVoid = typename GemmInfo::ElementCNoVoid;
using ElementD = typename GemmInfo::ElementD;
using CollectiveMainloop = typename GemmInfo::CollectiveMainloop;
using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue;
using GemmKernel = typename GemmInfo::GemmKernel;
using GemmGrouped = typename GemmInfo::GemmGrouped;
if (kernel_occupancy != nullptr)
{
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel, true>();
return;
}
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = multi_processor_count;
GemmGrouped gemm;
if (workspace_size != nullptr)
{
// Make a mock problem shape with just the minimal information actually required to get the workspace size
// This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to
// catch future cutlass updates causing silent breakages, but that is not fool proof.
// The alternative is to wait until we have data and then dynamically allocate the workspace
typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr};
typename GemmGrouped::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info};
*workspace_size = gemm.get_workspace_size(args);
return;
}
using MainloopArguments = typename CollectiveMainloop::Arguments;
TLLM_CHECK(hopper_input.stride_a);
TLLM_CHECK(hopper_input.stride_b);
TLLM_CHECK(hopper_input.ptr_a);
TLLM_CHECK(hopper_input.ptr_b);
MainloopArguments const mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
hopper_input.stride_b, reinterpret_cast<ElementA const**>(hopper_input.ptr_a), hopper_input.stride_a};
typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{
ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)};
epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
// TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS
auto make_epi_args = [&]()
{
if constexpr (FUSION == EpilogueFusion::NONE)
{
auto epi_params = hopper_input.default_epilogue;
return EpilogueArguments{epilogue_scalars, reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c),
hopper_input.stride_c, reinterpret_cast<ElementD**>(epi_params.ptr_d), epi_params.stride_d};
}
else if constexpr (FUSION == EpilogueFusion::FINALIZE)
{
// Parameters for fused finalize
auto epi_params = hopper_input.fused_finalize_epilogue;
return EpilogueArguments{
epilogue_scalars, // Parameters to underlying epilogue
reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c), hopper_input.stride_c, // C params
reinterpret_cast<typename GemmInfo::ElementFinalOutput*>(epi_params.ptr_final_output),
epi_params.stride_final_output, // D (output) params
reinterpret_cast<typename GemmInfo::ElementBias const*>(epi_params.ptr_bias),
epi_params.stride_bias, // Bias params
epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales
epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales
epi_params.ptr_source_token_index, // Index of the source token to sum into
epi_params.num_rows_in_final_output // Number of tokens in the output buffer
};
}
else
{
static_assert(
sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher");
}
};
EpilogueArguments const epilogue_params = make_epi_args();
typename GemmKernel::TileScheduler::Arguments scheduler_args{
1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN};
typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info,
mainloop_params, epilogue_params, hw_info, scheduler_args};
size_t calculated_ws_size = gemm.get_workspace_size(args);
TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size,
"Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size);
auto can_implement = gemm.can_implement(args);
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
"Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
auto init_status = gemm.initialize(args, hopper_input.gemm_workspace);
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
"Failed to initialize cutlass SM90 grouped gemm. Error: "
+ std::string(cutlassGetStatusString(init_status)));
auto run_status = gemm.run(stream);
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
"Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
sync_check_cuda_error();
}
else
{
TLLM_THROW("Configuration was disabled by FAST_BUILD");
}
#else // COMPILE_HOPPER_TMA_GEMMS
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
#endif // COMPILE_HOPPER_TMA_GEMMS
}
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/conv/convolution.h"
// Order matters here, packed_stride.hpp is missing cute and convolution includes
#include "cutlass/util/packed_stride.hpp"
#include "tensorrt_llm/common/logger.h"
namespace tensorrt_llm
{
std::array<size_t, 10> HopperGroupedGemmInput::workspaceBuffers(int num_experts)
{
size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts;
size_t stride_a_size = sizeof(StrideA) * num_experts;
size_t stride_b_size = sizeof(StrideB) * num_experts;
size_t stride_c_size = sizeof(StrideC) * num_experts;
size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts;
size_t ptr_buf_size = sizeof(void*) * num_experts;
size_t scale_buf_size = sizeof(float*) * num_experts;
return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size,
ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size};
}
size_t HopperGroupedGemmInput::workspaceSize(int num_experts)
{
auto buffers = workspaceBuffers(num_experts);
return tensorrt_llm::common::calculateTotalWorkspaceSize(buffers.data(), buffers.size());
}
void HopperGroupedGemmInput::configureWorkspace(
int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size)
{
auto buffers = workspaceBuffers(num_experts);
std::array<int8_t*, 10> pointers{};
TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers");
for (int i = 0; i < buffers.size(); i++)
{
pointers[i] = start_ptr;
start_ptr = tensorrt_llm::common::nextWorkspacePtr(start_ptr, buffers[i]);
}
shape_info.num_groups = num_experts;
shape_info.problem_shapes = reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(pointers[0]);
shape_info.host_problem_shapes = nullptr;
stride_a = reinterpret_cast<StrideA*>(pointers[1]);
stride_b = reinterpret_cast<StrideB*>(pointers[2]);
stride_c = reinterpret_cast<StrideC*>(pointers[3]);
default_epilogue.stride_d = reinterpret_cast<DefaultEpilogue::StrideD*>(pointers[4]);
ptr_a = reinterpret_cast<void const**>(pointers[5]);
ptr_b = reinterpret_cast<void const**>(pointers[6]);
ptr_c = reinterpret_cast<void const**>(pointers[7]);
default_epilogue.ptr_d = reinterpret_cast<void**>(pointers[8]);
alpha_scale_ptr_array = reinterpret_cast<float const**>(pointers[9]);
this->gemm_workspace = reinterpret_cast<uint8_t*>(gemm_workspace);
this->gemm_workspace_size = gemm_workspace_size;
}
void HopperGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales,
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
int num_output_tokens)
{
fused_finalize_epilogue.ptr_final_output = final_output;
fused_finalize_epilogue.ptr_router_scales = router_scales;
fused_finalize_epilogue.ptr_bias = bias;
fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset;
fused_finalize_epilogue.ptr_source_token_index = source_token_index;
fused_finalize_epilogue.stride_final_output
= cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{},
transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1)));
fused_finalize_epilogue.stride_bias
= transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size));
fused_finalize_epilogue.stride_router_scales = {};
fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens;
}
std::string HopperGroupedGemmInput::toString() const
{
std::stringstream ss;
ss << "Hopper Input Information: " << (isValid() ? "valid" : "null") << "\n";
if (isValid())
{
ss << "Ptr A: " << ptr_a << ", Ptr B: " << ptr_b << ", Ptr C: " << ptr_c << "\n";
ss << "Epilogue Fusion: " << (int) fusion;
if (fusion == HopperGroupedGemmInput::EpilogueFusion::FINALIZE)
{
ss << ",\nFinal Output: " << fused_finalize_epilogue.ptr_final_output;
ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales;
ss << ",\nBias: " << fused_finalize_epilogue.ptr_bias;
ss << " with Stride: " << fused_finalize_epilogue.stride_bias;
ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales;
ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales;
ss << ",\nExpert Offset: " << fused_finalize_epilogue.ptr_expert_first_token_offset;
ss << ", Source Map: " << fused_finalize_epilogue.ptr_source_token_index;
}
else
{
ss << ", Ptr D: " << default_epilogue.ptr_d;
}
ss << '\n';
ss << "Alpha scale ptr: " << alpha_scale_ptr_array << "\n";
}
return ss.str();
}
} // namespace tensorrt_llm
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/workspace.h"
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h"
#include <array>
#include <cuda_runtime_api.h>
#include <optional>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/layout/layout.h"
namespace tensorrt_llm
{
template <class T>
constexpr auto transpose_stride(T const& t)
{
return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t));
}
struct HopperGroupedGemmInput
{
template <class T>
using TransposeStride = decltype(transpose_stride<T>(T{}));
template <class Tag>
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
static_assert(std::is_same_v<cutlass::layout::RowMajor, TransposeLayoutTag<cutlass::layout::ColumnMajor>>);
static_assert(std::is_same_v<cutlass::layout::ColumnMajor, TransposeLayoutTag<cutlass::layout::RowMajor>>);
// Layout for A and B is transposed and then swapped in the implementation
// This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
using StrideA
= std::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutA*>>; // Use B because they will be swapped
using StrideB
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
template <class T>
constexpr static bool IsFP8_v = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
// Currently this should always just be T
template <class T>
using OutputTypeAdaptor_t = std::conditional_t<IsFP8_v<T>, nv_bfloat16, T>;
using ProblemShape = cutlass::gemm::GroupProblemShape<cute::Shape<int64_t, int64_t, int64_t>>;
ProblemShape shape_info{};
StrideA* stride_a = nullptr;
StrideB* stride_b = nullptr;
void const** ptr_a = nullptr;
void const** ptr_b = nullptr;
// C is currently the same in both epilogues
StrideC* stride_c = nullptr;
void const** ptr_c = nullptr;
struct DefaultEpilogue
{
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
StrideD* stride_d = nullptr;
void** ptr_d = nullptr;
};
struct FusedFinalizeEpilogue
{
using StrideFinalOutput = DefaultEpilogue::StrideD;
using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>;
using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>;
void* ptr_final_output = nullptr;
StrideFinalOutput stride_final_output{};
void const* ptr_bias = nullptr;
StrideBias stride_bias{};
float const* ptr_router_scales = nullptr;
StrideRouterScales stride_router_scales{};
int64_t const* ptr_expert_first_token_offset = nullptr;
int const* ptr_source_token_index = nullptr;
size_t num_rows_in_final_output = 0;
};
DefaultEpilogue default_epilogue;
FusedFinalizeEpilogue fused_finalize_epilogue;
enum class EpilogueFusion
{
NONE,
ACTIVATION,
GATED_ACTIVATION,
FINALIZE
};
EpilogueFusion fusion = EpilogueFusion::NONE;
float const** alpha_scale_ptr_array = nullptr;
uint8_t* gemm_workspace = nullptr;
size_t gemm_workspace_size = 0;
static std::array<size_t, 10> workspaceBuffers(int num_experts);
static size_t workspaceSize(int num_experts);
void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size);
bool isValid() const
{
return stride_a != nullptr && ptr_a != nullptr;
}
void setFinalizeFusionParams(void* final_output, float const* router_scales,
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
int num_output_tokens);
std::string toString() const;
};
// Note update moe.py to match
enum class ActivationType
{
Gelu = 0,
Relu,
Silu,
Swiglu,
Geglu,
Identity,
InvalidType
};
constexpr bool isGatedActivation(ActivationType activation_type)
{
return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu;
}
template <typename T, /*The type used for activations/scales/compute*/
typename WeightType, /* The type for the MoE weights */
typename OutputType, /* The output type for the GEMM */
typename ScaleBiasType = OutputType /* The type for the scales/bias */
>
class MoeGemmRunner
{
public:
MoeGemmRunner();
#if defined(ENABLE_FP8)
static constexpr bool use_fp8 = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
#else
static constexpr bool use_fp8 = false;
#endif
void moeGemmBiasAct(T const* A, WeightType const* B, ScaleBiasType const* weight_scales,
ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
ActivationType activation_type, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig chosen_conf);
void moeGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, void* C,
int64_t const* total_tokens_including_expert, HopperGroupedGemmInput layout_info, int64_t total_rows,
int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array,
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf);
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs() const;
static std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int sm);
static std::vector<cutlass_extensions::CutlassGemmConfig> getHopperConfigs(int sm);
static std::vector<cutlass_extensions::CutlassGemmConfig> getAmpereConfigs(int sm);
[[nodiscard]] bool isHopperSpecialised(cutlass_extensions::CutlassGemmConfig gemm_config) const;
[[nodiscard]] bool supportsHopperSpecialisation() const;
[[nodiscard]] bool isFusedGatedActivation(
cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const;
[[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const;
size_t getMaxWorkspaceSize(int num_experts) const;
[[nodiscard]] int getSM() const;
private:
template <typename EpilogueTag>
void dispatchToArch(T const* A, WeightType const* B, ScaleBiasType const* weight_scales,
ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array,
cudaStream_t stream, int* occupancy = nullptr);
template <typename EpilogueTag>
void runGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases,
bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig chosen_conf);
private:
int sm_{};
int multi_processor_count_{};
mutable int num_experts_ = 0;
mutable size_t gemm_workspace_size_ = 0;
size_t calcMaxWorkspaceSize(int num_experts) const;
};
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
#ifdef ENABLE_BF16
template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>;
#endif
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
#ifdef ENABLE_BF16
template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16>;
#endif
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
#ifdef ENABLE_BF16
template class MoeGemmRunner<__nv_bfloat16, uint8_t, __nv_bfloat16>;
#endif
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
template class MoeGemmRunner<half, half, half>;
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
template class MoeGemmRunner<half, cutlass::uint4b_t, half>;
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
template class MoeGemmRunner<half, uint8_t, half>;
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
template class MoeGemmRunner<float, float, float>;
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
#ifdef ENABLE_FP8
template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>;
#ifdef ENABLE_BF16
template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
#endif
// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>;
#endif
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // __GNUC__
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // __GNUC__
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace tensorrt_llm
{
using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion;
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename TileShape, typename ClusterShape>
void dispatchMoeGemmSelectBiasSM90(HopperGroupedGemmInput hopper_input, int num_experts, int multi_processor_count,
cudaStream_t stream, int* occupancy, size_t* workspace_size)
{
static_assert(kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>(),
"Invalid hopper configuration invoked, fallback to Sm80");
TLLM_CHECK_WITH_INFO(
workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information");
// auto func = hopper_input.ptr_c ?
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T, WeightType,
// cutlass::arch::Sm90, EpilogueTag, true>
// :
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T,
// WeightType,
// cutlass::arch::Sm90, EpilogueTag, false>;
// TODO(dastokes) Re-enable bias when CUTLASS supports it
auto func = kernels::cutlass_kernels::sm90_generic_moe_gemm_kernelLauncher<T, WeightType, OutputType, EpilogueTag,
FUSION, TileShape, ClusterShape, false>;
func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size);
}
/*
1x1x1 cluster shape is are supported for any tile shape.
2x1x1 cluster shape is only supported for when the M tile is at least 128.
1x2x1 cluster shape is only supported when the N tile is at least 128.
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
that may not be very useful in practice.
*/
template <typename CTAShape, typename ClusterShape>
constexpr bool are_tile_shapes_supported()
{
using namespace cute;
[[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{});
[[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{});
constexpr int cga_m = get<0>(ClusterShape{});
constexpr int cga_n = get<1>(ClusterShape{});
if constexpr (cga_m == _1{} && cga_n == _1{})
{
return true;
}
else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{})
{
return true;
}
else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{})
{
return true;
}
else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{})
{
return true;
}
else
{
return false;
}
}
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename TileShape>
void dispatchMoeGemmSelectClusterShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
size_t* workspace_size)
{
using namespace cute;
switch (gemm_config.cluster_shape)
{
#define SHAPE_CASE(M, N, K) \
case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \
{ \
using ClusterShape = Shape<_##M, _##N, _##K>; \
if constexpr (are_tile_shapes_supported<TileShape, ClusterShape>()) \
{ \
dispatchMoeGemmSelectBiasSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape>( \
hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \
break; \
} \
else \
{ \
TLLM_THROW("Unsupported tile and cluster shape combination"); \
} \
}
SHAPE_CASE(1, 1, 1)
SHAPE_CASE(1, 2, 1)
SHAPE_CASE(2, 1, 1)
SHAPE_CASE(2, 2, 1)
#undef SHAPE_CASE
default: TLLM_THROW("Unsupported config for MoE gemm.");
}
} // namespace tensorrt_llm
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION>
void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
size_t* workspace_size)
{
using namespace cute;
switch (gemm_config.tile_config_sm90)
{
#define SHAPE_CASE(M, N, K) \
case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \
{ \
constexpr int KtileBytes = K / sizeof(T); \
using KTileDim = Int<KtileBytes>; \
using TileShape = Shape<_##M, _##N, KTileDim>; \
dispatchMoeGemmSelectClusterShapeSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape>( \
hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \
break; \
}
SHAPE_CASE(128, 16, 128)
SHAPE_CASE(128, 32, 128)
SHAPE_CASE(128, 64, 128)
SHAPE_CASE(128, 128, 128)
SHAPE_CASE(128, 256, 128)
SHAPE_CASE(256, 128, 128)
#undef SHAPE_CASE
case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break;
case cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic:
TLLM_THROW("GEMM config should have already been set by heuristic.");
break;
default: TLLM_THROW("Unsupported config for MoE gemm."); break;
}
}
template <typename T, typename WeightType, typename OutputType, EpilogueFusion FUSION>
size_t calcMaxWorkspaceSizeSM90(
int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count)
{
size_t count;
// Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat
dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, cutlass_extensions::EpilogueOpDefault, FUSION>(
HopperGroupedGemmInput{}, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count);
return count;
}
} // namespace tensorrt_llm
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/mma_sm90.h"
#include "cutlass_extensions/epilogue_helpers.h"
namespace tensorrt_llm::kernels::cutlass_kernels
{
// Hopper arch
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
constexpr bool isValidHopperMOESpecialisation()
{
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
return cutlass::platform::is_same<T, WeightType>::value
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value;
#else
return false; // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled
#endif
}
// Hopper arch
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
constexpr bool isValidAmpereMOESpecialisation()
{
return true; // Default to true
}
} // namespace tensorrt_llm::kernels::cutlass_kernels
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment