#pragma once #include "common.h" #include "Tensor.h" #include template inline auto dispatchFloat(Tensor::ScalarType scalarType, F &&func) { switch (scalarType) { case Tensor::BF16: return func.template operator()<__nv_bfloat16>(); case Tensor::FP16: return func.template operator()(); case Tensor::FP32: return func.template operator()(); default: assert(false); throw std::invalid_argument("scalarType is not a floating type"); } } template inline auto dispatchFloat16(Tensor::ScalarType scalarType, F &&func) { switch (scalarType) { case Tensor::BF16: return func.template operator()<__nv_bfloat16>(); case Tensor::FP16: return func.template operator()(); default: assert(false); throw std::invalid_argument("scalarType is not a float16 type"); } } template inline auto dispatch(Tensor::ScalarType scalarType, F &&func) { switch (scalarType) { case Tensor::BF16: return func.template operator()<__nv_bfloat16>(); case Tensor::FP16: return func.template operator()(); case Tensor::FP32: return func.template operator()(); case Tensor::INT8: return func.template operator()(); case Tensor::INT32: return func.template operator()(); case Tensor::INT64: return func.template operator()(); default: throw std::runtime_error("Unsupported scalar type"); } } #pragma nv_diagnostic push // warning #445-D: template parameter "scalar_t" is not used in declaring the parameter types of function template "lambda []()->auto::operator auto (*)()" #pragma nv_diag_suppress 445 template inline bool isTypeMatch(Tensor::ScalarType scalarType) { return dispatch(scalarType, []() { return std::is_same_v; }); } #pragma nv_diagnostic pop template inline auto dispatchVal(int val, std::integer_sequence, F &&func) { auto call = [&]() { if (val == i) { func.template operator()(); } }; (call.template operator()(), ...); } template inline auto dispatchBool(bool val, F &&func) { if (val) { func.template operator()(); } else { func.template operator()(); } } #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) dispatchFloat(TYPE, [&]() { __VA_ARGS__(); });